# 1. Import

In [None]:
# python native
import os
import json
import random
import datetime
from functools import partial

# external library
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import albumentations as A

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models

# visualization
import matplotlib.pyplot as plt

# 2. Model

In [None]:
# # jupyter command 에서 library download 하기
# !pip install git+https://github.com/qubvel/segmentation_models.pytorch
# import segmentation_models_pytorch as smp

# # model 불러오기
# # 출력 label 수 정의 (classes=29)
# model = smp.UnetPlusPlus(
#     encoder_name="tu-densenet201", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=29,                     # model output channels (number of classes in your dataset)
# )

# 3. Path setting

In [None]:
PT_PATH = "/opt/ml/level2_cv_semanticsegmentation-cv-15/codebook/ensemble_pt/"

pt_list = os.listdir(PT_PATH)
pt_list = [pt for pt in pt_list if pt[-2:] == 'pt']
# pt_list = ["UNetplusplus_densenet_1024_BEST_MODEL.pt"]
num_pt = len(pt_list)
print(pt_list)

# 4. Dataset

In [None]:
CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

In [None]:
CLASS2IND = {v: i for i, v in enumerate(CLASSES)}

In [None]:
IND2CLASS = {v: k for k, v in CLASS2IND.items()}

In [None]:
# 테스트 데이터 경로를 입력하세요
IMAGE_ROOT = "/opt/ml/input/data/test/DCM"

In [None]:
pngs = {
    os.path.relpath(os.path.join(root, fname), start=IMAGE_ROOT)
    for root, _dirs, files in os.walk(IMAGE_ROOT)
    for fname in files
    if os.path.splitext(fname)[1].lower() == ".png" and os.path.splitext(fname)[1].lower() != ".ipynb_checkpoints"
}

In [None]:
# mask map으로 나오는 인퍼런스 결과를 RLE로 인코딩 합니다.

def encode_mask_to_rle(mask):
    '''
    mask: numpy array binary mask 
    1 - mask 
    0 - background
    Returns encoded run length 
    '''
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
# RLE로 인코딩된 결과를 mask map으로 복원합니다.

def decode_rle_to_mask(rle, height, width):
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(height * width, dtype=np.uint8)
    
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    
    return img.reshape(height, width)

In [None]:
class XRayInferenceDataset(Dataset):
    def __init__(self, transforms=None):
        _filenames = pngs
        _filenames = np.array(sorted(_filenames))
        
        self.filenames = _filenames
        self.transforms = transforms
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, item):
        image_name = self.filenames[item]
        image_path = os.path.join(IMAGE_ROOT, image_name)
        
        image = cv2.imread(image_path)
        
        if self.transforms is not None:
            inputs = {"image": image}
            result = self.transforms(**inputs)
            image = result["image"]
            
        image = image / 255.

        # to tenser will be done later
        image = image.transpose(2, 0, 1)    # make channel first
        
        image = torch.from_numpy(image).float()
            
        return image, image_name

In [None]:
tf = A.Compose([
    A.CLAHE(p=1.0),
    A.Resize(1024, 1024),
])

In [None]:
test_dataset = XRayInferenceDataset(transforms=tf)

In [None]:
test_loader = DataLoader(
    dataset=test_dataset, 
    batch_size=4,
    shuffle=False,
    num_workers=4,
    drop_last=False
)

# 5. Ensemble

In [None]:
def test(data_loader, thr=0.5):
    rles = []
    filename_and_class = []
    outputs_sum = [torch.zeros(4, 29, 1024, 1024) for _ in range(75)]
    image_names_list = []
    for model in pt_list:
        model = torch.load(os.path.join(PT_PATH, model))
        model = model.cuda()
        model.eval()

        with torch.no_grad():
            n_class = len(CLASSES)
            
            for step, (images, image_names) in tqdm(enumerate(data_loader), total=len(data_loader)):
                images = images.cuda()    
                outputs = model(images)
                outputs_sum[step] = outputs_sum[step] + outputs.cpu()
                image_names_list.append(list(image_names))

    image_names_list=image_names_list[:75]
    for ensemble_outputs, image_names in tqdm(zip(outputs_sum, image_names_list)):
        outputs = ensemble_outputs / num_pt
        # restore original size
        outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
        outputs = torch.sigmoid(outputs)
        outputs = (outputs > thr).detach().cpu().numpy()

        for output, image_name in zip(outputs, image_names):
            for c, segm in enumerate(output):
                rle = encode_mask_to_rle(segm)
                rles.append(rle)
                filename_and_class.append(f"{IND2CLASS[c]}_{image_name}")

    return rles, filename_and_class, image_names_list

In [None]:
rles, filename_and_class, image_names_list = test(test_loader)

# 6. Visaulization

In [None]:
filename_and_class[0]

In [None]:
image = cv2.imread(os.path.join(IMAGE_ROOT, filename_and_class[0].split("_")[1]))

In [None]:
preds = []
for rle in rles[:len(CLASSES)]:
    pred = decode_rle_to_mask(rle, height=2048, width=2048)
    preds.append(pred)

preds = np.stack(preds, 0)

In [None]:
# define colors
PALETTE = [
    (220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
    (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
    (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42),
    (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
    (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
    (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
]

# utility function
# this does not care overlap
def label2rgb(label):
    image_size = label.shape[1:] + (3, )
    image = np.zeros(image_size, dtype=np.uint8)
    
    for i, class_label in enumerate(label):
        image[class_label == 1] = PALETTE[i]
        
    return image

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(24, 12))
ax[0].imshow(image)    # remove channel dimension
ax[1].imshow(label2rgb(preds))

plt.show()

# 7. To csv

In [None]:
classes, filename = zip(*[x.split("_") for x in filename_and_class])

In [None]:
print(image_names_list)

In [None]:
image_name = [os.path.basename(f) for f in filename]

In [None]:
print(len(image_name),len(classes),len(rles))

In [None]:
df = pd.DataFrame({
    "image_name": image_name,
    "class": classes,
    "rle": rles,
})

In [None]:
df.head(30)

In [None]:
if not os.path.exists('ensemble_results'):
    os.mkdir('ensemble_results')

file_path = os.path.join('/opt/ml/level2_cv_semanticsegmentation-cv-15/codebook/ensemble_results', 'ensemble.csv')
df.to_csv(file_path, index=False)