# TOC

1. [Import](#1-import)
2. [필요한 정보 입력](#2-필요한-정보-입력)
3. [Validation Dataset](#3-validation-dataset)   
    3.1. [GT와 Pred 살펴보기](#31-gt와-pred-살펴보기)   
        3.1.1 [색깔 가득 채우기](#311-색깔-가득-채우기)   
        3.1.2 [점으로 테두리만 표현하기](#312-점으로-테두리만-표현하기)   
4. [Test Dataset](#4-test-dataset)   
    4.1. [Pred 살펴보기](#41-pred-살펴보기)   
        4.1.1 [색깔 가득 채우기](#411-색깔-가득-채우기)   
        4.1.2 [점으로 테두리만 표현하기](#412-점으로-테두리만-표현하기)   
5. [CSV파일 시각화](#5-csv파일-시각화)    

# 1. Import

In [None]:
import os
import glob
os.chdir('/opt/ml/input/code/local')

import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

from dataset import XRayDataset, XRayInferenceDataset
from visualize import label2rgb
from inference import encode_mask_to_rle

# 2. 필요한 정보 입력

In [None]:
data_root = "/opt/ml/input/data"
save_dir = "/opt/ml/input/code/local/checkpoints/[test]Baseline1_1226"

In [None]:
# define colors
PALETTE = [
    (220, 20, 60),
    (119, 11, 32),
    (0, 0, 142),
    (0, 0, 230),
    (106, 0, 228),
    (0, 60, 100),
    (0, 80, 100),
    (0, 0, 70),
    (0, 0, 192),
    (250, 170, 30),
    (100, 170, 30),
    (220, 220, 0),
    (175, 116, 175),
    (250, 0, 30),
    (165, 42, 42),
    (255, 77, 255),
    (0, 226, 252),
    (182, 182, 255),
    (0, 82, 0),
    (120, 166, 157),
    (110, 76, 0),
    (174, 57, 255),
    (199, 100, 0),
    (72, 0, 118),
    (255, 179, 240),
    (0, 125, 92),
    (209, 0, 151),
    (188, 208, 182),
    (0, 220, 176),
]

In [None]:
def decode_rle_to_mask(rle, height, width):
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    px = np.concatenate((starts//2048, ends//2048))
    py = np.concatenate((starts%2048, ends%2048))
    point = np.concatenate((np.expand_dims(px,1), np.expand_dims(py,1)), axis=1)
    img = np.zeros(height * width, dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(height, width), point

# 3. Validation Dataset

In [None]:
transform = A.Resize(512, 512)
dataset = XRayDataset(data_root, transforms=transform, split="val1")
model = torch.load(os.path.join(save_dir, "best_model.pt"))
thr = 0.5

## 3.1. GT와 Pred 살펴보기

### 3.1.1 색깔 가득 채우기

In [None]:
def full_show(idx):
    images, masks = dataset[idx]
    images, masks = images.unsqueeze(0), masks.unsqueeze(0)
    
    image_name = "/".join(dataset.df.iloc[idx]["filenames"].split("/")[-2:])
    
    outputs = model(images.cuda())["out"]
    output_h, output_w = outputs.size(-2), outputs.size(-1)
    mask_h, mask_w = masks.size(-2), masks.size(-1)

    # restore original size
    if output_h != mask_h or output_w != mask_w:
        outputs = F.interpolate(outputs, size=(mask_h, mask_w), mode="bilinear")
    outputs = torch.sigmoid(outputs)
    outputs = (outputs > thr).detach().cpu()
    
    img = images[0].cpu().numpy()
    img = np.transpose(img, (1,2,0))
    img *= 255
    img = img.astype(np.uint8)
        
    gt_mask = label2rgb(masks[0].cpu())
    pred_mask = label2rgb(outputs[0].cpu())
    
    fig, ax = plt.subplots(1, 3, figsize=(24, 12))
    show_imgs = [img, gt_mask, pred_mask]
    show_titles = [image_name, "GT", "Pred"]
    for i, (show_img, show_title) in enumerate(zip(show_imgs, show_titles)):
        ax[i].imshow(show_img, cmap='gray')    # remove channel dimension
        ax[i].set_title(show_title, fontsize=30)
        ax[i].set_xticks([])
        ax[i].set_yticks([])

    plt.tight_layout()
    plt.show()

In [None]:
full_show(0)

### 3.1.2 점으로 테두리만 표현하기

In [None]:
def point_show(idx):
    images, masks = dataset[idx]
    images, masks = images.unsqueeze(0), masks.unsqueeze(0)
    
    image_name = "/".join(dataset.df.iloc[idx]["filenames"].split("/")[-2:])
    
    outputs = model(images.cuda())["out"]
    output_h, output_w = outputs.size(-2), outputs.size(-1)
    mask_h, mask_w = masks.size(-2), masks.size(-1)

    # restore original size
    if output_h != mask_h or output_w != mask_w:
        outputs = F.interpolate(outputs, size=(mask_h, mask_w), mode="bilinear")
        images = F.interpolate(images, size=(mask_h, mask_w), mode="bilinear")
    outputs = torch.sigmoid(outputs)
    outputs = (outputs > thr).detach().cpu().numpy()
    
    img = images[0].cpu().numpy()
    img = np.transpose(img, (1,2,0))
    img *= 255
    img = img.astype(np.uint8)
    
    gts = []
    for i, segm in enumerate(masks[0].cpu().numpy()):
        rle = encode_mask_to_rle(segm)
        _, point = decode_rle_to_mask(rle, height=2048, width=2048)
        gts.append((i, point))
    
    preds = []
    for i, segm in enumerate(outputs[0]):
        rle = encode_mask_to_rle(segm)
        _, point = decode_rle_to_mask(rle, height=2048, width=2048)
        preds.append((i, point))
    
    fig, ax = plt.subplots(1, 2, figsize=(24, 12))
    show_imgs = [gts, preds]
    show_titles = [image_name, "Pred"]
    for i, (show_img, show_title) in enumerate(zip(show_imgs, show_titles)):
        ax[i].imshow(img)
        for cls in show_img:
            c_id = cls[0]
            x = cls[1][:, 1]
            y = cls[1][:, 0]
            ax[i].scatter(x, y, s=1, c=[np.array(PALETTE[c_id])/255])
        ax[i].set_title(show_title, fontsize=30)
        ax[i].set_xticks([])
        ax[i].set_yticks([])
    plt.tight_layout()
    plt.show()

In [None]:
point_show(0)

# 4. Test Dataset

In [None]:
transform = A.Resize(512, 512)
dataset = XRayInferenceDataset(data_root, transforms=transform)
model = torch.load(os.path.join(save_dir, "best_model.pt"))
thr = 0.5

## 4.1 Pred 살펴보기

### 4.1.1 색깔 가득 채우기

In [None]:
def full_show(idx):
    images, _ = dataset[idx]
    images = images.unsqueeze(0)
    
    image_name = "/".join(dataset.df.iloc[idx]["filenames"].split("/")[-2:])
    
    outputs = model(images.cuda())["out"]
    output_h, output_w = outputs.size(-2), outputs.size(-1)

    # restore original size
    if output_h != 2048 or output_w != 2048:
        outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
    outputs = torch.sigmoid(outputs)
    outputs = (outputs > thr).detach().cpu()
    
    img = images[0].cpu().numpy()
    img = np.transpose(img, (1,2,0))
    img *= 255
    img = img.astype(np.uint8)

    pred_mask = label2rgb(outputs[0].cpu())
    
    fig, ax = plt.subplots(1, 2, figsize=(24, 12))
    show_imgs = [img, pred_mask]
    show_titles = [image_name, "Pred"]
    for i, (show_img, show_title) in enumerate(zip(show_imgs, show_titles)):
        ax[i].imshow(show_img, cmap='gray')    # remove channel dimension
        ax[i].set_title(show_title, fontsize=30)
        ax[i].set_xticks([])
        ax[i].set_yticks([])

    plt.tight_layout()
    plt.show()

In [None]:
full_show(0)

### 4.1.2 점으로 테두리만 표현하기

In [None]:
def decode_rle_to_mask(rle, height, width):
    s = rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    px = np.concatenate((starts//2048, ends//2048))
    py = np.concatenate((starts%2048, ends%2048))
    point = np.concatenate((np.expand_dims(px,1), np.expand_dims(py,1)), axis=1)
    img = np.zeros(height * width, dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(height, width), point

In [None]:
def point_show(idx):
    images, _ = dataset[idx]
    images = images.unsqueeze(0)
    
    image_name = "/".join(dataset.df.iloc[idx]["filenames"].split("/")[-2:])
    
    outputs = model(images.cuda())["out"]
    output_h, output_w = outputs.size(-2), outputs.size(-1)

    # restore original size
    if output_h != 2048 or output_w != 2048:
        outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
        images = F.interpolate(images, size=(2048, 2048), mode="bilinear")
    outputs = torch.sigmoid(outputs)
    outputs = (outputs > thr).detach().cpu().numpy()
    
    img = images[0].cpu().numpy()
    img = np.transpose(img, (1,2,0))
    img *= 255
    img = img.astype(np.uint8)
    
    preds = []
    for i, segm in enumerate(outputs[0]):
        rle = encode_mask_to_rle(segm)
        _, point = decode_rle_to_mask(rle, height=2048, width=2048)
        preds.append((i, point))
    
    fig, ax = plt.subplots(1, 1, figsize=(24, 12))
    show_imgs = [preds]
    show_titles = [image_name]
    for i, (show_img, show_title) in enumerate(zip(show_imgs, show_titles)):
        ax.imshow(img)
        for cls in show_img:
            c_id = cls[0]
            x = cls[1][:, 1]
            y = cls[1][:, 0]
            ax.scatter(x, y, s=1, c=[np.array(PALETTE[c_id])/255])
        ax.set_title(show_title, fontsize=30)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()
    plt.show()

In [None]:
point_show(0)

# 5. CSV파일 시각화

In [None]:
csv_path = "/opt/ml/input/code/local/predictions/[test]Baseline1_1226/submission.csv"
df = pd.read_csv(csv_path)

In [None]:
def csv_show(idx):
    rles = []
    for i in range(29):
        image_name, _, rle = df.iloc[idx*29+i]
        rles.append(rle)
        
    for img_path in glob.glob(os.path.join(data_root, "test/DCM/*/*.png"), recursive=True):
        if image_name in img_path:
            img = cv2.imread(img_path)

    preds = []
    for i, rle in enumerate(rles):
        _, point = decode_rle_to_mask(rle, height=2048, width=2048)
        preds.append((i, point))
        
    fig, ax = plt.subplots(1, 1, figsize=(24, 12))
    show_imgs = [preds]
    show_titles = [image_name]
    for i, (show_img, show_title) in enumerate(zip(show_imgs, show_titles)):
        ax.imshow(img)
        for cls in show_img:
            c_id = cls[0]
            x = cls[1][:, 1]
            y = cls[1][:, 0]
            ax.scatter(x, y, s=1, c=[np.array(PALETTE[c_id])/255])
        ax.set_title(show_title, fontsize=30)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()
    plt.show()

In [None]:
csv_show(0)