# Evaluate

This notebook evaluates the network trained previous notebooks and analyzes the results.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd

import pprint
import wandb
import matplotlib
import matplotlib.pyplot as plt

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.11.0+cu113
cuda is available.


-----

## List up the models to check accuracy

In [4]:
# model_pool = ['1sl7ipca', ]

In [5]:
model_pool = []

api = wandb.Api()
runs = api.runs('ipis-mjkim/caueeg-task2-ablation')

for run in runs:
    path = os.path.join(r'E:\CAUEEG\checkpoint', run.name, 'checkpoint.pt')
    try:
        ckpt = torch.load(path, map_location=device)
        model_pool.append({'name': run.name, 'path': path})
    except Exception as e:
        print(e)
        print(f'- {run.name}\'s checkpoint cannot be opened: {path}')
        
pprint.pprint([model_dict['name'] for model_dict in model_pool])

['1242j086',
 'dejxbfxu',
 '2z7is3m0',
 '2h30v12p',
 '3c9rbdxu',
 '2615y2jm',
 '2452kedi',
 '2k8xomy6',
 '2bw00w05',
 '3874otvp']


---

## Configurations

In [6]:
base_repeat = 200
verbose = False

-----

## Load and check accuracy

In [7]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    optimizer = ckpt['optimizer_state']
    scheduler = ckpt['scheduler_state']
    
    # initiate the model
    model = hydra.utils.instantiate(config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    
    # reconfigure
    config.pop('cwd', 0)
    config['ddp'] = False
    config['crop_timing_analysis'] = True
    config['eval'] = True
    config['device'] = device
    repeat = round(base_repeat / config['crop_multiple'])
    
    # build dataset
    _ = build_dataset_for_train(config, verbose=verbose)
    train_loader = _[0]
    val_loader = _[1]
    test_loader = _[2]
    multicrop_test_loader = _[3]
    
    # train accuracy
    train_acc = check_accuracy(model, train_loader, 
                               config['preprocess_test'], config, repeat=repeat)
    model_dict['Train Accuracy'] = train_acc
    
    # val accuracy
    val_acc = check_accuracy(model, val_loader, 
                             config['preprocess_test'], config, repeat=repeat)
    model_dict['Validation Accuracy'] = val_acc
    
    # Test accuracy
    test_acc = check_accuracy(model, test_loader, 
                              config['preprocess_test'], config, repeat=repeat)
    model_dict['Test Accuracy'] = test_acc

    # Multi-crop test accuracy
    multi_test_acc = check_accuracy_multicrop(model, multicrop_test_loader, 
                                              config['preprocess_test'], config, repeat=repeat)
    model_dict['Multi-Crop Test Accuracy'] = multi_test_acc

In [8]:
pprint.pprint(model_pool)

[{'Multi-Crop Test Accuracy': 58.86440677966102,
  'Test Accuracy': 55.69491525423729,
  'Train Accuracy': 100.0,
  'Validation Accuracy': 60.94957983193277,
  'name': '1242j086',
  'path': 'E:\\CAUEEG\\checkpoint\\1242j086\\checkpoint.pt'},
 {'Multi-Crop Test Accuracy': 65.15254237288136,
  'Test Accuracy': 62.65677966101695,
  'Train Accuracy': 100.0,
  'Validation Accuracy': 60.621848739495796,
  'name': 'dejxbfxu',
  'path': 'E:\\CAUEEG\\checkpoint\\dejxbfxu\\checkpoint.pt'},
 {'Multi-Crop Test Accuracy': 61.05084745762712,
  'Test Accuracy': 59.74152542372882,
  'Train Accuracy': 99.98158482142857,
  'Validation Accuracy': 62.13025210084034,
  'name': '2z7is3m0',
  'path': 'E:\\CAUEEG\\checkpoint\\2z7is3m0\\checkpoint.pt'},
 {'Multi-Crop Test Accuracy': 56.220338983050844,
  'Test Accuracy': 55.894067796610166,
  'Train Accuracy': 99.83928571428571,
  'Validation Accuracy': 54.365546218487395,
  'name': '2h30v12p',
  'path': 'E:\\CAUEEG\\checkpoint\\2h30v12p\\checkpoint.pt'},
 {'M

In [10]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    
    model_dict['awgn'] = config['awgn']
    model_dict['awgn_age'] = config['awgn_age']
    model_dict['mgn'] = config['mgn']
    model_dict['mixup'] = config['mixup']
    model_dict['dropout'] = config['dropout']
    model_dict['fc_stages'] = config['fc_stages']
    model_dict['use_age'] = config['use_age']
    model_dict['photic'] = config['photic']
    model_dict['EKG'] = config['EKG']
    model_dict['seed'] = config['seed']

In [12]:
pd.DataFrame(model_pool).to_csv('temp.csv')
pd.DataFrame(model_pool)

Unnamed: 0,name,path,Train Accuracy,Validation Accuracy,Test Accuracy,Multi-Crop Test Accuracy,awgn,awgn_age,mgn,mixup,dropout,fc_stages,use_age,photic,EKG,seed
0,1242j086,E:\CAUEEG\checkpoint\1242j086\checkpoint.pt,100.0,60.94958,55.694915,58.864407,0.0,0.0,0.0,0.0,0.0,1,no,X,X,0
1,dejxbfxu,E:\CAUEEG\checkpoint\dejxbfxu\checkpoint.pt,100.0,60.621849,62.65678,65.152542,0.004873,0.035834,0.095756,0.2,0.3,3,conv,O,O,0
2,2z7is3m0,E:\CAUEEG\checkpoint\2z7is3m0\checkpoint.pt,99.981585,62.130252,59.741525,61.050847,0.0,0.0,0.095756,0.2,0.3,3,conv,O,O,1
3,2h30v12p,E:\CAUEEG\checkpoint\2h30v12p\checkpoint.pt,99.839286,54.365546,55.894068,56.220339,0.0,0.0,0.0,0.0,0.0,1,no,O,O,0
4,3c9rbdxu,E:\CAUEEG\checkpoint\3c9rbdxu\checkpoint.pt,100.0,57.163866,65.0,66.983051,0.0,0.0,0.0,0.0,0.0,1,conv,O,O,0
5,2615y2jm,E:\CAUEEG\checkpoint\2615y2jm\checkpoint.pt,99.997768,60.785714,62.894068,63.728814,0.0,0.0,0.0,0.2,0.0,3,conv,O,O,0
6,2452kedi,E:\CAUEEG\checkpoint\2452kedi\checkpoint.pt,100.0,57.470588,64.20339,65.779661,0.0,0.0,0.0,0.0,0.0,3,conv,O,O,0
7,2k8xomy6,E:\CAUEEG\checkpoint\2k8xomy6\checkpoint.pt,100.0,58.478992,63.152542,64.745763,0.0,0.0,0.0,0.0,0.3,3,conv,O,O,0
8,2bw00w05,E:\CAUEEG\checkpoint\2bw00w05\checkpoint.pt,99.829241,62.504202,62.152542,63.542373,0.0,0.0,0.0,0.2,0.3,3,conv,O,O,0
9,3874otvp,E:\CAUEEG\checkpoint\3874otvp\checkpoint.pt,99.999442,59.004202,63.122881,64.474576,0.0,0.0,0.095756,0.2,0.3,3,conv,O,O,0
