# TOC

1. [Import](#1-import)
2. [필요한 정보 입력](#2-필요한-정보-입력)
3. [테두리 잘 잡는지 확인](#3-테두리를-잘-잡는지-확인)   
    3.1. [필요한 값들 계산](#31-필요한-값들-계산)   
    3.2. [그려보기](#32-그려보기)   
    3.3. [값을 직접 보기](#33-값을-직접-보기)   

# 1. Import

In [1]:
import os
os.chdir('/opt/ml/input/code/local')

import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn.functional as F

from dataset import XRayDataset
from metric import dice_coef

# 2. 필요한 정보 입력

In [2]:
CLASSES = [
    "finger-1",
    "finger-2",
    "finger-3",
    "finger-4",
    "finger-5",
    "finger-6",
    "finger-7",
    "finger-8",
    "finger-9",
    "finger-10",
    "finger-11",
    "finger-12",
    "finger-13",
    "finger-14",
    "finger-15",
    "finger-16",
    "finger-17",
    "finger-18",
    "finger-19",
    "Trapezium",
    "Trapezoid",
    "Capitate",
    "Hamate",
    "Scaphoid",
    "Lunate",
    "Triquetrum",
    "Pisiform",
    "Radius",
    "Ulna",
]

In [3]:
data_root = "/opt/ml/input/data"
save_dir = "/opt/ml/input/code/local/checkpoints/[test]ExpName"

# 3. 테두리를 잘 잡는지 확인

In [4]:
transform = A.Resize(512, 512)
dataset = XRayDataset(data_root, transforms=transform, split="val1")
model = torch.load(os.path.join(save_dir, "best_model.pt"))
thr = 0.5

## 3.1 필요한 값들 계산

In [None]:
dices = []
preds = np.zeros((len(dataset), 29, 2048, 2048), dtype=np.bool_)
gts = np.zeros((len(dataset), 29, 2048, 2048), dtype=np.bool_)
for idx, (images, masks) in enumerate(tqdm(dataset)):
    images, masks = images.unsqueeze(0), masks.unsqueeze(0)
    
    model.eval()
    with torch.no_grad():
        outputs = model(images.cuda())["out"]
        output_h, output_w = outputs.size(-2), outputs.size(-1)
        mask_h, mask_w = masks.size(-2), masks.size(-1)

        # restore original size
        if output_h != mask_h or output_w != mask_w:
            outputs = F.interpolate(outputs, size=(mask_h, mask_w), mode="bilinear")

        outputs = torch.sigmoid(outputs)
        outputs = (outputs > thr).detach().cpu()
        preds[idx] = np.array(outputs.data, dtype=np.bool_)
        masks = masks.detach().cpu()
        gts[idx] = np.array(masks.data, dtype=np.bool_)

        dice = dice_coef(outputs, masks)
        dices.append(dice)

In [6]:
# 메모리 이슈
del images
del masks
del outputs
del dice

In [None]:
# Black Check (둘 다 배경)
gt_fg = np.any(gts, axis=1, keepdims=True)
pred_fg = np.any(preds, axis=1, keepdims=True)
fg = gt_fg | pred_fg
blacks = ~ fg
blacks.shape

In [None]:
# Red Check (클래스 예측을 실패한 경우)
reds = np.zeros_like(gts, dtype=np.bool_)
for i in range(29):
    reds[:,i,:,:] = (gts[:,i,:,:] != preds[:,i,:,:])
reds = np.any(reds, axis=1, keepdims=True)
reds = gt_fg & pred_fg & reds
reds.shape

In [None]:
# Yellow Check (과소 예측)
pred_bg = ~ pred_fg
yellows = gt_fg & pred_bg
yellows.shape

In [None]:
# Green Check (과대 예측)
gt_bg = ~ gt_fg
greens = gt_bg & pred_fg
greens.shape

In [None]:
# Blue Check (클래스 예측을 모두 성공한 경우)
blues = np.zeros_like(gts, dtype=np.bool_)
for i in range(29):
    blues[:,i,:,:] = (gts[:,i,:,:] == preds[:,i,:,:])
blues = np.all(blues, axis=1, keepdims=True)
blues = (gt_fg & pred_fg & blues)
blues.shape

In [12]:
# Overlap 영역 구하기
gts_num = gts.astype(np.uint8)
gts_sum = np.sum(gts_num, axis=1)
gts_bool = np.where(gts_sum > 1, True, False)
gts_bool = np.expand_dims(gts_bool, axis=1)

In [13]:
# Dict에 저장
arrs = dict(r=reds, g=greens, b=blues, y=yellows)

In [14]:
# 메모리 이슈
del gts
del preds

## 3.2 그려보기

In [15]:
def show(arrs, idx, color, overlap=False):
    if len(color) > 1:
        arr = np.zeros((3, 2048, 2048), dtype=np.bool_)
        for c in color:
            if c == "r":
                arr[0, :, :] = arr[0, :, :] | arrs["r"][idx]
            elif c == "g":
                arr[1, :, :] = arr[1, :, :] | arrs["g"][idx]
            elif c == "b":
                arr[2, :, :] = arr[2, :, :] | arrs["b"][idx]
            elif c == "y":
                arr[0:2, :, :] = arr[0:2, :, :] | arrs["y"][idx]
        arr = (arr.transpose(1, 2, 0) * 255).astype(np.uint8)
    else:
        arr = arrs[color][idx]
        arr = arr.transpose(1, 2, 0)
        arr = np.concatenate((arr, arr, arr), axis=2, dtype=np.uint8)
        if color == "r":
            arr[:, :, 0] = arr[:, :, 0] * 255
        elif color == "g":
            arr[:, :, 1] = arr[:, :, 1] * 255
        elif color == "b":
            arr[:, :, 2] = arr[:, :, 2] * 255
        elif color == "y":
            arr[:, :, 0:2] = arr[:, :, 0:2] * 255
            
    if overlap:
        plt.imshow(arr * np.expand_dims(gts_bool[idx][0], axis=-1))
        plt.xticks([])
        plt.yticks([])
        plt.tight_layout()
        plt.show()
    else:
        plt.imshow(arr)
        plt.xticks([])
        plt.yticks([])
        plt.tight_layout()
        plt.show()
    
    return arr

In [None]:
idx = 84
color = "rgby"
arr = show(arrs, idx, color, False)

## 3.3 값을 직접 보기

In [17]:
full_cnt = dict(r=0, g=0, b=0, y=0)
overlap_cnt = dict(r=0, g=0, b=0, y=0)
for key, val in arrs.items():
    full_cnt[key] = round(np.count_nonzero(val) / 100000, 2) # Scaling
    overlap_cnt[key] = round(np.count_nonzero(val * gts_bool) / 100000, 2) # Scaling
full_cnt["g+y"] = full_cnt["g"] + full_cnt["y"]
overlap_cnt["g+y"] = overlap_cnt["g"] + overlap_cnt["y"]

In [None]:
print(f"Full: {full_cnt}")
print(f"Overlap: {overlap_cnt}")

In [None]:
dices_cat = torch.cat(dices, 0)
dices_per_class = torch.mean(dices_cat, 0)
dice_str = [f"{c:<12}: {d.item():.4f}" for c, d in zip(CLASSES, dices_per_class)]
dice_str = "\n".join(dice_str)
print(dice_str)