# Evaluate

This notebook evaluates the network trained previous notebooks and analyzes the results.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import sys
import pickle
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd

import pprint
import wandb
import matplotlib
import matplotlib.pyplot as plt

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_and_throughput
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.11.0+cu113
cuda is available.


-----

## List up the models to check accuracy

In [4]:
# model_names = ['cix5o8ut', '76efnucq', '6a1tnz3r']
model_names = ['lo88puq7',
               'glem9euv',
               'ehh4f6gj',
               '1sl7ipca',
               'gefkh07l',
               'v301o425',
               'l8524nml',
               't1xzqrjh',
               'ikik8fuj',
               'gvqyvmrj',
               'j5ybdq18',
               'gchbqf3f',
               'gexwncpq',
               '2svr4k2h',
               'vuk1myed',
               'cix5o8ut',
               'cgszsx9e',
               'i8bx2s6r',
               'xci5svkl',
               '6a1tnz3r',
               'r7fr150u',
               '1tdyketc',
               '4lv906r9',
               'gjkysllw',
               't0j6ipeb',
               'd37sa3wb',
               'z6yuten2',]
model_pool = []

for model_name in model_names:
    path = os.path.join(r'E:\CAUEEG\checkpoint', model_name, 'checkpoint.pt')
    try:
        ckpt = torch.load(path, map_location=device)
        model_pool.append({'name': model_name, 'path': path})
    except Exception as e:
        print(e)
        print(f'- {run.name}\'s checkpoint cannot be opened: {path}')
        
pprint.pprint([model_dict['name'] for model_dict in model_pool])

['lo88puq7',
 'glem9euv',
 'ehh4f6gj',
 '1sl7ipca',
 'gefkh07l',
 'v301o425',
 'l8524nml',
 't1xzqrjh',
 'ikik8fuj',
 'gvqyvmrj',
 'j5ybdq18',
 'gchbqf3f',
 'gexwncpq',
 '2svr4k2h',
 'vuk1myed',
 'cix5o8ut',
 'cgszsx9e',
 'i8bx2s6r',
 'xci5svkl',
 '6a1tnz3r',
 'r7fr150u',
 '1tdyketc',
 '4lv906r9',
 'gjkysllw',
 't0j6ipeb',
 'd37sa3wb',
 'z6yuten2']


In [5]:
# model_pool = []

# api = wandb.Api()
# runs = api.runs('ipis-mjkim/caueeg-task2-ablation')

# for run in runs:
#     path = os.path.join(r'E:\CAUEEG\checkpoint', run.name, 'checkpoint.pt')
#     try:
#         ckpt = torch.load(path, map_location=device)
#         model_pool.append({'name': run.name, 'path': path})
#     except Exception as e:
#         print(e)
#         print(f'- {run.name}\'s checkpoint cannot be opened: {path}')
        
# pprint.pprint([model_dict['name'] for model_dict in model_pool])

---

## Configurations

In [6]:
base_repeat = 200
verbose = False

-----

## Load and check accuracy

In [7]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    optimizer = ckpt['optimizer_state']
    scheduler = ckpt['scheduler_state']
    
    print('- checking for', model_dict['name'], config['model'], '...')
    
    # initiate the model
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    elif type(config['generator']) is str:
        config['generator'] = getattr(models, config['generator'].split('.')[-1])
        if 'block' in config:
            config['block'] = getattr(models, config['block'].split('.')[-1])
        model = config['generator'](**config).to(device)
    else:
        if 'block' in config:
            if config['block'] == models.resnet_1d.BottleneckBlock1D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_2d.Bottleneck2D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_1d.BasicBlock1D:
                config['block'] = 'basic'
            elif config['block'] == models.resnet_2d.BasicBlock2D:
                config['block'] = 'basic'
                
        model = config['generator'](**config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    
    # reconfigure
    config.pop('cwd', 0)
    config['ddp'] = False
    config['crop_timing_analysis'] = False
    config['eval'] = True
    config['device'] = device
    repeat = round(base_repeat / config['crop_multiple'])
    
    # build dataset
    _ = build_dataset_for_train(config, verbose=verbose)
    train_loader = _[0]
    val_loader = _[1]
    test_loader = _[2]
    multicrop_test_loader = _[3]
    
    config = ckpt['config']  ##########################################################
    
    # train accuracy
    train_acc = check_accuracy(model, train_loader, 
                               config['preprocess_test'], config, repeat=repeat)
    model_dict['Train Accuracy'] = train_acc
    
    # val accuracy
    val_acc = check_accuracy(model, val_loader, 
                             config['preprocess_test'], config, repeat=repeat)
    model_dict['Validation Accuracy'] = val_acc
    
    # Test accuracy
    _ = check_accuracy_and_throughput(model, test_loader, 
                                      config['preprocess_test'], config, repeat=repeat)
    model_dict['Test Accuracy'] = _[0]
    throughput = _[1]

    # Multi-crop test accuracy
    multi_test_acc = check_accuracy_multicrop(model, multicrop_test_loader, 
                                              config['preprocess_test'], config, repeat=repeat)
    model_dict['Multi-Crop Test Accuracy'] = multi_test_acc
    model_dict['Throughput'] = throughput

- checking for lo88puq7 2D-VGG-19 ...
- checking for glem9euv 2D-VGG-19 ...
- checking for ehh4f6gj 1D-ResNet-18 ...
- checking for 1sl7ipca 2D-ResNeXt-50 ...
- checking for gefkh07l 2D-ResNeXt-50 ...
- checking for v301o425 1D-ResNeXt-50 ...
- checking for l8524nml 1D-ResNet-18 ...
- checking for t1xzqrjh 1D-ResNet-18 ...
- checking for ikik8fuj 1D-ResNet-18 ...
- checking for gvqyvmrj 1D-ResNet-50 ...
- checking for j5ybdq18 1D-ResNet-18 ...
- checking for gchbqf3f 2D-ResNeXt-50 ...
- checking for gexwncpq 1D-ResNet-18 ...
- checking for 2svr4k2h 1D-ResNet-18 ...
- checking for vuk1myed 1D-ResNet-18 ...
- checking for cix5o8ut 1D-ResNet-18 ...
- checking for cgszsx9e 1D-ResNet-50 ...
- checking for i8bx2s6r 1D-ResNet-18 ...
- checking for xci5svkl 2D-ResNet-18 ...
- checking for 6a1tnz3r 2D-ResNeXt-50 ...
- checking for r7fr150u 1D-ResNet-18 ...
- checking for 1tdyketc 2D-ResNet-18 ...
- checking for 4lv906r9 1D-ResNet-18 ...
- checking for gjkysllw 2D-ViT-B-16 ...
- checking for t0j

In [8]:
pprint.pprint(model_pool)

[{'Multi-Crop Test Accuracy': 69.91525423728814,
  'Test Accuracy': 68.27542372881356,
  'Throughput': 0.002077540856598075,
  'Train Accuracy': 100.0,
  'Validation Accuracy': 64.65126050420169,
  'name': 'lo88puq7',
  'path': 'E:\\CAUEEG\\checkpoint\\lo88puq7\\checkpoint.pt'},
 {'Multi-Crop Test Accuracy': 69.33898305084746,
  'Test Accuracy': 63.74576271186441,
  'Throughput': 0.0029901065251292697,
  'Train Accuracy': 92.52473684210527,
  'Validation Accuracy': 60.945378151260506,
  'name': 'glem9euv',
  'path': 'E:\\CAUEEG\\checkpoint\\glem9euv\\checkpoint.pt'},
 {'Multi-Crop Test Accuracy': 68.35593220338983,
  'Test Accuracy': 66.52966101694915,
  'Throughput': 0.005824099987002574,
  'Train Accuracy': 99.99720982142857,
  'Validation Accuracy': 59.49159663865546,
  'name': 'ehh4f6gj',
  'path': 'E:\\CAUEEG\\checkpoint\\ehh4f6gj\\checkpoint.pt'},
 {'Multi-Crop Test Accuracy': 68.03389830508475,
  'Test Accuracy': 64.70338983050847,
  'Throughput': 0.001304019380401496,
  'Train 

In [13]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
        
    model_dict['model'] = config['model']
    model_dict['num_params'] = config.get('num_params', '???')
    model_dict['model size (MiB)'] = sys.getsizeof(pickle.dumps(model_state)) / (1024 * 1024)
    # torch.save(model_state, 'temptemptemp.pt')
    
    model_dict['seq_length'] = config['seq_length']
    model_dict['use_age'] = config['use_age']
    model_dict['photic'] = config['photic']
    model_dict['EKG'] = config['EKG']

    model_dict['awgn'] = config.get('awgn', 0)
    model_dict['awgn_age'] = config.get('awgn_age', 0)
    model_dict['mgn'] = config.get('mgn', 0)
    model_dict['mixup'] = config.get('mixup', 0)
    model_dict['dropout'] = config.get('dropout', 0)
    model_dict['weight_decay'] = config.get('weight_decay', '???')
    model_dict['fc_stages'] = config.get('fc_stages', 1)
    
    model_dict['minibatch'] = round(config['minibatch'])
    model_dict['total_samples'] = round(config.get('total_samples', config['iterations'] * config['minibatch']))
    model_dict['base_lr'] = config.get('base_lr', config.get('LR', '???'))
    model_dict['lr_scheduler_type'] = config.get('lr_scheduler_type', 'constant_with_decay')
    model_dict['warmup_steps'] = config.get('warmup_steps', '???')
    model_dict['seed'] = config.get('seed', '???')

In [14]:
pd.DataFrame(model_pool)

Unnamed: 0,name,path,Train Accuracy,Validation Accuracy,Test Accuracy,Multi-Crop Test Accuracy,Throughput,model,num_params,model size (MiB),...,mixup,dropout,weight_decay,fc_stages,minibatch,total_samples,base_lr,lr_scheduler_type,warmup_steps,seed
0,lo88puq7,E:\CAUEEG\checkpoint\lo88puq7\checkpoint.pt,100.0,64.651261,68.275424,69.915254,2077.540857,2D-VGG-19,20184131,77.073749,...,0.3,0.201491,0.012529,2,128,100000000,8.5e-05,constant_with_decay,39062,0
1,glem9euv,E:\CAUEEG\checkpoint\glem9euv\checkpoint.pt,92.524737,60.945378,63.745763,69.338983,2990.106525,2D-VGG-19,20225763,77.23999,...,0.2,0.262623,0.001386,5,192,31999872,8e-05,constant_with_twice_decay,8333,???
2,ehh4f6gj,E:\CAUEEG\checkpoint\ehh4f6gj\checkpoint.pt,99.99721,59.491597,66.529661,68.355932,5824.099987,1D-ResNet-18,11358787,43.413528,...,0.0,0.3,0.007744,2,256,100000000,0.000121,linear_decay_with_warmup,19531,0
3,1sl7ipca,E:\CAUEEG\checkpoint\1sl7ipca\checkpoint.pt,97.394211,63.647059,64.70339,68.033898,1304.01938,2D-ResNeXt-50,25886467,99.1367,...,0.0,0.041975,0.015874,5,192,31999872,0.00253,constant_with_twice_decay,8333,???
4,gefkh07l,E:\CAUEEG\checkpoint\gefkh07l\checkpoint.pt,99.601695,61.97479,65.508475,67.033898,1845.540406,2D-ResNeXt-50,25199235,96.502495,...,0.2,0.116098,0.022836,2,64,100000000,0.000899,cosine_decay_with_warmup_one_and_half,78125,0
5,v301o425,E:\CAUEEG\checkpoint\v301o425\checkpoint.pt,99.729911,65.147059,64.949153,68.661017,5733.165558,1D-ResNeXt-50,25650051,98.227805,...,0.1,0.272182,0.048178,3,256,50000000,0.001897,cosine_decay_with_warmup_half,9766,0
6,l8524nml,E:\CAUEEG\checkpoint\l8524nml\checkpoint.pt,99.412388,61.516807,65.665254,67.322034,11142.824765,1D-ResNet-18,11394051,43.550922,...,0.2,0.3,0.043947,3,256,100000000,0.000469,cosine_decay_with_warmup_half,19531,0
7,t1xzqrjh,E:\CAUEEG\checkpoint\t1xzqrjh\checkpoint.pt,99.982701,61.037815,65.411017,66.79661,9256.958633,1D-ResNet-18,11227971,42.910619,...,0.0,0.3,0.049071,1,256,100000000,0.000238,cosine_decay_with_warmup_one_and_half,19531,0
8,ikik8fuj,E:\CAUEEG\checkpoint\ikik8fuj\checkpoint.pt,100.0,56.546218,64.855932,66.881356,6002.273648,1D-ResNet-18,11391427,43.540937,...,0.0,0.3,0.037617,3,256,100000000,0.000238,linear_decay_with_warmup,19531,0
9,gvqyvmrj,E:\CAUEEG\checkpoint\gvqyvmrj\checkpoint.pt,99.289474,60.768908,61.360169,67.101695,8805.0907,1D-ResNet-50,26178179,100.184862,...,0.2,0.268876,4e-05,3,192,31999872,0.001423,constant_with_twice_decay,8333,???


In [15]:
pd.DataFrame(model_pool).to_csv('local/output/caueeg_task2_summary.csv')