In [1]:
import os
import re
import glob
import json
import sys
import random
from pathlib import Path
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
# import torchvision.models as models
import torch.utils.data as data

import numpy as np
import matplotlib.pyplot as plt

from torchvision.transforms import Resize

In [2]:
sys.path.append('/opt/ml/pstage01')
from model import models,loss, metric
from dataloader import mask
from util import meter, transformers

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = '/opt/ml/input/data/train'
img_dir = f'{data_dir}/images'

In [4]:
mean, std = (0.56019358, 0.52410121, 0.501457), (0.23318603, 0.24300033, 0.24567522)

In [5]:
dataset = mask.MaskBaseDataset(data_dir = data_dir,img_dir=img_dir)
num_classes = 18

transform = transformers.get_transforms(mean=mean, std=std, transform_type = 'basic')
dataset.set_transform(transform['val'])

# -- data_loader
_, val_set = dataset.split_dataset()

val_loader = data.DataLoader(
    val_set,
    num_workers=4,
    shuffle=False,
    drop_last=True,
    batch_size=32,
    pin_memory=torch.cuda.is_available(),
)

In [6]:
def change_2d_to_1d(tensor_2d):
    """
    tensor의 shape를 2차원이면 1차원으로 바꿔줌
    """
    if len(tensor_2d.shape) == 2:
        tensor_2d = tensor_2d.reshape(-1)
    return tensor_2d


def tensor_images_to_numpy_images(images, renormalize=False):
    """
    이미지 화소 되돌리기 작업
    """
    images = images.detach().cpu().numpy()
#     print(images.shape)
    if renormalize:
        images = np.clip((images * STD) + MEAN, 0, 1)
    images = images.transpose(0, 2, 3, 1)
#     print(images.shape)
    return images


def tensor_to_numpy(tensors):
    """
    tensor에서 numpt array로 바꿔줌
    """
    return tensors.detach().cpu().numpy()


In [7]:
def get_all_datas(model, device, dataloader, argmax=True):
    """
    image데이터와 이에 대한 실제 라벨값과 에측 라벨값에 대한 정보를 반환
    """
    model.eval()
    
    all_images = torch.tensor([]).to(device)
    all_labels = torch.tensor([]).to(device)
    all_preds = torch.tensor([]).to(device)

    with torch.no_grad():
        for idx, (images, labels, _) in enumerate(tqdm(dataloader)):
            images, labels = images.to(device), labels.to(device)

            preds = model(images)
            if argmax:
                preds = torch.argmax(preds, dim=1)
                preds = change_2d_to_1d(preds)

            all_images = torch.cat((all_images, images))
            all_labels = torch.cat((all_labels, labels))
            all_preds = torch.cat((all_preds, preds))
            

    return all_images, all_labels, all_preds

In [8]:
model_path = "/opt/ml/model_save/04_04_v3/004_loss_7.5e-05_acc_0.97.ckpt"
model = models.MyResNext()
model.load_state_dict(torch.load(model_path))
model.cuda()
model = torch.nn.DataParallel(model)

In [1]:
# 라벨값을 가져옴
images, labels, preds = get_all_datas(model, device, val_loader, argmax=False)
print(images.shape, labels.shape, preds.shape)
labels, preds = tensor_to_numpy(labels), tensor_to_numpy(preds)

In [10]:
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

def log_f1_and_acc_scores(labels, preds, num_classes):
    # class 별 f1_score를 계산

    summary_table = pd.DataFrame([])
    false_table = pd.DataFrame([])

    for class_idx in range(num_classes):
        fancy_idx = np.where(labels == class_idx) # labels에서 현재 클래스에 대한 인덱스만 뽑아옴

        binary_labels = labels[fancy_idx] == class_idx
        binary_outputs = preds[fancy_idx] == class_idx
#         print(labels[fancy_idx])
#         print(labels[fancy_idx] == class_idx)
#         print(preds[fancy_idx])
#         print(preds[fancy_idx] == class_idx)
        false_count = len(binary_outputs[binary_outputs == False])

        f1 = f1_score(binary_labels, binary_outputs, average="binary")
        pr = precision_score(binary_labels, binary_outputs, average="binary")
        re = recall_score(binary_labels, binary_outputs, average="binary")
        acc = accuracy_score(binary_labels, binary_outputs)

        summary_table.loc["1003", f"{class_idx} f1"] = f1
        summary_table.loc["1003", f"{class_idx} pr"] = pr
        summary_table.loc["1003", f"{class_idx} re"] = re
        summary_table.loc["1003", f"{class_idx} acc"] = acc
        
        false_table.loc[f"{class_idx}","false count"] = false_count
    
    summary_table.fillna(0, inplace=True)
    summary_table = summary_table.applymap(lambda x: "{:,.1f}%".format(x * 100))
    
    return summary_table, false_table

In [11]:
summary_df, false_df = log_f1_and_acc_scores(labels, np.argmax(preds, axis=1), num_classes)

In [1]:
summary_df

In [2]:
false_df

In [14]:
import itertools
from sklearn.metrics import confusion_matrix

def log_confusion_matrix(labels, preds, num_classes):
    fig, axes = plt.subplots(nrows = 1, ncols = 2, figsize=(18, 9))
    fig.suptitle("Confusion Matrix", fontsize=16)
    cmap = plt.cm.GnBu
    
    # confusion matrix 구하기
    cm = confusion_matrix(labels, preds)
    
    axes[0].imshow(cm, interpolation="nearest", cmap=cmap)

    axes[0].set_xticks(range(num_classes))
    axes[0].set_yticks(range(num_classes))
    axes[0].set_ylabel("True label")
    axes[0].set_xlabel("Predicted label")
    
    thresh = cm.max() / 2.0
    
    # 갯수 세주기
    # 대각선에 많을수록 많이 맞추는 것
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        axes[0].text(
            j,
            i,
            cm[i, j],
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
        )
        
    #못마추는 것만 뽑아내기
    np.fill_diagonal(cm, 0)
    axes[1].imshow(cm, interpolation="nearest", cmap=cmap)

    axes[1].set_xticks(range(num_classes))
    axes[1].set_yticks(range(num_classes))
    axes[0].set_ylabel("True label")
    axes[1].set_xlabel("Predicted label")

    thresh = cm.max() / 2.0

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        axes[1].text(
            j,
            i,
            cm[i, j],
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
        )

    return fig
    
    

In [3]:
fig = log_confusion_matrix(labels, np.argmax(preds, axis=1), num_classes)

In [16]:
# 예측값이 어떻게 떨어지는가 보기
from scipy.special import softmax

MEAN = np.array([0.548, 0.504, 0.479]).reshape(-1, 1, 1)
STD = np.array([0.237, 0.247, 0.246]).reshape(-1, 1, 1)

In [17]:
np_images = tensor_images_to_numpy_images(images, renormalize=True)
labels = labels.astype(np.int)

In [18]:
search_information = []

for i in range(len(np_images)):
    search_information.append((np_images[i], labels[i], preds[i]))

search_information_iter = iter(search_information)

In [19]:
def _log_plots_image(ax, image, pred, pred_label, true_label, num_classes):
    ax.grid(False)
    color = "blue" if pred_label == true_label else "red"
    
    classes = np.array([str(i) for i in range(num_classes)])
    
    ax.imshow(image)
    ax.set_xlabel(
        "pred: {} {:2.0f}% | (true: {})".format(
            classes[pred_label], 100 * pred[pred_label], classes[true_label]
        ),
        color=color,
        fontsize=18
    )

def _log_plots_distribution(ax, pred, pred_label, true_label, num_classes):
    ax.grid(False)
    ax.set_ylim([0, 1])

    thisplot = ax.bar(range(num_classes), pred, color="#777777")

    thisplot[pred_label].set_color("red")
    thisplot[true_label].set_color("blue")

def plots_result(info_iter,n,num_classes, title = "Predict Analysis"):
    """
    show how to predict
    
    Args:
        search_information_iter : information iter
        n : grid space
    """
    images = []
    labels = []
    preds = []

    for _ in range(n*n):
        search = next(search_information_iter)
        images.append(search[0])
        labels.append(search[1])
        preds.append(search[2])

    images = np.array(images)
    labels = np.array(labels)
    preds = np.array(preds)       

    preds = softmax(preds, axis = 1)
    
    num_rows = num_cols = int(len(images) ** 0.5)
#     print("num_rows", num_rows)
#     print("len(images) ** 0.5", len(images) ** 0.5)
#     print("len(images)", len(images))
    num_images = num_rows * num_cols
    
    fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols * 2, figsize=(36, 18))
    fig.suptitle(title, fontsize=54)

    plt.setp(axes, xticks=[], yticks=[])
    
    for idx in range(num_images):
        image, pred, label = images[idx], preds[idx], labels[idx]

        num_row = idx // num_rows
        num_col = idx % num_cols

        pred_label = np.argmax(pred)
        true_label = label

        _log_plots_image(
            axes[num_row][num_col * 2], image, pred, pred_label, true_label, num_classes
        )

        _log_plots_distribution(
            axes[num_row][num_col * 2 + 1], pred, pred_label, true_label, num_classes
        )

    return fig

In [1]:
fig = plots_result(search_information_iter, 4 ,num_classes)

## visualization 실습

In [21]:
def normalize(tensor):
    x = tensor - tensor.min()
    x = x / (x.max() + 1e-9)
    return x

def draw_border(image_np, color):
    color = np.asarray(color)
    s = image_np.shape
    image_np = image_np.copy()
    image_np[0:5, :, :] = color[np.newaxis, np.newaxis, :]
    image_np[:, 0:5, :] = color[np.newaxis, np.newaxis, :]
    image_np[s[0]-5:s[0], :, :] = color[np.newaxis, np.newaxis, :]
    image_np[:, s[0]-5:s[0], :] = color[np.newaxis, np.newaxis, :]
    return image_np

def show_image(image, title=None):
    np_img = image_tensor_to_numpy(image)
    if len(np_img.shape) > 3:
        np_img = np_img[0]
    np_img = normalize(np_img)
    
    # plot 
    np_img = np_img.squeeze()
    plt.figure(figsize=(4,4))
    plt.imshow(np_img)
    plt.axis('off')
    if title: 
        plt.title(title)
    plt.show()
    
def show_images(image_list):
    for l in image_list:
        f, axarr = plt.subplots(1,len(l))
    for i,img in enumerate(l):
        np_img = image_tensor_to_numpy(img)
        if len(np_img.shape) > 3:
            np_img = np_img[0]
        np_img = normalize(np_img)
        
        np_img = np_img.squeeze()
        axarr[i].imshow(np_img)
        axarr[i].axis('off')
    plt.show()