## evaluate for Classification with sklearn

In [1]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix,classification_report
# augmentation
from albumentations.augmentations.transforms import Resize,Normalize
from albumentations import Compose

from dataset import load_train_data
from dataset import UkiyoeTestDataset
from dataset import UkiyoeTrainDataset

import models
import constants as cons
from convenient_function import fix_model_state_dict

In [2]:
train_images_path ='data'
train_labels_path ='data'
params_path ='params'

PARAM_LIST = [
    {'model':'/model_dense/densenet161/fold0/model_dense_880.pth','fix':False,'ratio':1},
    {'model':'/aug_decrease_effi_b3/efficientnet_b3/fold0/aug_decrease_effi_b3_790.pth','fix':True,'ratio':1},
    {'model':'/model_effi_b3/efficientnet_b3/fold0/model_effi_b3_980.pth','fix':True,'ratio':1},
    {'model':'/model_effi_b3_sampling/efficientnet_b3/fold0/with_pseudo_labeling/model_effi_b3_sampling_best.pth','fix':False,'ratio':2},
    {'model':'/model_senet154/senet154/fold0/model_senet154_680.pth','fix':True,'ratio':1},
    {'model':'/partical_augmentation/senet154/fold0/partical_augmentation_380.pth','fix':True,'ratio':1},
    {'model':'/se_oversampling/se_resnet50/fold0/se_oversampling_490.pth','fix':False,'ratio':3},
    {'model':'/ince3_w/inceptionv3/fold0/ince3_weight_670.pth','fix':False,'ratio':3}
]
DATA_LIST = []

In [3]:
def valid_loop(model, loader,data_len):
    model.eval()
    output = np.zeros((data_len,cons.NUM_CLASSES))
    label = np.zeros((data_len,cons.NUM_CLASSES))
    with torch.no_grad():
        for idx, feed in enumerate(tqdm(loader)):
            inputs, labels = feed
            inputs, labels = inputs.cuda(),labels.cuda()
            #forward
            outputs = torch.sigmoid(model(inputs))
            output[idx] = outputs.data.cpu().numpy()
            label[idx] = labels.data.cpu().numpy()
    return output,label

In [4]:
def confidence_check(pred,confidence,data_len):
    inference_ratio = np.sum(np.identity(cons.NUM_CLASSES)[pred],axis=0)

    print(inference_ratio)

In [5]:
dataset = UkiyoeTestDataset(data_path='data')
print(dataset.__len__())
print(dataset.__getitem__(396))

397
(397, tensor([[[254, 255, 246,  ..., 253, 239, 252],
         [253, 255, 245,  ..., 255, 240, 247],
         [248, 254, 247,  ..., 255, 247, 250],
         ...,
         [248, 251, 249,  ..., 243, 233, 250],
         [245, 252, 246,  ..., 238, 235, 245],
         [244, 254, 244,  ..., 247, 241, 233]],

        [[234, 235, 224,  ..., 240, 228, 244],
         [232, 235, 223,  ..., 246, 229, 236],
         [227, 234, 225,  ..., 244, 234, 239],
         ...,
         [239, 240, 238,  ..., 236, 227, 245],
         [235, 241, 234,  ..., 238, 235, 246],
         [234, 242, 232,  ..., 250, 245, 236]],

        [[233, 230, 213,  ..., 221, 208, 223],
         [231, 230, 212,  ..., 224, 209, 216],
         [224, 227, 212,  ..., 223, 215, 219],
         ...,
         [230, 234, 232,  ..., 226, 227, 251],
         [233, 239, 234,  ..., 230, 235, 251],
         [233, 244, 234,  ..., 241, 246, 243]]], dtype=torch.uint8))


In [7]:
dataset = UkiyoeTrainDataset(
    train_images_path='data',
    train_labels_path='data',
    valid=True,
    transform=Compose([
        Resize(cons.IMAGE_SIZE,cons.IMAGE_SIZE),
        Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5),max_pixel_value=255.0)],p=1),
    as_numpy=True)

print(dataset.__len__())
print(dataset.__getitem__(1))

631
(array([[[0.8745099 , 0.9294118 , 0.86666673],
        [0.882353  , 0.93725497, 0.8745099 ],
        [0.882353  , 0.93725497, 0.8745099 ],
        ...,
        [0.9215687 , 0.9215687 , 0.9058824 ],
        [0.9294118 , 0.91372555, 0.95294124],
        [0.9294118 , 0.9058824 , 0.9607844 ]],

       [[0.8745099 , 0.9294118 , 0.86666673],
        [0.882353  , 0.93725497, 0.8745099 ],
        [0.882353  , 0.93725497, 0.8745099 ],
        ...,
        [0.9215687 , 0.9215687 , 0.9058824 ],
        [0.9294118 , 0.91372555, 0.95294124],
        [0.9294118 , 0.9058824 , 0.9607844 ]],

       [[0.882353  , 0.93725497, 0.8745099 ],
        [0.882353  , 0.93725497, 0.8745099 ],
        [0.882353  , 0.93725497, 0.8745099 ],
        ...,
        [0.9215687 , 0.9215687 , 0.9058824 ],
        [0.9294118 , 0.91372555, 0.95294124],
        [0.9294118 , 0.9058824 , 0.9607844 ]],

       ...,

       [[0.8588236 , 0.8352942 , 0.7960785 ],
        [0.8431373 , 0.8352942 , 0.7960785 ],
        [0.843137