# Test Ensemble

This notebook classifies the input EDF using the pretrained models.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%cd ..
%load_ext autoreload
%autoreload 2

C:\Users\bengb\OneDrive\문서\GitHub\eeg_analysis


In [2]:
# Load some packages
import os
import sys
import pickle
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
from collections import OrderedDict

import numpy as np
import pandas as pd
import pyedflib
import datetime
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

import pprint
from tqdm import auto
import wandb
import matplotlib
import matplotlib.pyplot as plt
import tkinter as tk
from tkinter import filedialog

# custom package
from datasets.caueeg_dataset import CauEegDataset
from datasets.caueeg_data_curation import calculate_age, birth_to_datetime
from datasets.pipeline import EegChangeMontageOrder, EegResample
from datasets.pipeline import eeg_collate_fn
from datasets.pipeline import EegNormalizeMeanStd, EegNormalizePerSignal
from datasets.pipeline import EegNormalizeAge
from datasets.pipeline import EegToTensor, EegToDevice
from datasets.pipeline import EegSpectrogram
from datasets.caueeg_script import compose_transforms, compose_preprocess
import models
from train.evaluate import estimate_score



In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

num_workers = 0  # A number other than 0 causes an error
pin_memory = True

PyTorch version: 1.11.0
cuda is available.


---

## Configurations

In [4]:
base_repeat = 16  # 500
crop_multiple = 8
test_crop_multiple = 8
verbose = False

model_names = [
    'lo88puq7',
]
model_pool = []
for model_name in model_names:
    path = os.path.join(r'./local/checkpoint', model_name, 'checkpoint.pt')
    try:
        ckpt = torch.load(path, map_location=device)
        model_pool.append({'name': model_name, 'path': path})
    except Exception as e:
        print(e)
        print(f'- checkpoint cannot be opened: {path}')
        
pprint.pprint([model_dict['name'] for model_dict in model_pool])

['lo88puq7']


In [5]:
old_montage = ['Fp1-AVG', 'F3-AVG', 'C3-AVG', 'P3-AVG', 'O1-AVG', 'Fp2-AVG', 'F4-AVG', 
               'C4-AVG', 'P4-AVG', 'O2-AVG', 'F7-AVG', 'T3-AVG', 'T5-AVG', 'F8-AVG', 
               'T4-AVG', 'T6-AVG', 'FZ-AVG', 'CZ-AVG', 'PZ-AVG', 'EKG', 'Photic']

-----

## Evaluate each model and accumulate the logits

In [19]:
# file picker: EDF
root = tk.Tk()
root.withdraw()
root.focus_force()
root.wm_attributes('-topmost', 1)
edf_file = filedialog.askopenfile(title="Select an EDF file",
                                  filetypes=(("EDF files", "*.edf"),
                                             ("all files", "*.*"))).name
print("Selected EDF:", edf_file)
root.destroy()

# read the EDF file
signals, signal_headers, edf_header = pyedflib.highlevel.read_edf(edf_file)
print("Loaded EDF signal has the shape of :", signals.shape)
print()

old_sample_freq = 200
new_sample_freq = signal_headers[0]['sample_rate']

# calculate age
if edf_header['birthdate'] != '':
    birth = datetime.datetime.strptime(edf_header['birthdate'], "%d %b %Y")
    age = calculate_age(birth, edf_header['startdate'])
else:
    age = 70

print(age)
    
# change channel order of montage if required
current_montage = [sh['label'].split(' ')[-1] for sh in signal_headers]
eeg_change_montage_order = EegChangeMontageOrder(old_montage, current_montage)

# build data list
data_list = [{'serial': os.path.splitext(os.path.basename(edf_file))[0], 'age': age}]

for model_dict in model_pool:
    #################################
    # load and parse the checkpoint #
    #################################
    ckpt = torch.load(model_dict['path'], map_location=device)
    model_state = ckpt['model_state']
    config = ckpt['config']
    
    model_dict['model'] = config['model']
    print('- checking for', model_dict['name'], config['model'], '...')
    
    # initiate the model
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    elif type(config['generator']) is str:
        config['generator'] = getattr(models, config['generator'].split('.')[-1])
        if 'block' in config:
            config['block'] = getattr(models, config['block'].split('.')[-1])
        model = config['generator'](**config).to(device)
    else:
        if 'block' in config:
            if config['block'] == models.resnet_1d.BottleneckBlock1D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_2d.Bottleneck2D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_1d.BasicBlock1D:
                config['block'] = 'basic'
            elif config['block'] == models.resnet_2d.BasicBlock2D:
                config['block'] = 'basic'
                
        model = config['generator'](**config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    model = model.requires_grad_(False)
    model_dict['model'] = model
    
    # reconfigure and update
    config.pop('cwd', 0)
    config['ddp'] = False
    config['crop_multiple'] = crop_multiple
    config['crop_timing_analysis'] = False
    config['eval'] = True
    config['device'] = device
    
    repeat = round(base_repeat / crop_multiple)
    model_dict['repeat'] = repeat
    model_dict['crop_multiple'] = crop_multiple
    model_dict['test_crop_multiple'] = test_crop_multiple
    
    #################
    # build dataset #
    #################
    if new_sample_freq != old_sample_freq:
        config['photic'] = 'O'  # pretend to be
    transform, transform_multicrop = compose_transforms(config)
    
    transform.transforms.insert(1, eeg_change_montage_order)
    transform_multicrop.transforms.insert(1, eeg_change_montage_order)
    
    test_dataset = CauEegDataset(os.path.dirname(edf_file), data_list, load_event=False, 
                                 file_format='edf', use_prefix_signal=False, transform=transform)    
    multicrop_test_dataset = CauEegDataset(os.path.dirname(edf_file), data_list, load_event=False, 
                                           file_format='edf', use_prefix_signal=False, transform=transform_multicrop)    
    
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=False,
                             num_workers=num_workers, pin_memory=pin_memory, collate_fn=eeg_collate_fn)
    multicrop_test_loader = DataLoader(multicrop_test_dataset, batch_size=1, shuffle=False, drop_last=False,
                                       num_workers=num_workers, pin_memory=pin_memory, collate_fn=eeg_collate_fn)
    
    preprocess_test = []
    preprocess_test += [EegToDevice(device=config['device'])]
    preprocess_test += [EegResample(orig_freq=old_sample_freq, new_freq=new_sample_freq)]
    preprocess_test += [EegNormalizeAge(mean=config['age_mean'], std=config['age_std'])]

    if config['input_norm'] == 'dataset':
        preprocess_test += [EegNormalizeMeanStd(mean=config['signal_mean'], std=config['signal_std'])]
    elif config['input_norm'] == 'datapoint':
        preprocess_test += [EegNormalizePerSignal()]

    if config.get('model', '1D').startswith('2D'):
        preprocess_test += [EegSpectrogram(**config['stft_params'])]
        
        if config['input_norm'] == 'dataset':
            preprocess_test += [EegNormalizeMeanStd(mean=config['signal_2d_mean'], std=config['signal_2d_std'])]
        elif config['input_norm'] == 'datapoint':
            preprocess_test += [EegNormalizePerSignal()]
            
    preprocess_test = transforms.Compose(preprocess_test)
    preprocess_test = torch.nn.Sequential(*preprocess_test.transforms).to(device)
    
    ########
    # test #
    ########
    for sample_batched in test_loader:
        score = estimate_score(model, sample_batched, preprocess_test, config)
    
    for sample_batched in multicrop_test_loader:
        multi_score = estimate_score(model, sample_batched, preprocess_test, config)
    
    score = torch.cat((score, multi_score), dim=0).mean(dim=0).detach().cpu().numpy()
    for s in score:
        print(f"{s * 100:.2f}%")
            
print('==== Finished ====')

Selected EDF: C:/Users/bengb/Desktop/drive-download-20221104T094002Z-001/ref/00010.edf
Loaded EDF signal has the shape of : (21, 201600)

70
- checking for lo88puq7 2D-VGG-19 ...
4.66%
84.18%
15.19%
==== Finished ====


In [10]:
transform

Compose(
    EegRandomCrop(crop_length=4000, length_limit=10000000, multiple=8, latency=2000, return_timing=False)
    EegChangeMontageOrder(channel_change=[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20])
    EegDropChannels(drop_index=[20])
    EegToTensor()
)

## Conduct ensemble

In [None]:
n_ensemble = len(model_pool)

ensemble_test_score = np.zeros_like(model_pool[0]['Test Score'])
ensemble_multi_test_score = np.zeros_like(model_pool[0]['Multi-Crop Test Score'])

ensemble_test_latency = 0
ensemble_multi_test_latency = 0

ensemble_params = 0
ensemble_model_size = 0

for model_dict in model_pool:
    ensemble_test_score += model_dict['Test Score'] / len(model_pool)
    ensemble_multi_test_score += model_dict['Multi-Crop Test Score'] / len(model_pool)
    
    ensemble_test_latency += 1 / model_dict['Test Throughput']
    ensemble_multi_test_latency += 1 / model_dict['Multi-Crop Test Throughput']
    
    ensemble_params += model_dict['num_params']
    ensemble_model_size += model_dict['model size (MiB)']

In [None]:
# ensemble accuracy
pred = ensemble_test_score.argmax(axis=-1)
ensemble_test_acc = 100.0 * (pred.squeeze() == model_pool[0]['Test Target']).sum() / pred.shape[0]

# class wise metrics
ensemble_test_confusion = calculate_confusion_matrix(pred, model_pool[0]['Test Target'], 
                                                     num_classes=ensemble_test_score.shape[-1])
ensembel_test_class_wise_metrics = calculate_class_wise_metrics(ensemble_test_confusion)

In [None]:
# multi-crop accuracy
pred = ensemble_multi_test_score.argmax(axis=-1)
ensemble_multi_test_acc = 100.0 * (pred.squeeze() == model_pool[0]['Multi-Crop Test Target']).sum() / pred.shape[0]

# class wise metrics
ensemble_multi_test_confusion = calculate_confusion_matrix(pred, model_pool[0]['Multi-Crop Test Target'], 
                                                           num_classes=ensemble_multi_test_score.shape[-1])
ensembel_multi_test_class_wise_metrics = calculate_class_wise_metrics(ensemble_multi_test_confusion)

## Summarize the ensemble results

In [None]:
ensemble_dict = {}

ensemble_dict['name'] = 'Ensemble'
ensemble_dict['Test Throughput'] = 1 / ensemble_test_latency
ensemble_dict['Test Accuracy'] = ensemble_test_acc
ensemble_dict['Multi-Crop Test Throughput'] = 1 / ensemble_multi_test_latency
ensemble_dict['Multi-Crop Test Accuracy'] = ensemble_multi_test_acc
ensemble_dict['num_params'] = ensemble_params
ensemble_dict['model size (MiB)'] = ensemble_model_size

for k, v in ensembel_test_class_wise_metrics.items():
    for c in range(config['out_dims']):
        c_name = config['class_label_to_name'][c]
        ensemble_dict[f'{k} ({c_name})'] = ensembel_test_class_wise_metrics[k][c]
        
for k, v in ensembel_multi_test_class_wise_metrics.items():
    for c in range(config['out_dims']):
        c_name = config['class_label_to_name'][c]
        ensemble_dict[f'Multi-Crop {k} ({c_name})'] = ensembel_multi_test_class_wise_metrics[k][c]
        
model_pool.append(ensemble_dict)

In [None]:
for model_dict in model_pool:
    model_dict.pop('Test Score', None)
    model_dict.pop('Test Target', None)
    model_dict.pop('Multi-Crop Test Score', None)
    model_dict.pop('Multi-Crop Test Target', None)

In [None]:
pd.DataFrame(model_pool)

In [None]:
pd.DataFrame(model_pool).to_csv(f'local/output/{task}-ensemble.csv')

In [None]:
import glob
for edf_file in glob.glob(r'C:\Users\bengb\Desktop\drive-download-20221104T094002Z-001\*.edf'):
    print('*' * 100)
    print(edf_file)
    print()
    signals, signal_headers, edf_header = pyedflib.highlevel.read_edf(edf_file)
    print("Loaded EDF signal has the shape of :", signals.shape)
    print()
    
    channel_config = [sh['label'].split(' ')[1].upper() for sh in signal_headers]
    if channel_config != ['FP1-REF', 'F3-REF', 'C3-REF', 'P3-REF', 'O1-REF', 'FP2-REF', 'F4-REF', 'C4-REF', 'P4-REF', 'O2-REF', 'F7-REF', 'T7-REF', 'P7-REF', 'F8-REF', 'T8-REF', 'P8-REF', 'FZ-REF', 'CZ-REF', 'PZ-REF', 'EKG1-EKG2']:
        print('WARNING:::::::::::::::::::::::::::::::')
    print(channel_config)
    print()

In [None]:
import tkinter as tk
from tkinter import filedialog
import pyedflib
from pprint import pprint
import datetime
from datasets.caueeg_data_curation import calculate_age, birth_to_datetime
import numpy as np

# file picker: EDF
root = tk.Tk()
root.withdraw()
root.focus_force()
root.wm_attributes('-topmost', 1)
edf_file = filedialog.askopenfile(title="Select an EDF file",
                                  filetypes=(("EDF files", "*.edf"),
                                             ("all files", "*.*"))).name
print("Selected EDF:", edf_file)
root.destroy()

# read the EDF file
signals, signal_headers, edf_header = pyedflib.highlevel.read_edf(edf_file)
print("Loaded EDF signal has the shape of :", signals.shape)
print()
# pprint(signal_headers)
# print()
# pprint(edf_header)

channel_config = [sh['label'].split(' ')[1].split('-')[0].upper() for sh in signal_headers]
target_channel_config = [ch.split('-')[0].upper() for ch in ["Fp1-AVG", "F3-AVG", "C3-AVG", "P3-AVG", "O1-AVG", "Fp2-AVG", "F4-AVG", "C4-AVG", "P4-AVG",
        "O2-AVG", "F7-AVG", "T3-AVG", "T5-AVG", "F8-AVG", "T4-AVG", "T6-AVG", "FZ-AVG", "CZ-AVG",
        "PZ-AVG", "EKG1"]]

print(channel_config)
print(target_channel_config)
print('---')
print(set(channel_config) - set(target_channel_config))
print(set(target_channel_config) - set(channel_config))
print('---')
print()

# axis_move = []
# for ch in channel_config:
#     for i, ta_ch in enumerate(target_channel_config):
#         if ch == ta_ch:
#             axis_move.append(i)\
# print(len(axis_move))
# print(np.array(channel_config).transpose(axis_move))
print('***')
print(montage)
print('***')
montage = [sh['label'].split('-')[1] for sh in signal_headers]
montage1 = ['Fp1-Ref', 'F3-Ref', 'C3-Ref', 'P3-Ref', 'Fp2-Ref', 'F4-Ref', 'C4-Ref', 'P4-Ref', 'F7-Ref', 'T7-Ref', 'P7-Ref', 'O1-Ref', 'F8-Ref', 'T8-Ref', 'P8-Ref', 'O2-Ref', 'Fz-Ref', 'Cz-Ref', 'Pz-Ref', 'EKG1-EKG2']

print(set(montage) - set(montage1))
print(set(montage1) - set(montage))
print('***')

montage2 = ['Fp1-C3', 'C3-O1', 'Fp1-T7', 'T7-O1', 'Fp2-C4', 'C4-O2', 'Fp2-T8', 'T8-O2', 'F7-Fz', 'Fz-F8', 'T7-Cz', 'Cz-T8', 'Pz-P8', 'Fz-Pz', 'Fp1-O2', 'EKG1-EKG2', 'Pg2-A2']
print(set(montage) - set(montage2))
print(set(montage2) - set(montage))
print('===')