# Train Networks

- Three-way SoftMax or Multi-BCE classifier of normal, non-vascular MCI, and non-vascular dementia
- `Weights and Biases` sweep is used for hyperparameter search

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import json
from copy import deepcopy
import gc
import time

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pprint
import wandb

# custom package
from datasets.cau_eeg_dataset import *
from datasets.cau_eeg_script import *
from models import *
from train import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.10.1+cu113
cuda is available.


-----

## Set the default configuration for building datatset

In [4]:
cfg_data = {}
cfg_data['device'] = device
cfg_data['dataset'] = 'CAUHS'
cfg_data['data_path'] = r'local/dataset/02_Curated_Data_220402/'
cfg_data['meta_path'] = os.path.join(cfg_data['data_path'], 'metadata.json')
cfg_data['file_format'] = 'feather'  # 'feather', 'memmap'
cfg_data['target_task'] = 'Normal, MCI, Dementia' # 'Norml, MCI, Dementia'
cfg_data['vascular'] = 'X'
cfg_data['segment'] = 'no' # 'train', 'all', 'no'
cfg_data['seed'] = 0
cfg_data['latency'] = 200 * 10  # 10 seconds
cfg_data['crop_length'] = 200 * 10 # 10 seconds
cfg_data['longer_crop_length'] = 200 * 10 * 10 # 100 seconds
cfg_data['crop_multiple'] = 4
cfg_data['minibatch'] = 128
cfg_data['input_norm'] = 'dataset' # 'datatset', 'datapoint', 'no'
cfg_data['EKG'] = 'X'
cfg_data['photic'] = 'X'
cfg_data['awgn'] = 3e-2
cfg_data['mgn'] = 1e-4
cfg_data['awgn_age'] = 5e-2

In [5]:
_ = build_dataset(cfg_data, verbose=True)
train_loader = _[0]
val_loader = _[1]
test_loader = _[2]
test_loader_longer = _[3]
preprocess_train = _[4]
preprocess_test = _[5]

class_label_to_name: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']

----------------------------------------------------------------------------------------------------

- There are 458 data belonging to Normal
- There are 350 data belonging to Non-vascular MCI
- There are 233 data belonging to Non-vascular dementia

----------------------------------------------------------------------------------------------------

Train data label distribution	: [366, 280, 186] 832
Train data label distribution	: [46, 35, 23] 104
Train data label distribution	: [46, 35, 24] 105

----------------------------------------------------------------------------------------------------

composed_train: Compose(
    EegDropEKGChannel()
    EegDropPhoticChannel()
    EegRandomCrop(crop_length=2000, multiple=4, latency=2000, return_timing=False)
    EegToTensor()
)

----------------------------------------------------------------------------------------------------

composed_test: Compose(
    EegDro

-----

## Define Network Models

In [6]:
cfg_common_model = {'in_channels': next(iter(train_loader))['signal'].shape[1], 
                    'out_dims': len(cfg_data['class_label_to_name'])}

cfg_model_pool = []

#### 1D Tiny CNN

In [7]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-Tiny-CNN'
cfg_model['generator'] = TinyCNN1D
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None
cfg_model['dropout'] = 0.3
cfg_model['activation'] = 'mish'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': None,
 'activation': 'mish',
 'base_channels': 64,
 'dropout': 0.3,
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.simple_cnn_1d.TinyCNN1D'>,
 'in_channels': 19,
 'model': '1D-Tiny-CNN',
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

TinyCNN1D(
  (conv1): Conv1d(19, 64, kernel_size=(35,), stride=(7,))
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(64, 64, kernel_size=(7,), stride=(1,))
  (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (final_pool): AdaptiveMaxPool1d(output_size=1)
  (fc_stage): Sequential(
    (0): Sequential(
      (0): Linear(in_features=65, out_features=32, bias=Fals

#### M7 model (fc-age)

In [8]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-Mx'
# cfg_model['generator'] = M7
# cfg_model['fc_stages'] = 1
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 256
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### 1D ResNet model (fc-age)

In [9]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNet-2x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [2, 2, 2, 2]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### Deeper 1D ResNet model (fc-age)

In [10]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNet-5x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 6, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### Shallower 1D ResNet

In [11]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNet-2x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BasicBlock1D
# cfg_model['conv_layers'] = [2, 2, 2, 2]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### Tiny 1D ResNet model

In [12]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNet-1x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BasicBlock1D
# cfg_model['conv_layers'] = [1, 1, 1, 1]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = None
# cfg_model['dropout'] = 0.2
# cfg_model['activation'] = 'mish'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### Multi-Dilated 1D ResNet model

In [13]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-Multi-Dilated-ResNet-5x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = MultiBottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 6, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 32
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### 1D ResNeXt-53

In [14]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNeXt-5x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 6, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['groups'] = 32
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### 1D ResNeXt-103

In [15]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNeXt-10x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 23, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['groups'] = 32
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### 2D ResNet-20 model

In [16]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '2D-ResNet-2x' # resnet-18 + three more fc layer
# cfg_model['generator'] = ResNet2D
# cfg_model['block'] = BasicBlock2D
# cfg_model['conv_layers'] = [2, 2, 2, 2]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['n_fft'] = 100
# cfg_model['complex_mode'] = 'as_real' # 'power', 'remove'
# cfg_model['hop_length'] = cfg_model['n_fft'] // 2
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')

# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### 2D ResNet-52 model

In [17]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '2D-ResNet-5x' # resnet-50 + three more fc layer
# cfg_model['generator'] = ResNet2D
# cfg_model['block'] = Bottleneck2D
# cfg_model['conv_layers'] = [3, 4, 6, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['n_fft'] = 100
# cfg_model['complex_mode'] = 'as_real' # 'power', 'remove'
# cfg_model['hop_length'] = cfg_model['n_fft'] // 2
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')

# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### 2D ResNeXt-104 model

In [18]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '2D-ResNeXt-10x' # resnet-101 + three more fc layer
# cfg_model['generator'] = ResNet2D
# cfg_model['block'] = Bottleneck2D
# cfg_model['conv_layers'] = [3, 4, 23, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['n_fft'] = 100
# cfg_model['complex_mode'] = 'as_real' # 'power', 'remove'
# cfg_model['hop_length'] = cfg_model['n_fft'] // 2
# cfg_model['groups'] = 32
# cfg_model['width_per_group'] = 8
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')

# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

#### CNN-Transformer

In [19]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-CNN-Transformer'
cfg_model['generator'] = CNNTransformer
cfg_model['fc_stages'] = 2
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 192
cfg_model['n_encoders'] = 8
cfg_model['n_heads'] = 4
cfg_model['dropout'] = 0.2
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': None,
 'activation': 'relu',
 'base_channels': 192,
 'dropout': 0.2,
 'fc_stages': 2,
 'final_pool': 'max',
 'generator': <class 'models.cnn_transformer.CNNTransformer'>,
 'in_channels': 19,
 'model': '1D-CNN-Transformer',
 'n_encoders': 8,
 'n_heads': 4,
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

CNNTransformer(
  (conv1): Conv1d(19, 192, kernel_size=(21,), stride=(9,))
  (bn1): BatchNorm1d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv1d(192, 192, kernel_size=(9,), stride=(3,))
  (bn2): BatchNorm1d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamically

#### Summarize the model pool

In [20]:
for cfg_model in cfg_model_pool:
    pprint.pp(cfg_model, width=150)
    print('\n' + '-' * 100 + '\n')

{'in_channels': 19,
 'out_dims': 3,
 'model': '1D-Tiny-CNN',
 'generator': <class 'models.simple_cnn_1d.TinyCNN1D'>,
 'fc_stages': 3,
 'use_age': 'fc',
 'final_pool': 'max',
 'base_channels': 64,
 'LR': None,
 'dropout': 0.3,
 'activation': 'mish'}

----------------------------------------------------------------------------------------------------

{'in_channels': 19,
 'out_dims': 3,
 'model': '1D-CNN-Transformer',
 'generator': <class 'models.cnn_transformer.CNNTransformer'>,
 'fc_stages': 2,
 'use_age': 'fc',
 'final_pool': 'max',
 'base_channels': 192,
 'n_encoders': 8,
 'n_heads': 4,
 'dropout': 0.2,
 'LR': None,
 'activation': 'relu'}

----------------------------------------------------------------------------------------------------



#### Selected model

In [21]:
model_index = 1
cfg_model = cfg_model_pool[model_index]

pprint.pp(cfg_model, width=150)

{'in_channels': 19,
 'out_dims': 3,
 'model': '1D-CNN-Transformer',
 'generator': <class 'models.cnn_transformer.CNNTransformer'>,
 'fc_stages': 2,
 'use_age': 'fc',
 'final_pool': 'max',
 'base_channels': 192,
 'n_encoders': 8,
 'n_heads': 4,
 'dropout': 0.2,
 'LR': None,
 'activation': 'relu'}


-----

## Default Configurations for Training

In [22]:
# training configurations
cfg_train = {}
cfg_train['iterations'] = (300000 * 32) // cfg_data['minibatch']
cfg_train['num_history'] = 500
cfg_train['lr_decay_timing'] = 0.45
cfg_train['lr_decay_gamma'] = 0.2
cfg_train['weight_decay'] = 1e-2
cfg_train['mixup'] = 0.3 # 0 for no usage
cfg_train['criterion'] = 'cross-entropy' # 'cross-entropy', 'multi-bce'

cfg_train['device'] = device
cfg_train['save_model'] = True
cfg_train['save_temporary'] = False
cfg_train['draw_result'] = True
cfg_train['watch_model'] = False

-----

## Train

In [24]:
# initialize the wandb log
wandb_run = wandb.init(project="eeg-analysis-2")
wandb.run.name = wandb.run.id

with wandb_run:
    # collect some garbages
    time.sleep(10)
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    # load default config
    config = {}
    for k, v in {**cfg_data, **cfg_train, **cfg_model}.items():
        if k not in wandb.config:
            config[k] = v

    # to prevent callables from type-conversion to str
    config['history_interval'] = config['iterations'] // config['num_history']
    wandb.config.update(config)
    for k, v in wandb.config.items():
        if k not in config:
            config[k] = v

    # train the model
    model = train_with_wandb(config, train_loader, val_loader, test_loader, test_loader_longer, 
                             preprocess_train, preprocess_test)
    
    # release memory
    del model

************************************************************************************************************************
******************************              1D-CNN-Transformer train starts               ******************************
************************************************************************************************************************



Passing a schema to Validator.iter_errors is deprecated and will be removed in a future release. Call validator.evolve(schema=new_schema).iter_errors(...) instead.







Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display




Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display



VBox(children=(Label(value='2.454 MB of 2.454 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Loss,█▇▇▄▄▅▆▃▅▅▂▅▄▅▄▆▅▅▅▄▃▂▂▆▄▄▂▅▂▄▃▂▁▄▄▄▂▃▁▂
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▂▂▃▅▅▄▆▄▄█▃▄▂▅▂▃▅▃▂▅▆▄▄▄▃▅▃▅▅▄█▅▄▇▄█▇▇▅
Validation Accuracy,█▆▄▂▂▃▂▁▆▁▅▁▂▂▃▂▂▂▂▄▂▂▂▄▄▃▂▁▄▃▂▄▂▅▃▄▃▃▃▄

0,1
Loss,0.97169
Test Accuracy,60.0
Test Accuracy (Longer),60.0
Train Accuracy,51.50136
Validation Accuracy,53.84615
