## evaluate for Classification with sklearn

In [1]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix,classification_report
# augmentation
from albumentations.augmentations.transforms import Resize,Normalize
from albumentations import Compose

from dataset import load_train_data
import models
import constants as cons
from convenient_function import fix_model_state_dict

In [2]:
TRAIN_IMAGES_PATH='data'
TRAIN_LABELS_PATH='data'
PARAM_PATH='params/model_se_weight/se_resnet50/fold0/model_se_weight_790.pth'
fix_state_dict = False
train_seed = 98

In [3]:
def valid_loop(model, loader, criterion):
    model.eval()
    result = {}
    labels_cm = []
    pred_cm = []
    with torch.no_grad():
        total_loss, total_correct, total_num = 0, 0, 0
        for feed in tqdm(loader):
            
            inputs, labels = feed
            inputs, labels = inputs.cuda(), labels.cuda()
            
            outputs = model(inputs)
            
            # make confusion matrix
            pred_cm= np.concatenate([pred_cm,outputs.data.max(1,keepdim=False)[1].cpu().numpy()])
            labels_cm= np.concatenate([labels_cm,labels.data.max(1,keepdim=False)[1].cpu().numpy()])
            # calcurate loss, acc
            loss = criterion(outputs.double(), labels)
            pred = outputs.data.max(1, keepdim=True)[1]
            correct = pred.eq(labels.data.max(1,keepdim=True)[1]).sum()
            
            labels = labels.cpu()
            total_loss += loss.item() * labels.size(0)
            total_correct += correct.item()
            total_num += labels.size(0)
        print(pred_cm)
        print(labels_cm)
        result["confusion_matrix"]=confusion_matrix(labels_cm,pred_cm)
        result["precision_score"]=classification_report(labels_cm,pred_cm)
    return total_loss / total_num, total_correct / total_num * 100, result

# get Evaluation Criterion
- Confusion Matrix

- precision / recall / **Fscore**
- macro average / weighted average

[reference](https://note.nkmk.me/python-sklearn-confusion-matrix-score/)

In [4]:
valid_loader = load_train_data(
    
    train_images_path=TRAIN_IMAGES_PATH,
    train_labels_path=TRAIN_LABELS_PATH,
    batch_size=32,
    valid=True,
    nfold=int(PARAM_PATH.split('/')[3][4]),
    seed=train_seed,
    transform=Compose([
        Resize(cons.IMAGE_SIZE,cons.IMAGE_SIZE),
        Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5),max_pixel_value=255.0)])
)

model = models.get_model(model_name=PARAM_PATH.split('/')[2], num_classes=cons.NUM_CLASSES)
if fix_state_dict:
    model.load_state_dict(fix_model_state_dict(torch.load(PARAM_PATH)))
else:
    model.load_state_dict(torch.load(PARAM_PATH))
optimizer = optim.Adam(model.parameters(), lr=cons.start_lr)
criterion = nn.BCEWithLogitsLoss()

valid_loss, valid_acc ,result= valid_loop(model,valid_loader,criterion)
print(valid_acc)
print('- confusion matrix(x,t)\n',result["confusion_matrix"])
print('\n- precision score\n',result["precision_score"])

100%|██████████| 20/20 [00:02<00:00,  8.15it/s]

[6. 4. 0. 0. 6. 4. 3. 1. 8. 9. 4. 6. 8. 0. 1. 0. 4. 6. 3. 0. 9. 2. 6. 8.
 2. 1. 9. 3. 7. 0. 1. 5. 6. 4. 3. 8. 0. 8. 8. 6. 8. 8. 3. 1. 6. 3. 3. 7.
 1. 3. 6. 0. 3. 7. 0. 3. 0. 0. 5. 4. 8. 0. 5. 4. 6. 0. 0. 2. 6. 6. 7. 2.
 4. 1. 4. 5. 4. 1. 5. 0. 0. 8. 4. 6. 8. 3. 2. 1. 2. 8. 8. 1. 1. 3. 1. 1.
 4. 1. 7. 0. 1. 2. 1. 6. 3. 1. 4. 6. 3. 4. 4. 2. 0. 0. 2. 0. 1. 9. 4. 4.
 4. 2. 0. 3. 5. 3. 6. 0. 2. 1. 1. 0. 3. 3. 4. 5. 1. 3. 6. 1. 6. 6. 1. 1.
 6. 0. 6. 0. 6. 9. 6. 0. 2. 0. 6. 6. 0. 6. 0. 1. 6. 1. 1. 2. 8. 5. 2. 6.
 8. 7. 0. 4. 5. 6. 6. 5. 8. 0. 4. 1. 8. 2. 3. 3. 1. 1. 1. 4. 1. 4. 0. 1.
 0. 4. 5. 1. 2. 6. 0. 3. 0. 1. 3. 2. 0. 1. 9. 6. 2. 2. 3. 0. 2. 9. 0. 8.
 1. 9. 2. 8. 2. 1. 0. 5. 8. 2. 1. 3. 4. 6. 0. 4. 0. 6. 1. 1. 1. 1. 1. 4.
 1. 3. 0. 0. 1. 6. 7. 9. 8. 2. 1. 9. 3. 6. 2. 0. 3. 2. 5. 0. 6. 5. 2. 5.
 1. 1. 6. 1. 7. 1. 1. 3. 7. 3. 1. 4. 0. 5. 7. 6. 3. 9. 6. 4. 0. 0. 0. 1.
 1. 2. 4. 0. 2. 4. 5. 0. 3. 2. 1. 9. 0. 1. 0. 0. 1. 2. 2. 2. 5. 2. 2. 1.
 1. 0. 5. 8. 0. 0. 3. 1. 2. 0. 0. 8. 0. 0. 5. 3. 6.


