# Evaluate

This notebook evaluates the network trained previous notebooks and analyzes the results.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import sys
import pickle
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd

import pprint
import wandb
import matplotlib
import matplotlib.pyplot as plt

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.evaluate import calculate_class_wise_metrics
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.11.0+cu113
cuda is available.


-----

## List up the models to check accuracy

In [5]:
model_names = [*set([
    '22pj8bg1',
    # '1663fxhy',
    # '3td60j3k',
    # 'umtq5iu6',
    # 'f9g4k64e',
    # 'kxyux4lz',
    # '03iz3rso',
    # '9d9rnxrx',
    # '3smz5tul',
    # '34oksw2q',
    # 'pz76l4fq',
    # 'ocf38th1',
    # 'nd1ss8ci',
    # 'nkrh4pji',
])]

model_pool = []

for model_name in model_names:
    path = os.path.join(r'E:\CAUEEG\checkpoint', model_name, 'checkpoint.pt')
    try:
        ckpt = torch.load(path, map_location=device)
        model_pool.append({'name': model_name, 'path': path})
    except Exception as e:
        print(e)
        print(f'- checkpoint cannot be opened: {path}')
        
pprint.pprint([model_dict['name'] for model_dict in model_pool])

['22pj8bg1']


In [6]:
# model_pool = []

# api = wandb.Api()
# runs = api.runs('ipis-mjkim/caueeg-task2-ablation')

# for run in runs:
#     path = os.path.join(r'E:\CAUEEG\checkpoint', run.name, 'checkpoint.pt')
#     try:
#         ckpt = torch.load(path, map_location=device)
#         model_pool.append({'name': run.name, 'path': path})
#     except Exception as e:
#         print(e)
#         print(f'- {run.name}\'s checkpoint cannot be opened: {path}')
        
# pprint.pprint([model_dict['name'] for model_dict in model_pool])

---

## Configurations

In [7]:
base_repeat = 32
verbose = False
dataset_path = r'local\dataset\02_Curated_Data_220715_seg_30s'

-----

## Load and check accuracy

In [8]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    config['dataset_path'] = dataset_path
    
    model_dict['model'] = config['model']
    model_dict['num_params'] = config.get('num_params', '???')
    model_dict['model size (MiB)'] = sys.getsizeof(pickle.dumps(model_state)) / (1024 * 1024)
    # torch.save(model_state, 'temptemptemp.pt')
    
    model_dict['seq_length'] = config['seq_length']
    model_dict['use_age'] = config['use_age']
    model_dict['photic'] = config['photic']
    model_dict['EKG'] = config['EKG']

    model_dict['awgn'] = config.get('awgn', 0)
    model_dict['awgn_age'] = config.get('awgn_age', 0)
    model_dict['mgn'] = config.get('mgn', 0)
    model_dict['mixup'] = config.get('mixup', 0)
    model_dict['dropout'] = config.get('dropout', 0)
    model_dict['weight_decay'] = config.get('weight_decay', '???')
    model_dict['fc_stages'] = config.get('fc_stages', 1)
    
    model_dict['minibatch'] = round(config['minibatch'])
    model_dict['total_samples'] = round(config.get('total_samples', config['iterations'] * config['minibatch']))
    model_dict['base_lr'] = config.get('base_lr', config.get('LR', '???'))
    model_dict['lr_scheduler_type'] = config.get('lr_scheduler_type', 'constant_with_decay')
    model_dict['warmup_steps'] = config.get('warmup_steps', '???')
    model_dict['seed'] = config.get('seed', '???')
    
    print('- checking for', model_dict['name'], config['model'], '...')
    
    # initiate the model
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    elif type(config['generator']) is str:
        config['generator'] = getattr(models, config['generator'].split('.')[-1])
        if 'block' in config:
            config['block'] = getattr(models, config['block'].split('.')[-1])
        model = config['generator'](**config).to(device)
    else:
        if 'block' in config:
            if config['block'] == models.resnet_1d.BottleneckBlock1D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_2d.Bottleneck2D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_1d.BasicBlock1D:
                config['block'] = 'basic'
            elif config['block'] == models.resnet_2d.BasicBlock2D:
                config['block'] = 'basic'
                
        model = config['generator'](**config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    
    # reconfigure and update
    config.pop('cwd', 0)
    config['ddp'] = False
    config['crop_timing_analysis'] = False
    config['eval'] = True
    config['device'] = device
    
    repeat = round(base_repeat / config['crop_multiple'])
    model_dict['repeat'] = repeat
    model_dict['crop_multiple'] = config['crop_multiple']
    model_dict['test_crop_multiple'] = config['test_crop_multiple']
    
    # build dataset
    _ = build_dataset_for_train(config, verbose=verbose)
    train_loader = _[0]
    val_loader = _[1]
    test_loader = _[2]
    multicrop_test_loader = _[3]

- checking for 22pj8bg1 1D-Linear-SVM ...
47
6
6


In [10]:
print(len(train_loader.dataset))
print(len(val_loader.dataset))
print(len(test_loader.dataset))

24417
3052
3052


In [7]:
for model_dict in model_pool:
    # load and parse the checkpoint
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    config['dataset_path'] = dataset_path
    
    model_dict['model'] = config['model']
    model_dict['num_params'] = config.get('num_params', '???')
    model_dict['model size (MiB)'] = sys.getsizeof(pickle.dumps(model_state)) / (1024 * 1024)
    # torch.save(model_state, 'temptemptemp.pt')
    
    model_dict['seq_length'] = config['seq_length']
    model_dict['use_age'] = config['use_age']
    model_dict['photic'] = config['photic']
    model_dict['EKG'] = config['EKG']

    model_dict['awgn'] = config.get('awgn', 0)
    model_dict['awgn_age'] = config.get('awgn_age', 0)
    model_dict['mgn'] = config.get('mgn', 0)
    model_dict['mixup'] = config.get('mixup', 0)
    model_dict['dropout'] = config.get('dropout', 0)
    model_dict['weight_decay'] = config.get('weight_decay', '???')
    model_dict['fc_stages'] = config.get('fc_stages', 1)
    model_dict['activation'] = config.get('activation', 0)
    
    model_dict['minibatch'] = round(config['minibatch'])
    model_dict['total_samples'] = round(config.get('total_samples', config['iterations'] * config['minibatch']))
    model_dict['base_lr'] = config.get('base_lr', config.get('LR', '???'))
    model_dict['lr_scheduler_type'] = config.get('lr_scheduler_type', 'constant_with_decay')
    model_dict['warmup_steps'] = config.get('warmup_steps', '???')
    model_dict['seed'] = config.get('seed', '???')
    
    print('- checking for', model_dict['name'], config['model'], '...')
    
    # initiate the model
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    elif type(config['generator']) is str:
        config['generator'] = getattr(models, config['generator'].split('.')[-1])
        if 'block' in config:
            config['block'] = getattr(models, config['block'].split('.')[-1])
        model = config['generator'](**config).to(device)
    else:
        if 'block' in config:
            if config['block'] == models.resnet_1d.BottleneckBlock1D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_2d.Bottleneck2D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_1d.BasicBlock1D:
                config['block'] = 'basic'
            elif config['block'] == models.resnet_2d.BasicBlock2D:
                config['block'] = 'basic'
                
        model = config['generator'](**config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    
    # reconfigure and update
    config.pop('cwd', 0)
    config['ddp'] = False
    config['crop_timing_analysis'] = False
    config['eval'] = True
    config['device'] = device
    
    repeat = round(base_repeat / config['crop_multiple'])
    model_dict['repeat'] = repeat
    model_dict['crop_multiple'] = config['crop_multiple']
    model_dict['test_crop_multiple'] = config['test_crop_multiple']
    
    # build dataset
    _ = build_dataset_for_train(config, verbose=verbose)
    train_loader = _[0]
    val_loader = _[1]
    test_loader = _[2]
    multicrop_test_loader = _[3]
    
    # train accuracy
    train_acc = check_accuracy(model, train_loader, 
                               config['preprocess_test'], config, repeat=repeat)
    model_dict['Train Accuracy'] = train_acc
    
    # val accuracy
    val_acc = check_accuracy(model, val_loader, 
                             config['preprocess_test'], config, repeat=repeat)
    model_dict['Validation Accuracy'] = val_acc
    
    # Test accuracy
    _ = check_accuracy_extended(model, test_loader, 
                                config['preprocess_test'], config, repeat=repeat)
    model_dict['Test Throughput'] = _[4]
    model_dict['Test Accuracy'] = _[0]
    test_class_wise_metrics = calculate_class_wise_metrics(_[3])
    
    for k, v in test_class_wise_metrics.items():
        for c in range(config['out_dims']):
            c_name = config['class_label_to_name'][c]
            model_dict[f'{k} ({c_name})'] = test_class_wise_metrics[k][c]
    
    # Multi-crop test accuracy
    _ = check_accuracy_multicrop_extended(model, multicrop_test_loader, 
                                          config['preprocess_test'], config, repeat=repeat)
    model_dict['Multi-Crop Test Throughput'] = _[4]
    model_dict['Multi-Crop Test Accuracy'] = _[0]
    multi_test_class_wise_metrics = calculate_class_wise_metrics(_[3])
    
    for k, v in multi_test_class_wise_metrics.items():
        for c in range(config['out_dims']):
            c_name = config['class_label_to_name'][c]
            model_dict[f'Multi-Crop {k} ({c_name})'] = multi_test_class_wise_metrics[k][c]
            
print('==== Finished ====')

- checking for 3smz5tul Ieracitano-CNN ...
- checking for 9d9rnxrx 2D-ResNet-18 ...
- checking for umtq5iu6 2D-ResNeXt-50 ...
- checking for kxyux4lz 1D-ResNet-101 ...
- checking for 3td60j3k 2D-ViT-B-16 ...
- checking for 22pj8bg1 1D-Linear-SVM ...
- checking for ocf38th1 1D-ResNet-50 ...
- checking for pz76l4fq 2D-VGG-19 ...
- checking for nd1ss8ci 1D-ResNet-18 ...
- checking for nkrh4pji 1D-VGG-19 ...
- checking for f9g4k64e 1D-ResNeXt-50 ...
- checking for 03iz3rso 2D-ResNet-50 ...
- checking for 1663fxhy 2D-ViT-B-16 ...
- checking for 34oksw2q Ieracitano-CNN ...
==== Finished ====


In [8]:
pprint.pprint(model_pool)

[{'Accuracy (Dementia)': 0.864474115334207,
  'Accuracy (MCI)': 0.7742566349934469,
  'Accuracy (Normal)': 0.8196162352555701,
  'EKG': 'X',
  'F1-score (Dementia)': 0.74906154021158,
  'F1-score (MCI)': 0.6952097877929081,
  'F1-score (Normal)': 0.7492134895440374,
  'Multi-Crop Accuracy (Dementia)': 0.8632044560943644,
  'Multi-Crop Accuracy (MCI)': 0.7747276376146789,
  'Multi-Crop Accuracy (Normal)': 0.8201486730013107,
  'Multi-Crop F1-score (Dementia)': 0.7474098162292975,
  'Multi-Crop F1-score (MCI)': 0.6956016436764116,
  'Multi-Crop F1-score (Normal)': 0.7496401031941734,
  'Multi-Crop Precision (Dementia)': 0.7770875923887404,
  'Multi-Crop Precision (MCI)': 0.6978706865439605,
  'Multi-Crop Precision (Normal)': 0.7262960201066092,
  'Multi-Crop Recall (Dementia)': 0.7199155011655012,
  'Multi-Crop Recall (MCI)': 0.693347308031774,
  'Multi-Crop Recall (Normal)': 0.7745346371347785,
  'Multi-Crop Sensitivity (Dementia)': 0.7199155011655012,
  'Multi-Crop Sensitivity (MCI)': 

In [9]:
pd.DataFrame(model_pool)

Unnamed: 0,name,path,model,num_params,model size (MiB),seq_length,use_age,photic,EKG,awgn,...,Multi-Crop Specificity (Dementia),Multi-Crop Precision (Normal),Multi-Crop Precision (MCI),Multi-Crop Precision (Dementia),Multi-Crop Recall (Normal),Multi-Crop Recall (MCI),Multi-Crop Recall (Dementia),Multi-Crop F1-score (Normal),Multi-Crop F1-score (MCI),Multi-Crop F1-score (Dementia)
0,3smz5tul,E:\CAUEEG\checkpoint\3smz5tul\checkpoint.pt,Ieracitano-CNN,3457995,13.19866,1000,no,X,X,0.0,...,0.91924,0.726296,0.697871,0.777088,0.774535,0.693347,0.719916,0.74964,0.695602,0.74741
1,9d9rnxrx,E:\CAUEEG\checkpoint\9d9rnxrx\checkpoint.pt,2D-ResNet-18,11466947,43.826613,2000,conv,O,O,0.05,...,1.0,1.0,0.999669,1.0,0.999764,1.0,0.999854,0.999882,0.999835,0.999927
2,umtq5iu6,E:\CAUEEG\checkpoint\umtq5iu6\checkpoint.pt,2D-ResNeXt-50,25731395,98.538363,2000,conv,O,O,0.05,...,0.999487,0.999294,1.0,0.998691,1.0,0.998345,1.0,0.999647,0.999172,0.999345
3,kxyux4lz,E:\CAUEEG\checkpoint\kxyux4lz\checkpoint.pt,1D-ResNet-101,45174531,172.949571,2000,conv,O,O,0.05,...,0.998804,0.999293,0.998895,0.996948,0.998822,0.997573,0.999272,0.999057,0.998234,0.998108
4,3td60j3k,E:\CAUEEG\checkpoint\3td60j3k\checkpoint.pt,2D-ViT-B-16,86272899,329.164175,2000,conv,O,O,0.05,...,0.990656,0.984726,0.98889,0.976246,0.987394,0.982017,0.981935,0.986058,0.985441,0.979082
5,22pj8bg1,E:\CAUEEG\checkpoint\22pj8bg1\checkpoint.pt,1D-Linear-SVM,63006,0.241155,1000,fc,O,O,0.0,...,0.76863,0.533868,0.411436,0.36223,0.558524,0.416179,0.336029,0.545918,0.413794,0.348638
6,ocf38th1,E:\CAUEEG\checkpoint\ocf38th1\checkpoint.pt,1D-ResNet-50,26182403,100.200901,2000,conv,O,O,0.05,...,0.99755,0.992952,0.998004,0.993755,0.995877,0.992829,0.996941,0.994412,0.99541,0.995345
7,pz76l4fq,E:\CAUEEG\checkpoint\pz76l4fq\checkpoint.pt,2D-VGG-19,20217923,77.205525,2000,conv,O,O,0.05,...,0.999601,0.994471,0.992849,0.998974,0.995994,0.995697,0.993298,0.995232,0.994271,0.996128
8,nd1ss8ci,E:\CAUEEG\checkpoint\nd1ss8ci\checkpoint.pt,1D-ResNet-18,11394051,43.550922,2000,conv,O,O,0.05,...,0.998234,0.99495,0.996456,0.995488,0.998115,0.992718,0.996503,0.99653,0.994584,0.995996
9,nkrh4pji,E:\CAUEEG\checkpoint\nkrh4pji\checkpoint.pt,1D-VGG-19,20205827,77.159308,2000,conv,O,O,0.05,...,0.998234,0.994248,0.997451,0.99549,0.997879,0.993049,0.996795,0.99606,0.995245,0.996142


In [10]:
pd.DataFrame(model_pool).to_csv('local/output/caueeg-task2-segmented.csv')