# Train Networks

- Three-way SoftMax classifier of normal, non-vascular MCI, and non-vascular dementia

-----

## Load Packages and Get Ready for Training

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
# Load some packages
import os
import glob
import json
import datetime
from copy import deepcopy

import matplotlib.pyplot as plt
import pprint
from IPython.display import clear_output
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import label_binarize
from tqdm.auto import tqdm 

import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from itertools import cycle

from torch.utils.tensorboard import SummaryWriter
import wandb

# custom package
from utils.eeg_dataset import *
from models import *
from utils.train_utils import *

In [None]:
# notebook name
def get_notebook_name():
    import ipynbname
    return ipynbname.name()
nb_fname = get_notebook_name()

In [None]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams["font.family"] = 'NanumGothic' # for Hangul in Windows

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

-----

## Set up the Dataset and the PyTorch Dataloader

In [None]:
cfg_data = {}
cfg_data['dataset'] = 'CAUHS'
cfg_data['vascular'] = 'X'
cfg_data['segment'] = 'no' # 'train', 'all'
cfg_data['seed'] = 0
cfg_data['crop_length'] = 200 * 10 # 10 seconds
cfg_data['input_norm'] = 'dataset' # 'datatset', 'datapoint', 'no'
cfg_data['EKG'] = 'O'
cfg_data['photic'] = 'X'
cfg_data['awgn'] = 5e-2
cfg_data['minibatch'] = 32

In [None]:
# Data file path
data_path = r'dataset/02_Curated_Data/'
meta_path = os.path.join(data_path, 'metadata_debug.json')

with open(meta_path, 'r') as json_file:
    metadata = json.load(json_file)

# pprint.pprint(metadata[0])

#### Filter the Data According to the Target Task

In [None]:
# consider only non-vascular symptoms
if cfg_data['vascular'] == 'X': 
    diagnosis_filter = [
        # Normal
        {'type': 'Normal',
         'include': ['normal'], 
         'exclude': []},
        # Non-vascular MCI
        {'type': 'Non-vascular MCI',
         'include': ['mci'], 
         'exclude': ['mci_vascular']},
        # Non-vascular dementia
        {'type': 'Non-vascular dementia',
         'include': ['dementia'], 
         'exclude': ['vd']},
    ]
# consider all cases
elif cfg_data['vascular'] == 'O':
    diagnosis_filter = [
        # Normal
        {'type': 'Normal',
         'include': ['normal'], 
         'exclude': []},
        # Non-vascular MCI
        {'type': 'Non-vascular MCI',
         'include': ['mci'], 
         'exclude': []},
        # Non-vascular dementia
        {'type': 'Non-vascular dementia',
         'include': ['dementia'], 
         'exclude': []},
    ]
else:
    raise ValueError(f"cfg_data['vascular'] have to be set to one of ['O', 'X']")

    
class_label_to_type = [d_f['type'] for d_f in diagnosis_filter]
print('class_label_to_type:', class_label_to_type)

In [None]:
def generate_class_label(label):
    for c, f in enumerate(diagnosis_filter):
        inc = set(f['include']) & set(label) == set(f['include'])
        # inc = len(set(f['include']) & set(label)) > 0        
        exc = len(set(f['exclude']) & set(label)) == 0
        if  inc and exc:
            return (c, f['type'])
    return (-1, 'The others')


splitted_metadata = [[] for i in diagnosis_filter]

for m in metadata:
    c, n = generate_class_label(m['label'])
    if c >= 0:
        m['class_type'] = n
        m['class_label'] = c
        splitted_metadata[c].append(m)
        
for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        print(f'(Warning) Split group {i} has no data.')
    else:
        print(f'- There are {len(split):} data belonging to {split[0]["class_type"]}')

#### Split the filtered dataset and shuffle them

In [None]:
# random seed
random.seed(cfg_data['seed'])

# Train : Val : Test = 8 : 1 : 1
ratio1 = 0.8
ratio2 = 0.1

metadata_train = []
metadata_val = []
metadata_test = []

for split in splitted_metadata:
    random.shuffle(split)
    
    n1 = round(len(split) * ratio1)
    n2 = n1 + round(len(split) * ratio2)

    metadata_train.extend(split[:n1])
    metadata_val.extend(split[n1:n2])
    metadata_test.extend(split[n2:])

random.shuffle(metadata_train)
random.shuffle(metadata_val)
random.shuffle(metadata_test)

print('Train data size\t\t:', len(metadata_train))
print('Validation data size\t:', len(metadata_val))
print('Test data size\t\t:', len(metadata_test))

print('\n', '--- Recheck ---', '\n')
train_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_train:
    train_class_nums[m['class_label']] += 1

val_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_val:
    val_class_nums[m['class_label']] += 1

test_class_nums = np.zeros((len(class_label_to_type)), dtype=np.int32)
for m in metadata_test:
    test_class_nums[m['class_label']] += 1

print('Train data label distribution\t:', train_class_nums, train_class_nums.sum())
print('Val data label distribution\t:', val_class_nums, val_class_nums.sum())
print('Test data label distribution\t:', test_class_nums, test_class_nums.sum())

# random seed
random.seed()

# print([m['serial']  for m in metadata_train[:15]])
# print([m['serial']  for m in metadata_val[:15]])
# print([m['serial']  for m in metadata_test[:15]])

#### Compose the dataset transforms

In [None]:
ages = []
for m in metadata_train:
    ages.append(m['age'])

ages = np.array(ages)
age_mean = np.mean(ages)
age_std = np.std(ages)

print('Age mean and standard deviation:')
print(age_mean, age_std)

cfg_data['age_mean'] = age_mean
cfg_data['age_std'] = age_std

composed_train = [EEGRandomCrop(crop_length=cfg_data['crop_length']), 
                  EEGNormalizeAge(mean=cfg_data['age_mean'], std=cfg_data['age_std'])]
composed_test = [EEGRandomCrop(crop_length=cfg_data['crop_length']), 
                 EEGNormalizeAge(mean=cfg_data['age_mean'], std=cfg_data['age_std'])]
longer_composed_test = [EEGRandomCrop(crop_length=cfg_data['crop_length'] * 10), 
                        EEGNormalizeAge(mean=cfg_data['age_mean'], std=cfg_data['age_std'])]

In [None]:
if cfg_data['input_norm'] == 'dataset':
    # composed = transforms.Compose([EEGRandomCrop(crop_length=cfg_data['crop_length'])])
    # train_dataset = EEGDataset(data_path, metadata_train, composed)

    # signal_means = []
    # signal_stds = []

    # for i in range(10):
    #     for d in train_dataset:
    #         signal_means.append(d['signal'].mean(axis=1, keepdims=True))
    #         signal_stds.append(d['signal'].std(axis=1, keepdims=True))

    # signal_mean = np.mean(np.array(signal_means), axis=0)
    # signal_std = np.mean(np.array(signal_stds), axis=0)

    # print('Mean and standard deviation for signal:')
    # print(signal_means, '\n\n', signal_stds)

    # SPEED-UP
    signal_mean = np.array([[ 0.1127599 ], [ 0.06298441], [-0.02522413], [ 0.00508518], 
                             [ 0.12026667], [-0.19987741], [-0.00516898], [ 0.00239212], 
                             [-0.02861219], [-0.02973673], [-0.02515898], [-0.00060568], 
                             [ 0.04921601], [-0.00562142], [-0.04888308], [-0.0438447 ], 
                             [ 0.07532331], [-0.01890181], [-0.044876  ], [-0.00365138], [-0.01564376]])
    signal_std = np.array([[46.09896  ], [20.50783  ], [11.196733 ], [11.236944 ], [15.070532 ], 
                            [47.664406 ], [19.32747  ], [10.106162 ], [11.314243 ], [15.065008 ],
                            [20.478817 ], [13.86243  ], [13.2378435], [21.554531 ], [16.875841 ],
                            [13.989367 ], [19.789454 ], [10.839711 ], [11.179158 ], [94.12114  ], [65.64865  ]])
    
    cfg_data['signal_mean'] = signal_mean
    cfg_data['signal_std'] = signal_std
    
    composed_train += [EEGNormalizeMeanStd(mean=cfg_data['signal_mean'], 
                                           std=cfg_data['signal_std'])]
    composed_test += [EEGNormalizeMeanStd(mean=cfg_data['signal_mean'], 
                                          std=cfg_data['signal_std'])]
    longer_composed_test += [EEGNormalizeMeanStd(mean=cfg_data['signal_mean'], 
                                                 std=cfg_data['signal_std'])]
elif cfg_data['input_norm'] == 'datapoint':
    composed_train += [EEGNormalizePerSignal()]
    composed_test += [EEGNormalizePerSignal()]
    longer_composed_test += [EEGNormalizePerSignal()]
elif cfg_data['input_norm'] == 'no':
    pass
else:
    raise ValueError(f"cfg_data['input_norm'] have to be set to one of ['dataset', 'datapoint', 'no']")

In [None]:
if cfg_data['EKG'] == 'O':
    pass
elif cfg_data['EKG'] == 'X':
    composed_train += [EEGDropEKGChannel()]
    composed_test += [EEGDropEKGChannel()]
    longer_composed_test += [EEGDropEKGChannel()]
else:
    raise ValueError(f"cfg_data['EKG'] have to be set to one of ['O', 'X']")

In [None]:
if cfg_data['photic'] == 'O':
    pass
elif cfg_data['photic'] == 'X':
    composed_train += [EEGDropPhoticChannel()]
    composed_test += [EEGDropPhoticChannel()]
    longer_composed_test += [EEGDropPhoticChannel()]
else:
    raise ValueError(f"cfg_data['photic'] have to be set to one of ['O', 'X']")

In [None]:
if cfg_data['awgn'] is None:
    pass
elif cfg_data['awgn'] > 0.0:
    composed_train += [EEGAddGaussianNoise(mean=0.0, std=cfg_data['awgn'])]
else:
    raise ValueError(f"cfg_data['awgn'] have to be None or a positive floating point number")

In [None]:
composed_train += [EEGToTensor()]
composed_test += [EEGToTensor()]
longer_composed_test += [EEGToTensor()]

composed_train = transforms.Compose(composed_train)
composed_test = transforms.Compose(composed_test)
longer_composed_test = transforms.Compose(longer_composed_test)

print('composed_train:', composed_train)
print('\n' + '-' * 100 + '\n')

print('composed_test:', composed_test)
print('\n' + '-' * 100 + '\n')

print('longer_composed_test:', longer_composed_test)
print('\n' + '-' * 100 + '\n')

#### Wrap the splitted data using PyTorch Dataset

In [None]:
train_dataset = EEGDataset(data_path, metadata_train, composed_train)
val_dataset = EEGDataset(data_path, metadata_val, composed_test)
test_dataset = EEGDataset(data_path, metadata_test, composed_test)
longer_test_dataset = EEGDataset(data_path, metadata_test, longer_composed_test)

print(train_dataset[0]['signal'].shape)
print(train_dataset[0])
print('\n' + '-' * 100 + '\n')

print(val_dataset[0]['signal'].shape)
print(val_dataset[0])
print('\n' + '-' * 100 + '\n')

print(test_dataset[0]['signal'].shape)
print(test_dataset[0])
print('\n' + '-' * 100 + '\n')

print(longer_test_dataset[0]['signal'].shape)
print(longer_test_dataset[0])

#### Train, validation, test dataloaders

In [None]:
if device.type == 'cuda':
    num_workers = 0 # A number other than 0 causes an error
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

train_loader = DataLoader(train_dataset, 
                          batch_size=cfg_data['minibatch'], 
                          shuffle=True, 
                          drop_last=True,
                          num_workers=num_workers, 
                          pin_memory=pin_memory,
                          collate_fn=eeg_collate_fn)

val_loader = DataLoader(val_dataset, 
                        batch_size=cfg_data['minibatch'], 
                        shuffle=False, 
                        drop_last=False,
                        num_workers=num_workers, 
                        pin_memory=pin_memory,
                        collate_fn=eeg_collate_fn)

test_loader = DataLoader(test_dataset, 
                         batch_size=cfg_data['minibatch'], 
                         shuffle=False, 
                         drop_last=False,
                         num_workers=num_workers, 
                         pin_memory=pin_memory,
                         collate_fn=eeg_collate_fn)

longer_test_loader = DataLoader(longer_test_dataset, 
                                batch_size=cfg_data['minibatch'] // 2, # memory capacity
                                shuffle=False, 
                                drop_last=False,
                                num_workers=num_workers, 
                                pin_memory=pin_memory,
                                collate_fn=eeg_collate_fn)

for i_batch, sample_batched in enumerate(train_loader):
    sample_batched['signal'].to(device)
    sample_batched['age'].to(device)
    sample_batched['class_label'].to(device)
    
    print(i_batch, 
          sample_batched['signal'].shape, 
          sample_batched['age'].shape, 
          sample_batched['class_label'].shape, 
          len(sample_batched['metadata']))
    
    if i_batch > 3:
        break

-----

## Define Network Models

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def calculate_final_shape(model):
    x = torch.zeros_like(sample_batched['signal']).to(device)
    model(x, age=sample_batched['age'].to(device))
    return model.get_final_shape()


def visualize_network_tensorboard(model, name):
    # default `log_dir` is "runs" - we'll be more specific here
    writer = SummaryWriter('runs/' + nb_fname + '_' + name)

    for batch_i, sample_batched in enumerate(train_loader):
        # pull up the batch data
        x = sample_batched['signal'].to(device)
        age = sample_batched['age'].to(device)
        target = sample_batched['class_label'].to(device)

        # apply model on whole batch directly on device
        writer.add_graph(model, (x, age))
        output = model(x, age, print_shape=True)
        break
        
    writer.close()

In [None]:
cfg_common_model = {'in_channels': train_dataset[0]['signal'].shape[0], 
                    'out_dims': len(class_label_to_type)}
model_pool = []

#### 1D Tiny CNN

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-Tiny-CNN'
cfg_model['generator'] = TinyCNN1D
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### M7 model (fc-age)

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = 'M7'
cfg_model['generator'] = M7
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 256
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### M7 model (conv-age)

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = 'M7'
cfg_model['generator'] = M7
cfg_model['use_age'] = 'conv'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 256
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### M7 model (no-age)

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = 'M7'
cfg_model['generator'] = M7
cfg_model['use_age'] = None
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 256
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### 1D ResNet model (fc-age)

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-ResNet-29'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BottleneckBlock1D
cfg_model['conv_layers'] = [2, 2, 2, 2]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### 1D ResNet model (conv-age)

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-ResNet-29'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BottleneckBlock1D
cfg_model['conv_layers'] = [2, 2, 2, 2]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'conv'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### 1D ResNet model (no-age)

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-ResNet-29'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BottleneckBlock1D
cfg_model['conv_layers'] = [2, 2, 2, 2]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = None
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### Deeper 1D ResNet model

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-ResNet-53'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BottleneckBlock1D
cfg_model['conv_layers'] = [3, 4, 6, 3]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### Shallower 1D ResNet

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-ResNet-21'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BasicBlock1D
cfg_model['conv_layers'] = [2, 2, 2, 2]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### Tiny 1D ResNet model

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-ResNet-13'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BasicBlock1D
cfg_model['conv_layers'] = [1, 1, 1, 1]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### Multi-Dilated 1D ResNet model

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-Multi-Dilated-ResNet-53'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = MultiBottleneckBlock1D
cfg_model['conv_layers'] = [3, 4, 6, 3]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 32
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### 1D ResNeXt-53

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-ResNeXt-53'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BottleneckBlock1D
cfg_model['conv_layers'] = [3, 4, 6, 3]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['groups'] = 32
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### 2D ResNet-20 model

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '2D-ResNet-20' # resnet-18 + two more fc layer
cfg_model['generator'] = ResNet2D
cfg_model['block'] = BasicBlock2D
cfg_model['conv_layers'] = [2, 2, 2, 2]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['n_fft'] = 100
cfg_model['complex_mode'] = 'as_real' # 'power', 'remove'
cfg_model['hop_length'] = cfg_model['n_fft'] // 2
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

#### 2D ResNet-52 model

In [None]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '2D-ResNet-52' # resnet-18 + two more fc layer
cfg_model['generator'] = ResNet2D
cfg_model['block'] = Bottleneck2D
cfg_model['conv_layers'] = [3, 4, 6, 3]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['n_fft'] = 100
cfg_model['complex_mode'] = 'as_real' # 'power', 'remove'
cfg_model['hop_length'] = cfg_model['n_fft'] // 2
cfg_model['LR'] = None

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

cfg_model['num_params'] = count_parameters(model)
cfg_model['final_shape'] = calculate_final_shape(model)

print(f'- The Number of parameters of the model: {cfg_model["num_params"]:,}')
print('- Tensor shape right before FC stage:', cfg_model["final_shape"])

# tensorboard visualization
# visualize_network_tensorboard(model, '1D-Tiny-CNN-fc-age')

del model
model_pool.append(cfg_model)

In [None]:
for cfg_model in model_pool:
    pprint.pp(cfg_model, width=150)
    print('\n' + '-' * 100 + '\n')

-----

## Some useful functions for training

In [None]:
def train_multistep(model, loader, optimizer, scheduler, config, steps):
    model.train()
        
    i = 0
    cumu_loss = 0
    correct, total = (0, 0)
    
    while True:
        for sample_batched in loader:
            optimizer.zero_grad()

            # load the mini-batched data
            x = sample_batched['signal'].to(device)
            age = sample_batched['age'].to(device)
            y = sample_batched['class_label'].to(device)
            
            # forward pass
            output = model(x, age)
            
            # loss function
            if config['criterion'] == 'cross-entropy':
                s = F.log_softmax(output, dim=1)
                loss = F.nll_loss(s, y)
            elif config['criterion'] == 'multi-bce':
                y_oh = F.one_hot(y, num_classes=len(class_label_to_type))
                s = torch.sigmoid(output)
                loss = F.binary_cross_entropy_with_logits(output, y_oh.float())

            # backward and update
            loss.backward()
            optimizer.step()
            scheduler.step()

            # train accuracy
            pred = s.argmax(dim=-1)
            correct += pred.squeeze().eq(y).sum().item()
            total += pred.shape[0]
            cumu_loss += loss.item()
            
            i += 1
            if steps <= i: break
        if steps <= i: break
            
    train_acc = 100.0 * correct / total
    avg_loss = cumu_loss / steps
    
    return (avg_loss, train_acc)


def train_mixup_multistep(model, loader, optimizer, scheduler, config, steps):
    model.train()
        
    i = 0
    cumu_loss = 0
    correct, total = (0, 0)
    
    while True:
        for sample_batched in loader:
            optimizer.zero_grad()

            # load and mixup the mini-batched data
            x1 = sample_batched['signal'].to(device)
            age1 = sample_batched['age'].to(device)
            y1 = sample_batched['class_label'].to(device)

            index = torch.randperm(x1.shape[0]).cuda()                
            x2 = x1[index]
            age2 = age1[index]
            y2 = y1[index]
            
            mixup_alpha = config['mixup']
            lam = np.random.beta(mixup_alpha, mixup_alpha)
            x = lam * x1 + (1.0 - lam) * x2
            age = lam * age1 + (1.0 - lam) * age2

            # forward pass
            output = model(x, age)
            
            # loss function
            if config['criterion'] == 'cross-entropy':
                s = F.log_softmax(output, dim=1)
                loss1 = F.nll_loss(s, y1)
                loss2 = F.nll_loss(s, y2)
                loss = lam * loss1 + (1 - lam) * loss2
            elif config['criterion'] == 'multi-bce':
                y1_oh = F.one_hot(y1, num_classes=len(class_label_to_type))
                y2_oh = F.one_hot(y2, num_classes=len(class_label_to_type))
                y_oh = lam * y1_oh + (1.0 - lam) * y2_oh
                s = torch.sigmoid(output)
                loss = F.binary_cross_entropy_with_logits(output, y_oh)

            # backward and update
            loss.backward()
            optimizer.step()
            scheduler.step()

            # train accuracy
            pred = s.argmax(dim=-1)
            correct1 = pred.squeeze().eq(y1).sum().item()
            correct2 = pred.squeeze().eq(y2).sum().item()
            correct += lam * correct1 + (1.0 - lam) * correct2
            total += pred.shape[0]
            cumu_loss += loss.item()
            
            i += 1
            if steps <= i: break
        if steps <= i: break
            
    train_acc = 100.0 * correct / total
    avg_loss = cumu_loss / steps
    
    return (avg_loss, train_acc)

In [None]:
def check_accuracy(model, loader, config, repeat=1):
    model.eval()
    
    # for accuracy
    correct, total = (0, 0) 
    
    # for confusion matrix
    C = len(class_label_to_type)
    confusion_matrix = np.zeros((C, C), dtype=np.int32)
    
    # for debug table
    debug_table = {data['metadata']['serial']: {'GT': data['class_label'].item(), 
                                                'Acc': 0, 
                                                'Pred': [0] * C} for data in loader.dataset}
    
    # for ROC curve
    score = None
    target = None
    
    with torch.no_grad():
        for k in range(repeat):
            for sample_batched in loader:
                # pull up the data
                x = sample_batched['signal'].to(device)
                age = sample_batched['age'].to(device)
                y = sample_batched['class_label'].to(device)

                # apply model on whole batch directly on device
                output = model(x, age)
                
                if config['criterion'] == 'cross-entropy':
                    s = F.softmax(output, dim=1)
                elif config['criterion'] == 'multi-bce':
                    s = torch.sigmoid(output)
                
                # calculate accuracy
                pred = s.argmax(dim=-1)
                correct += pred.squeeze().eq(y).sum().item()
                total += pred.shape[0]

                if score is None:
                    score = s.detach().cpu().numpy()
                    target = y.detach().cpu().numpy()
                else:
                    score = np.concatenate((score, s.detach().cpu().numpy()), axis=0)
                    target = np.concatenate((target, y.detach().cpu().numpy()), axis=0)

                # confusion matrix
                confusion_matrix += calculate_confusion_matrix(pred, y)

                # debug table
                for n in range(pred.shape[0]):
                    serial = sample_batched['metadata'][n]['serial']
                    debug_table[serial]['edfname'] = sample_batched['metadata'][n]['edfname']
                    debug_table[serial]['Pred'][pred[n].item()] += 1
                    acc = debug_table[serial]['Pred'][y[n].item()] / np.sum(debug_table[serial]['Pred']) * 100
                    debug_table[serial]['Acc'] = f'{acc:>6.02f}%'

    # debug table update
    debug_table_serial = []
    debug_table_edf = []
    debug_table_pred = []
    debug_table_gt = []
    
    for key, val in debug_table.items():
        debug_table_serial.append(key)
        debug_table_edf.append(val['edfname'])
        debug_table_pred.append(val['Pred'])
        debug_table_gt.append(val['GT'])
        
    debug_table = (debug_table_serial, debug_table_edf, debug_table_pred, debug_table_gt)

    accuracy = 100.0 * correct / total
    return (accuracy, confusion_matrix, debug_table, score, target)


def calculate_confusion_matrix(pred, target):
    N = target.shape[0]
    C = len(class_label_to_type)
    confusion = np.zeros((C, C), dtype=np.int32)
    
    for i in range(N):
        r = target[i]
        c = pred[i]
        confusion[r, c] += 1
    return confusion

In [None]:
def draw_confusion(confusion, use_wandb=False):
    C = len(class_label_to_type)
    
    plt.style.use('default') # default, ggplot, fivethirtyeight, classic
    plt.rcParams['image.cmap'] = 'jet' # 'nipy_spectral'

    fig = plt.figure(num=1, clear=True, figsize=(4.0, 4.0), constrained_layout=True)
    ax = fig.add_subplot(1, 1, 1)
    im = ax.imshow(confusion, alpha=0.8)

    ax.set_xticks(np.arange(C))
    ax.set_yticks(np.arange(C))
    ax.set_xticklabels(class_label_to_type)
    ax.set_yticklabels(class_label_to_type)
    
    for r in range(C):
        for c in range(C):
            text = ax.text(c, r, confusion[r, c],
                           ha="center", va="center", color='k')
    
    ax.set_title('Confusion Matrix')
    ax.set_xlabel('Prediction')
    ax.set_ylabel('Ground Truth')
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    
    # draw
    if use_wandb:
        wandb.log({'Confusion Matrix (Image)': wandb.Image(plt)})
    else: 
        plt.show()
    
    fig.clear()
    plt.close(fig)
    

def draw_roc_curve(score, target, use_wandb=False):
    plt.style.use('default') # default, ggplot, fivethirtyeight, classic
    
    # Binarize the output
    n_classes = len(class_label_to_type)
    target = label_binarize(target, classes=np.arange(n_classes))
    
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(target[:, i], score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(target.ravel(), score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    
    # aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

    # Then interpolate all ROC curves at this points
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

    # Finally average it and compute AUC
    mean_tpr /= n_classes

    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
    
    # draw class-agnostic ROC curve
    fig = plt.figure(num=1, clear=True, figsize=(8.5, 4.0), constrained_layout=True)
    ax = fig.add_subplot(1, 2, 1)
    lw = 1.5
    colors = cycle(['limegreen', 'mediumpurple', 'darkorange', 
                    'dodgerblue', 'lightcoral', 'goldenrod', 
                    'indigo', 'darkgreen', 'navy', 'brown'])
    for i, color in zip(range(n_classes), colors):
        ax.plot(fpr[i], tpr[i], color=color, lw=lw,
                label='{0} (area = {1:0.2f})'
                ''.format(class_label_to_type[i], roc_auc[i]))    
    ax.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('Class-Wise ROC Curves')
    ax.legend(loc="lower right")

    # Plot class-aware ROC curves
    ax = fig.add_subplot(1, 2, 2)
    plt.plot(fpr["micro"], tpr["micro"],
             label='micro-average (area = {0:0.2f})'
                   ''.format(roc_auc["micro"]),
             color='deeppink', linestyle='-', linewidth=lw)

    plt.plot(fpr["macro"], tpr["macro"],
             label='macro-average (area = {0:0.2f})'
                   ''.format(roc_auc["macro"]),
             color='navy', linestyle='-', linewidth=lw)

    ax.plot([0, 1], [0, 1], 'k--', lw=lw)
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('Class-Agnostic ROC Curves')
    ax.legend(loc="lower right")
    
    # draw
    if use_wandb:
        wandb.log({'ROC Curve (Image)': wandb.Image(plt)})
    else: 
        plt.show()
        
    fig.clear()
    plt.close(fig)
    
    
def draw_debug_table(debug_table, use_wandb=False):
    (debug_table_serial, debug_table_edf, debug_table_pred, debug_table_gt) = debug_table
    
    fig = plt.figure(num=1, clear=True, figsize=(20.0, 4.0), constrained_layout=True)
    ax = fig.add_subplot(1, 1, 1)

    total_error, total_count = (0, 0)

    for edf in np.unique(debug_table_edf):
        indices = [i for i, x in enumerate(debug_table_edf) if x == edf]

        err, cnt = (0, 0)
        for i in indices:
            cnt += sum(debug_table_pred[i])
            err += sum(debug_table_pred[i]) - debug_table_pred[i][debug_table_gt[i]]

        total_error += err
        total_count += cnt

        ax.bar(edf, err / cnt, color=['g', 'b', 'r'][debug_table_gt[i]])

    ax.set_title(f'Debug Table (Acc. {1.0 - total_error / total_count: .2f}%)', fontsize=18)
    ax.set_ylim(0.0, 1.0)
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right", fontsize=9, visible=True)
    
    if use_wandb:
        table = [[serial, edf, pred, gt] for serial, edf, pred, gt in zip(*debug_table)]
        table = wandb.Table(data=table, columns=['Serial', 'EDF', 'Prediction', 'Ground-truth'])
        wandb.log({'Debug Table': table})
        
        wandb.log({'Debug Table (Image)': wandb.Image(plt)})
    else:
        plt.show()
    
    fig.clear()
    plt.close(fig)

In [None]:
def learning_rate_search(model, min_log_lr, max_log_lr, trials, config, steps):
    learning_rate_record = []
    for t in tqdm(range(trials)):
        log_lr = np.random.uniform(min_log_lr, max_log_lr)
        lr = 10 ** log_lr
        
        model.reset_weights()
        model.train()
        
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=config["weight_decay"])
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=config['lr_decay_step'], gamma=config['lr_decay_gamma'])        
        
        _, train_accuracy = config['tr_ms'](model, train_loader, optimizer, scheduler, config, steps)
        
        # Train accuracy for the final epoch is stored
        learning_rate_record.append((log_lr, train_accuracy))
    
    return learning_rate_record


def draw_learning_rate_record(learning_rate_record, use_wandb=False):
    plt.style.use('default') # default, ggplot, fivethirtyeight, classic

    fig = plt.figure(num=1, clear=True, constrained_layout=True, figsize=(5.0, 5.0))
    ax = fig.add_subplot(1, 1, 1)
    
    ax.set_title('Learning Rate Search')
    ax.set_xlabel('Learning rate in log-scale')
    ax.set_ylabel('Train accuracy')

    ax.scatter(*max(learning_rate_record, key=lambda x: x[1]), 
               s=150, c='w', marker='o', edgecolors='limegreen')
    
    for log_lr, val_accuracy in learning_rate_record:
        ax.scatter(log_lr, val_accuracy, c='r',
                   alpha=0.5, edgecolors='none')
    
    if use_wandb:
        wandb.log({'Learning Rate Search (Image)': wandb.Image(plt)})
    else:
        plt.show()
        
    fig.clear()
    plt.close(fig)

-----

## Train models

In [None]:
# training configurations
cfg_train = {}
cfg_train['iterations'] = 100000
cfg_train['history_interval'] = cfg_train['iterations'] // 500
cfg_train['lr_decay_step'] = round(cfg_train['iterations'] * 0.8)
cfg_train['lr_decay_gamma'] = 0.1
cfg_train['weight_decay'] = 1e-2
cfg_train['mixup'] = 0.0 # 0 for no usage
cfg_train['criterion'] = 'cross-entropy' # 'cross-entropy', 'multi-bce'
cfg_train['tr_ms'] = train_multistep if cfg_train.get('mixup', 0) < 1e-3 else train_mixup_multistep

In [None]:
for cfg_model in model_pool:
    if cfg_model["LR"] is None:
        print(f'{cfg_model["model"]} LR searching..')
        model = cfg_model['generator'](**cfg_model).to(device)
        model.train()
        
        record = learning_rate_search(model, min_log_lr=-4.5, max_log_lr=-1.4, 
                                      trials=5, config=cfg_train, steps=300)
        best_log_lr = record[np.argmax(np.array([v for lr, v in record]))][0]
        
        cfg_model['LR'] = 10 ** best_log_lr
        cfg_model['lr_search'] = record
        
        print(f'best lr {cfg_model["LR"]:.5e} / log_lr {best_log_lr}')
    else:
        print(f'{cfg_model["model"]}: {cfg_model["LR"]:.5e}')

In [None]:
save_model = True
save_temporary = False
draw_result = True

# progress bar
pbar = tqdm(total=len(model_pool) * cfg_train['iterations'])

# train process on model_pool
for cfg_model in model_pool:
    print(f'{"*"*40} {cfg_model["model"]} train starts {"*"*40}')
    
    # wandb initialization
    config = {}
    config.update(cfg_data)
    config.update(cfg_train)
    config.update(cfg_model)
    
    # generate model and its trainer
    model = config['generator'](**config).to(device)
    optimizer = optim.AdamW(model.parameters(), 
                            lr=config['LR'], 
                            weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                          step_size=config['lr_decay_step'], 
                                          gamma=config['lr_decay_gamma'])
    
    wandb_run = wandb.init(project="eeg-analysis", 
                           entity="ipis-mjkim", 
                           reinit=True,
                           save_code=True, 
                           notes=nb_fname,
                           config=config)
    wandb.run.name = wandb.run.id
    
    save_path = f'history_temp/{wandb.run.name}/'
    os.makedirs(save_path, exist_ok=True)
    
    with wandb_run:
        wandb.watch(model, log='all', 
                    log_freq=config['history_interval'], 
                    log_graph=True)
        
        # train and validation routine
        best_val_acc = 0
        for i in range(0, config["iterations"], config["history_interval"]):
            # train 'history_interval' steps
            loss, train_acc = cfg_train['tr_ms'](model, train_loader, optimizer, scheduler, 
                                                 config, config["history_interval"])
            
            # validation
            val_acc, _, _, _, _ = check_accuracy(model, val_loader, config, repeat=10)
            
            if best_val_acc < val_acc:
                best_val_acc = val_acc
                best_model_state = deepcopy(model.state_dict())                
                if save_model and save_temporary:
                    path = os.path.join(save_path, f'{config["model"]}')
                    torch.save(best_model_state, path)                    
                
            # log
            wandb.log({'Loss': loss, 
                       'Train Accuracy': train_acc, 
                       'Validation Accuracy': val_acc}, step=i)
            pbar.update(config['history_interval'])

        # calculate the test accuracies for best and last models
        last_model_state = deepcopy(model.state_dict())
        last_test_result = check_accuracy(model, test_loader, config, repeat=30)
        last_test_acc = last_test_result[0]
        
        model.load_state_dict(best_model_state)
        best_test_result = check_accuracy(model, test_loader, config, repeat=30)
        best_test_acc = best_test_result[0]
 
        if last_test_acc < best_test_acc:
            model_state = best_model_state
            test_result = best_test_result
        else:
            model_state = last_model_state
            test_result = last_test_result
            
        model.load_state_dict(model_state)
        test_acc, test_confusion, test_debug, score, target = test_result
        
        # calculate the test accuracies for final model on much longer sequence
        last_test_result = check_accuracy(model, longer_test_loader, config, repeat=30)
        longer_test_acc = last_test_result[0]
        
        # save the model
        if save_model:
            path = os.path.join(save_path, f'{config["model"]}')
            torch.save(model_state, path)
            
        # leave the message
        wandb.log({'Test Accuracy': test_acc,
                   '(Best / Last) Test Accuracy': ('Best' if last_test_acc < best_test_acc else 'Last', 
                                                   round(best_test_acc, 2), round(last_test_acc, 2)),
                   'Confusion Matrix (Array)': test_confusion,
                   'Test Accuracy (Longer)': longer_test_acc, 
                   'Test Debug Table/Serial': test_debug[0], 
                   'Test Debug Table/EDF': test_debug[1], 
                   'Test Debug Table/Pred': test_debug[2], 
                   'Test Debug Table/GT': test_debug[3]})
        
        if 'lr_search' in config:
            draw_learning_rate_record(config['lr_search'], use_wandb=True)
            
        
        if draw_result:
            draw_roc_curve(score, target, use_wandb=True)
            draw_confusion(test_confusion, use_wandb=True)
            draw_debug_table(test_debug, use_wandb=True)
            wandb.log({"Confusion Matrix": wandb.plot.confusion_matrix(y_true=target, 
                                                                              preds=score.argmax(axis=-1), 
                                                                              class_names=class_label_to_type)})
            wandb.log({"ROC Curve": wandb.plot.roc_curve(target, score, labels=class_label_to_type)})
            
            
    print('\n' + '-' * 100)