In [8]:
import time
import os
import pickle

import albumentations as A
import numpy as np
import pandas as pd
from pycocotools.coco import COCO
from tqdm import tqdm

In [2]:
pkl_file_paths = ['./upernet_swin_albu_mb_gd_40k_logit.pkl', './ocrnet_hr48_40k_logit.pkl', './deeplabv3+_se_resnext101_32x4d_ALL_40_logit.pkl']

out_csv_file_path = "./submission/swin_hrnet_deep3.csv"
for_pseudo = False  # True 이면 pseudo 용 csv 출력 (512 x 512)

out_pkl_file_path = "./saved/swin_hrnet_deep3.pkl"  

In [3]:
def load_pkl(pkl_file_path):
    with open(pkl_file_path, 'rb') as pkl_file:
        logit_uint8 = pickle.load(pkl_file)
        
    logit_float32 = logit_uint8.astype(dtype=np.float32) / 255.
    
    return logit_float32

In [4]:
logit_list = list()
for pkl_file_path in pkl_file_paths:
    logit = load_pkl(pkl_file_path)
    print(logit.shape)
    logit_list.append(logit)
    
avg_logit = np.average(np.array(logit_list), axis=0)

with open(out_pkl_file_path, 'wb') as pkl_file:
    logit_to_pickle = np.clip(avg_logit * 255., 0., 255.)  # 혹시 몰라서 clamping
    logit_to_pickle = logit_to_pickle.astype(np.uint8)
    pickle.dump(logit_to_pickle, pkl_file, protocol=4)


(819, 11, 512, 512)
(819, 11, 512, 512)
(819, 11, 512, 512)


In [5]:
soft_pred = avg_logit.argmax(axis=1)

In [9]:
resize_transform = A.Compose([A.Resize(256, 256, 0)])  # cv2.INTER_NEAREST

prediction_strings = list()
file_names = list()
coco = COCO("/opt/ml/segmentation/input/data/test.json")

for i, out in enumerate(list(soft_pred)):
    image_info = coco.loadImgs(coco.getImgIds(imgIds=i))[0]
    file_names.append(image_info['file_name'])
        
    if for_pseudo:
        out_fit_array = out
    else:
        out_fit_array = resize_transform(image=out)['image']
    
    prediction_string = ' '.join([str(pixel_pred) for pixel_pred in out_fit_array.flatten().tolist()])
    prediction_strings.append(prediction_string)

submission = pd.DataFrame()
submission['image_id'] = file_names
submission['PredictionString'] = prediction_strings
submission.to_csv(out_csv_file_path, index=False)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
