# Evaluate

This notebook evaluates the network trained previous notebooks and analyzes the results.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pprint
import matplotlib
import matplotlib.pyplot as plt

# custom package
from datasets.caueeg_script import build_dataset_for_train
import models
from train.evaluate import check_accuracy
from train.evaluate import check_accuracy_extended
from train.evaluate import check_accuracy_extended_debug
from train.evaluate import check_accuracy_multicrop
from train.evaluate import check_accuracy_multicrop_extended
from train.visualize import draw_roc_curve
from train.visualize import draw_confusion
from train.visualize import draw_class_wise_metrics
from train.visualize import draw_error_table
from train.visualize import annotate_heatmap

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.11.0+cu113
cuda is available.


-----

## Load the configuration used during the train phase

In [4]:
model_name = 'lo88puq7'
model_path = os.path.join(r'E:\CAUEEG\checkpoint', model_name, 'checkpoint.pt')

ckpt = torch.load(model_path, map_location=device)
print(ckpt.keys())

dict_keys(['model_state', 'config', 'optimizer_state', 'scheduler_state'])


In [5]:
model_state = ckpt['model_state']
config = ckpt['config']
optimizer = ckpt['optimizer_state']
scheduler = ckpt['scheduler_state']

In [6]:
pprint.pprint(config, width=250)

{'EKG': 'O',
 '_target_': 'models.vgg_2d.VGG2D',
 'activation': 'gelu',
 'age_mean': tensor([71.2425], device='cuda:0'),
 'age_std': tensor([9.5208], device='cuda:0'),
 'awgn': 0.0479404157791391,
 'awgn_age': 0.031086177931273194,
 'base_lr': 8.479608095398431e-05,
 'batch_norm': True,
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'criterion': 'multi-bce',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'cwd': '/home/minjae/Desktop/eeg_analysis',
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda'),
 'draw_result': True,
 'dropout': 0.20149124331753415,
 'fc_stages': 2,
 'file_format': 'memmap',
 'in_channels': 40,
 'input_norm': 'dataset',
 'iterations': 781250,
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'constant_with_decay',
 'mgn': 0.003676528367854981,
 'minibatch': 128,
 'mixup': 0.3,
 'model': '2

-----

## Load the target model

In [7]:
# model = config['generator'](**config).to(device)
model = hydra.utils.instantiate(config).to(device)

if config.get('ddp', False):
    model_state_ddp = deepcopy(model_state)
    model_state = OrderedDict()
    for k, v in model_state_ddp.items():
        name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
        model_state[name] = v
        
model.load_state_dict(model_state)

<All keys matched successfully>

-----

## Evaluate the model and analyze the performance by the crop timing

### Configurations

In [8]:
config = ckpt['config']

config.pop('cwd', 0)
config['ddp'] = False
config['crop_timing_analysis'] = True
config['eval'] = True
config['device'] = device

# config['seq_length'] = 16000 # TODO
# config['test_crop_multiple'] = 32 # TODO
config['crop_multiple'] = 128 # TODO
config['mgn'] = 0
config['awgn'] = 0

repeat = round(512 / config['crop_multiple'])
print(repeat)

4


### Build Dataset

In [9]:
train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config, verbose=True)

transform: Compose(
    EegRandomCrop(crop_length=4000, length_limit=10000000, multiple=128, latency=2000, return_timing=True)
    EegDropChannels(drop_index=[20])
    EegToTensor()
)

----------------------------------------------------------------------------------------------------

transform_multicrop: Compose(
    EegRandomCrop(crop_length=4000, length_limit=10000000, multiple=8, latency=2000, return_timing=True)
    EegDropChannels(drop_index=[20])
    EegToTensor()
)

----------------------------------------------------------------------------------------------------


task config:
{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-Dementia benchmark'}

 ---------------------------------------------------------------------------------------------------- 

train_dataset[0].keys():
dict

In [10]:
from datasets.pipeline import EegAgeBias
preprocess_test_age_bias = torch.nn.Sequential(*[EegAgeBias(-5.0), *config['preprocess_test']])
print(preprocess_test_age_bias)

Sequential(
  (0): EegAgeBias(bias=-5.0)
  (1): EegToDevice(device=cuda)
  (2): EegNormalizeAge(mean=tensor([71.2425], device='cuda:0'),std=tensor([9.5208], device='cuda:0'),eps=1e-08)
  (3): EegNormalizeMeanStd(mean=tensor([ 0.0794, -0.0131, -0.0121, -0.0244,  0.0004,  0.0295,  0.0506,  0.0152,
          -0.0003,  0.0337, -0.0050, -0.0096, -0.0336,  0.0111,  0.0077,  0.0076,
           0.0110,  0.0071, -0.0213,  0.0056], device='cuda:0'),std=tensor([49.7625, 21.7403, 12.5883, 12.5998, 17.0438, 53.6425, 21.1902, 11.1606,
          12.5307, 16.9399, 21.8603, 15.2040, 14.7430, 23.0870, 18.4153, 15.7697,
          20.6772, 12.2012, 12.5872, 94.9068], device='cuda:0'),eps=1e-08)
  (4): EegSpectrogram(n_fft=179, complex_mode=as_real, stft_kwargs={'hop_length': 45})
  (5): EegNormalizeMeanStd(mean=tensor([[ 8.8867e-02, -5.7190e-03, -1.6277e-03,  ...,  2.6429e-04,
           -6.7322e-04, -3.5536e-04],
          [ 1.6543e-01, -4.2610e-03, -1.3435e-03,  ...,  1.1995e-04,
           -7.5047e-04,

In [11]:
for sample_batched in test_loader:
    print(sample_batched['age'])
    preprocess_test_age_bias(sample_batched)
    print(sample_batched['age'])
    break
print('-----')    
for sample_batched in test_loader:
    print(sample_batched['age'])
    config['preprocess_test'](sample_batched)
    print(sample_batched['age'])
    break

tensor([62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62.])
tensor([-1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959,
        -1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959,
        -1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959, -1.4959,
        -1.4959, -1.4959, -1.4959, -1.4

In [12]:
from datasets.pipeline import EegAgeZero
preprocess_test_age_zero = torch.nn.Sequential(*[*config['preprocess_test'], EegAgeZero(-1.0)])
print(preprocess_test_age_zero)

Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([71.2425], device='cuda:0'),std=tensor([9.5208], device='cuda:0'),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([ 0.0794, -0.0131, -0.0121, -0.0244,  0.0004,  0.0295,  0.0506,  0.0152,
          -0.0003,  0.0337, -0.0050, -0.0096, -0.0336,  0.0111,  0.0077,  0.0076,
           0.0110,  0.0071, -0.0213,  0.0056], device='cuda:0'),std=tensor([49.7625, 21.7403, 12.5883, 12.5998, 17.0438, 53.6425, 21.1902, 11.1606,
          12.5307, 16.9399, 21.8603, 15.2040, 14.7430, 23.0870, 18.4153, 15.7697,
          20.6772, 12.2012, 12.5872, 94.9068], device='cuda:0'),eps=1e-08)
  (3): EegSpectrogram(n_fft=179, complex_mode=as_real, stft_kwargs={'hop_length': 45})
  (4): EegNormalizeMeanStd(mean=tensor([[ 8.8867e-02, -5.7190e-03, -1.6277e-03,  ...,  2.6429e-04,
           -6.7322e-04, -3.5536e-04],
          [ 1.6543e-01, -4.2610e-03, -1.3435e-03,  ...,  1.1995e-04,
           -7.5047e-04, -5.0288e-04],
          [ 3.

In [13]:
for sample_batched in test_loader:
    print(sample_batched['age'])
    preprocess_test_age_zero(sample_batched)
    print(sample_batched['age'])
    break
print('-----')    
for sample_batched in test_loader:
    print(sample_batched['age'])
    config['preprocess_test'](sample_batched)
    print(sample_batched['age'])
    break

tensor([62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62., 62.,
        62., 62.])
tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1

### Val set

In [14]:
_ = check_accuracy_extended_debug(model, val_loader, 
                                  config['preprocess_test'], config, repeat=repeat)
val_acc = _[0]
val_score = _[1]
val_target = _[2]
val_confusion = _[3]
val_error_table = _[4]
val_crop_timing = _[5]

print(val_acc)

64.59755777310924


In [15]:
_ = check_accuracy_extended_debug(model, val_loader, 
                                  config['preprocess_train'], config, repeat=repeat)
val_acc = _[0]
val_score = _[1]
val_target = _[2]
val_confusion = _[3]
val_error_table = _[4]
val_crop_timing = _[5]

print(val_acc)

64.51549369747899


In [16]:
_ = check_accuracy_extended_debug(model, val_loader, 
                                  preprocess_test_age_bias, config, repeat=repeat)
val_acc = _[0]
val_score = _[1]
val_target = _[2]
val_confusion = _[3]
val_error_table = _[4]
val_crop_timing = _[5]

print(val_acc)

63.00715598739496


In [17]:
_ = check_accuracy_extended_debug(model, val_loader, 
                                  preprocess_test_age_zero, config, repeat=repeat)
val_acc = _[0]
val_score = _[1]
val_target = _[2]
val_confusion = _[3]
val_error_table = _[4]
val_crop_timing = _[5]

print(val_acc)

54.41012342436975


### Test set

In [18]:
_ = check_accuracy_extended_debug(model, test_loader, 
                                  config['preprocess_test'], config, repeat=repeat)
test_acc = _[0]
test_score = _[1]
test_target = _[2]
test_confusion = _[3]
test_error_table = _[4]
test_crop_timing = _[5]

print(test_acc)

68.06971663135593


In [19]:
_ = check_accuracy_extended_debug(model, test_loader, 
                                  config['preprocess_train'], config, repeat=repeat)
test_acc = _[0]
test_score = _[1]
test_target = _[2]
test_confusion = _[3]
test_error_table = _[4]
test_crop_timing = _[5]

print(test_acc)

68.25178760593221


In [20]:
_ = check_accuracy_extended_debug(model, test_loader, 
                                  preprocess_test_age_bias, config, repeat=repeat)
test_acc = _[0]
test_score = _[1]
test_target = _[2]
test_confusion = _[3]
test_error_table = _[4]
test_crop_timing = _[5]

print(test_acc)

66.67604608050848


In [21]:
_ = check_accuracy_extended_debug(model, test_loader, 
                                  preprocess_test_age_zero, config, repeat=repeat)
test_acc = _[0]
test_score = _[1]
test_target = _[2]
test_confusion = _[3]
test_error_table = _[4]
test_crop_timing = _[5]

print(test_acc)

56.88559322033898


### Test set (with test-time augmentation)

In [22]:
_ = check_accuracy_multicrop_extended(model, multicrop_test_loader, 
                                      config['preprocess_test'], config, repeat=round(repeat * config['crop_multiple']))
multi_test_acc = _[0]
multi_test_score = _[1]
multi_test_target = _[2]
multi_test_confusion = _[3]
multi_test_error_table = _[4]

print(multi_test_acc)

70.1320842161017


In [23]:
_ = check_accuracy_multicrop_extended(model, multicrop_test_loader, 
                                      config['preprocess_train'], config, repeat=repeat * config['crop_multiple'])
multi_test_acc = _[0]
multi_test_score = _[1]
multi_test_target = _[2]
multi_test_confusion = _[3]
multi_test_error_table = _[4]

print(multi_test_acc)

69.95332362288136


In [24]:
_ = check_accuracy_multicrop_extended(model, multicrop_test_loader, 
                                      preprocess_test_age_bias, config, repeat=round(repeat * config['crop_multiple']))
multi_test_acc = _[0]
multi_test_score = _[1]
multi_test_target = _[2]
multi_test_confusion = _[3]
multi_test_error_table = _[4]

print(multi_test_acc)

67.72543697033898


In [25]:
_ = check_accuracy_multicrop_extended(model, multicrop_test_loader, 
                                      preprocess_test_age_zero, config, repeat=round(repeat * config['crop_multiple']))
multi_test_acc = _[0]
multi_test_score = _[1]
multi_test_target = _[2]
multi_test_confusion = _[3]
multi_test_error_table = _[4]

print(multi_test_acc)

56.90545550847458
