# Evaluate

The network trained by `05_Train_Sweep.ipynb` is evaluated in this notebook.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [2]:
# Load some packages
import os
import json
import yaml
from copy import deepcopy

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pprint
import matplotlib.pyplot as plt

# custom package
from datasets.cau_eeg_dataset import *
from datasets.cau_eeg_script import *
import models
from train import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.11.0+cu113
cuda is available.


## Load the configuration which used during train phase

In [4]:
model_name = 'zlypije8'
model_path = os.path.join('local/checkpoint', model_name)

In [5]:
with open(os.path.join(model_path, 'config.yaml'), 'r') as f:
    wandb_config = yaml.load(f, Loader=yaml.FullLoader)
    
config = {}
for k, v in wandb_config.items():
    if type(k) is str and (k.startswith('wandb') or k.startswith('_wandb')):
        continue
    config[k] = v['value']

pprint.pprint(config)

{'EKG': 'X',
 'LR': 0.00028586537399823023,
 'activation': 'gelu',
 'age_mean': 69.92779783393502,
 'age_std': 9.817569889945597,
 'awgn': 0.09612556281067966,
 'awgn_age': 0.232644717265596,
 'base_channels': 64,
 'base_stride': 3,
 'block': 'models.resnet_1d.BottleneckBlock1D',
 'conv_layers': [3, 4, 6, 3],
 'criterion': 'multi-bce',
 'crop_length': 2000,
 'crop_multiple': 2,
 'data_path': 'dataset/02_Curated_Data/',
 'dataset': 'CAUHS',
 'device': 'cuda',
 'draw_result': True,
 'dropout': 0.4741286361591117,
 'fc_stages': 3,
 'final_pool': 'average',
 'final_shape': [64, 2048, 12],
 'first_dilation': 3,
 'generator': 'models.resnet_1d.ResNet1D',
 'groups': 32,
 'history_interval': 200,
 'in_channels': 20,
 'input_norm': 'dataset',
 'iterations': 200000,
 'longer_crop_length': 20000,
 'lr_decay_gamma': 0.1721164909561929,
 'lr_decay_step': 150000,
 'meta_path': 'dataset/02_Curated_Data/metadata_debug.json',
 'minibatch': 64,
 'mixup': 0.3,
 'model': '1D-ResNeXt-5x',
 'model_index': 2

In [6]:
# preprocessing procedure was changed.. (norm, channel drop) --> (channel drop, norm)
config['signal_mean'] = np.delete(np.array(config['signal_mean']), 19, 0)
config['signal_std'] = np.delete(np.array(config['signal_std']), 19, 0)

-----

## Build the dataset

In [7]:
config['data_path'] = r'local\dataset\02_Cuated_TEMP'
config['meta_path'] = r'local/dataset/02_Cuated_TEMP/annotations.json'
config['file_format'] = 'edf'
config['minibatch'] = 2
config['crop_multiple'] = 1
config['crop_timing_analysis'] = True
config['evaluation_phase'] = True
config['latency'] = 200 * 10  # 10 seconds
# config['longer_crop_length'] = 20000 * 2
config['device'] = device
repeat = round(50 / config['crop_multiple'])

In [8]:
with open(config['meta_path'], 'r') as json_file:
    metadata = json.load(json_file)

In [9]:
diagnosis_filter, class_label_to_name = define_target_task(config, verbose=True)
config['diagnosis_filter'] = diagnosis_filter
config['class_label_to_name'] = class_label_to_name

def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        inc = set(f['include']) & set(label) == set(f['include'])
        # inc = len(set(f['include']) & set(label)) > 0
        exc = len(set(f['exclude']) & set(label)) == 0
        if inc and exc:
            return c, f['type']
    return -1, 'The others'

splitted_metadata = [[] for _ in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)

class_label_to_name: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']

----------------------------------------------------------------------------------------------------



In [10]:
metadata = [splitted_metadata[0][0]]
print(metadata)

[{'serial': '00001', 'age': 23, 'label': ['normal', 'smi'], 'class_type': 'Normal', 'class_label': 0}]


In [11]:
composed_train, composed_test, composed_test_longer = compose_transforms(config, verbose=True)

composed_train: Compose(
    EegDropEKGChannel()
    EegRandomCrop(crop_length=2000, multiple=1, latency=2000, return_timing=True)
    EegToTensor()
)

----------------------------------------------------------------------------------------------------

composed_test: Compose(
    EegDropEKGChannel()
    EegRandomCrop(crop_length=2000, multiple=1, latency=2000, return_timing=True)
    EegToTensor()
)

----------------------------------------------------------------------------------------------------

longer_composed_test: Compose(
    EegDropEKGChannel()
    EegRandomCrop(crop_length=20000, multiple=1, latency=2000, return_timing=True)
    EegToTensor()
)

----------------------------------------------------------------------------------------------------




In [12]:
test_dataset = CauEegDataset(config['data_path'],
                             metadata,
                             config.get('load_event', False),
                             config.get('file_format', 'feather'),
                             composed_test)
test_dataset_longer = CauEegDataset(config['data_path'],
                                    metadata,
                                    config.get('load_event', False),
                                    config.get('file_format', 'feather'),
                                    composed_test_longer)   

In [13]:
if config['device'].type == 'cuda':
    num_workers = 0  # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False
    
batch_size = config['minibatch'] // config.get('crop_multiple', 1)



test_loader = DataLoader(test_dataset,
                         batch_size=batch_size,
                         shuffle=False,
                         drop_last=False,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

test_loader_longer = DataLoader(test_dataset_longer,
                                batch_size=batch_size // 2,  # to save the memory capacity
                                shuffle=False,
                                drop_last=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)

preprocess_test = []
preprocess_test += [EegToDevice(device=config['device'])]
preprocess_test += [EegNormalizeMeanStd(mean=config['signal_mean'], std=config['signal_std'])]
preprocess_test += [EegNormalizeAge(mean=config['age_mean'], std=config['age_std'])]
preprocess_test = transforms.Compose(preprocess_test)
preprocess_test = torch.nn.Sequential(*preprocess_test.transforms)

In [14]:
for sample in test_loader:
    print(sample['signal'].shape)

for sample in test_loader_longer:
    print(sample['signal'].shape)

torch.Size([1, 20, 2000])
torch.Size([1, 20, 20000])


-----

## Load the target model

In [15]:
config['generator'] = getattr(models, config['generator'].split('.')[-1])

if 'block' in config:
    config['block'] = getattr(models, config['block'].split('.')[-1])

model = config['generator'](**config).to(device)
model_state = torch.load(os.path.join(model_path, config['model']))
model.load_state_dict(model_state)

<All keys matched successfully>

-----

## Evaluate the model

### Test set

In [19]:
test_acc = check_accuracy(model, test_loader, preprocess_test, config, repeat=50)
print(test_acc)

100.0


### Test set (with test-time augmentation)

In [20]:
test_longer_acc = check_accuracy(model, test_loader_longer, preprocess_test, config, repeat=50)
print(test_longer_acc)

100.0
