# Train Networks

- Three-way SoftMax or Multi-BCE classifier of normal, non-vascular MCI, and non-vascular dementia

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import json
from copy import deepcopy

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pprint
import wandb

# custom package
from models import *
from utils.eeg_dataset import *
from utils.train_utils import *

In [3]:
# notebook name
def get_notebook_name():
    import ipynbname
    return ipynbname.name()
nb_fname = get_notebook_name()

In [4]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.') 

PyTorch version: 1.9.0
cuda is available.


-----

## Set the default configuration for building datatset

In [5]:
cfg_data = {}
cfg_data['device'] = device
cfg_data['dataset'] = 'CAUHS'
cfg_data['data_path'] = r'dataset/02_Curated_Data/'
cfg_data['meta_path'] = os.path.join(cfg_data['data_path'], 'metadata_debug.json')
cfg_data['target_task'] = 'Normal, MCI, Dementia' # 'Norml, MCI, Dementia'
cfg_data['vascular'] = 'X'
cfg_data['segment'] = 'no' # 'train', 'all', 'no'
cfg_data['seed'] = 0
cfg_data['crop_length'] = 200 * 10 # 10 seconds
cfg_data['longer_crop_length'] = 200 * 10 * 10 # 100 seconds
cfg_data['input_norm'] = 'dataset' # 'datatset', 'datapoint', 'no'
cfg_data['EKG'] = 'O'
cfg_data['photic'] = 'X'
cfg_data['awgn'] = 5e-2
cfg_data['awgn_age'] = 5e-2
cfg_data['minibatch'] = 32

In [6]:
_ = build_dataset(cfg_data, verbose=True)

class_label_to_type: ['Normal', 'Non-vascular MCI', 'Non-vascular dementia']

----------------------------------------------------------------------------------------------------

- There are 463 data belonging to Normal
- There are 347 data belonging to Non-vascular MCI
- There are 229 data belonging to Non-vascular dementia

----------------------------------------------------------------------------------------------------

Train data label distribution	: [370, 278, 183] 831
Train data label distribution	: [46, 35, 23] 104
Train data label distribution	: [47, 34, 23] 104

----------------------------------------------------------------------------------------------------

composed_train: Compose(
    <utils.eeg_dataset.EEGRandomCrop object at 0x0000027E42CE18E0>
    <utils.eeg_dataset.EEGNormalizeMeanStd object at 0x0000027E4820A880>
    <utils.eeg_dataset.EEGNormalizeAge object at 0x0000027E4820A8E0>
    <utils.eeg_dataset.EEGDropPhoticChannel object at 0x0000027E4B280BE0>
    <uti

-----

## Define Network Models

In [7]:
cfg_common_model = {'in_channels': _[0].dataset[0]['signal'].shape[0], 
                    'out_dims': len(_[-1])}
cfg_model_pool = []

#### 1D Tiny CNN

In [8]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-Tiny-CNN'
# cfg_model['generator'] = TinyCNN1D
# cfg_model['fc_stages'] = 1
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

#### M7 model (fc-age)

In [9]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-Mx'
# cfg_model['generator'] = M7
# cfg_model['fc_stages'] = 1
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 256
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 256,
 'fc_stages': 1,
 'final_pool': 'max',
 'generator': <class 'models.simple_cnn_1d.M7'>,
 'in_channels': 20,
 'model': '1D-Mx',
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

M7(
  (conv1): Conv1d(20, 256, kernel_size=(41,), stride=(2,))
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(256, 256, kernel_size=(11,), stride=(1,))
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(256, 512, kernel_size=(11,), stride=(1,))
  (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool1d(kernel_size=3, stride=3, padding=0, d

#### 1D ResNet model (fc-age)

In [10]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNet-2x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [2, 2, 2, 2]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_1d.BottleneckBlock1D'>,
 'conv_layers': [2, 2, 2, 2],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_1d.ResNet1D'>,
 'in_channels': 20,
 'model': '1D-ResNet-2x',
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

ResNet1D(
  (input_stage): Sequential(
    (0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv_stage1): Sequential(
    (0): BottleneckBlock1D(
      (conv1): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum

#### Deeper 1D ResNet model (fc-age)

In [11]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNet-5x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 6, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_1d.BottleneckBlock1D'>,
 'conv_layers': [3, 4, 6, 3],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_1d.ResNet1D'>,
 'in_channels': 20,
 'model': '1D-ResNet-5x',
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

ResNet1D(
  (input_stage): Sequential(
    (0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv_stage1): Sequential(
    (0): BottleneckBlock1D(
      (conv1): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum

#### Shallower 1D ResNet

In [12]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNet-2x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BasicBlock1D
# cfg_model['conv_layers'] = [2, 2, 2, 2]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_1d.BasicBlock1D'>,
 'conv_layers': [2, 2, 2, 2],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_1d.ResNet1D'>,
 'in_channels': 20,
 'model': '1D-ResNet-2x',
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

ResNet1D(
  (input_stage): Sequential(
    (0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv_stage1): Sequential(
    (0): BasicBlock1D(
      (conv1): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
      (bn2): BatchNorm1d(64, eps=1e-05, mome

#### Tiny 1D ResNet model

In [13]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNet-1x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BasicBlock1D
# cfg_model['conv_layers'] = [1, 1, 1, 1]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_1d.BasicBlock1D'>,
 'conv_layers': [1, 1, 1, 1],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_1d.ResNet1D'>,
 'in_channels': 20,
 'model': '1D-ResNet-1x',
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

ResNet1D(
  (input_stage): Sequential(
    (0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv_stage1): Sequential(
    (0): BasicBlock1D(
      (conv1): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), bias=False)
      (bn2): BatchNorm1d(64, eps=1e-05, mome

#### Multi-Dilated 1D ResNet model

In [14]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-Multi-Dilated-ResNet-5x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = MultiBottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 6, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 32
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

#### 1D ResNeXt-53

In [15]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNeXt-5x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 6, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['groups'] = 32
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_1d.BottleneckBlock1D'>,
 'conv_layers': [3, 4, 6, 3],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_1d.ResNet1D'>,
 'groups': 32,
 'in_channels': 20,
 'model': '1D-ResNeXt-5x',
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

ResNet1D(
  (input_stage): Sequential(
    (0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv_stage1): Sequential(
    (0): BottleneckBlock1D(
      (conv1): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), groups=32, bias=False)
      (bn2): BatchNor

#### 1D ResNeXt-103

In [16]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '1D-ResNeXt-10x'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 23, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['groups'] = 32
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_1d.BottleneckBlock1D'>,
 'conv_layers': [3, 4, 23, 3],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_1d.ResNet1D'>,
 'groups': 32,
 'in_channels': 20,
 'model': '1D-ResNeXt-10x',
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

ResNet1D(
  (input_stage): Sequential(
    (0): Conv1d(20, 64, kernel_size=(27,), stride=(2,), padding=(13,), bias=False)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv_stage1): Sequential(
    (0): BottleneckBlock1D(
      (conv1): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kernel_size=(9,), stride=(1,), padding=(4,), groups=32, bias=False)
      (bn2): BatchN

#### 2D ResNet-20 model

In [17]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '2D-ResNet-2x' # resnet-18 + three more fc layer
# cfg_model['generator'] = ResNet2D
# cfg_model['block'] = BasicBlock2D
# cfg_model['conv_layers'] = [2, 2, 2, 2]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['n_fft'] = 100
# cfg_model['complex_mode'] = 'as_real' # 'power', 'remove'
# cfg_model['hop_length'] = cfg_model['n_fft'] // 2
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')

# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_2d.BasicBlock2D'>,
 'complex_mode': 'as_real',
 'conv_layers': [2, 2, 2, 2],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_2d.ResNet2D'>,
 'hop_length': 50,
 'in_channels': 20,
 'model': '2D-ResNet-2x',
 'n_fft': 100,
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

ResNet2D(
  (conv1): Conv2d(40, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock2D(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

#### 2D ResNet-52 model

In [18]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '2D-ResNet-5x' # resnet-50 + three more fc layer
# cfg_model['generator'] = ResNet2D
# cfg_model['block'] = Bottleneck2D
# cfg_model['conv_layers'] = [3, 4, 6, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['n_fft'] = 100
# cfg_model['complex_mode'] = 'as_real' # 'power', 'remove'
# cfg_model['hop_length'] = cfg_model['n_fft'] // 2
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')

# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_2d.Bottleneck2D'>,
 'complex_mode': 'as_real',
 'conv_layers': [3, 4, 6, 3],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_2d.ResNet2D'>,
 'hop_length': 50,
 'in_channels': 20,
 'model': '2D-ResNet-5x',
 'n_fft': 100,
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

ResNet2D(
  (conv1): Conv2d(40, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck2D(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (con

#### 2D ResNeXt-104 model

In [19]:
# cfg_model = {}
# cfg_model.update(cfg_common_model)
# cfg_model['model'] = '2D-ResNeXt-10x' # resnet-101 + three more fc layer
# cfg_model['generator'] = ResNet2D
# cfg_model['block'] = Bottleneck2D
# cfg_model['conv_layers'] = [3, 4, 23, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'max'
# cfg_model['base_channels'] = 64
# cfg_model['n_fft'] = 100
# cfg_model['complex_mode'] = 'as_real' # 'power', 'remove'
# cfg_model['hop_length'] = cfg_model['n_fft'] // 2
# cfg_model['groups'] = 32
# cfg_model['width_per_group'] = 8
# cfg_model['LR'] = 1e-3

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')

# model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# # tensorboard visualization
# # visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

# del model
# cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 64,
 'block': <class 'models.resnet_2d.Bottleneck2D'>,
 'complex_mode': 'as_real',
 'conv_layers': [3, 4, 23, 3],
 'fc_stages': 3,
 'final_pool': 'max',
 'generator': <class 'models.resnet_2d.ResNet2D'>,
 'groups': 32,
 'hop_length': 50,
 'in_channels': 20,
 'model': '2D-ResNeXt-10x',
 'n_fft': 100,
 'out_dims': 3,
 'use_age': 'fc',
 'width_per_group': 8}

----------------------------------------------------------------------------------------------------

ResNet2D(
  (conv1): Conv2d(40, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck2D(
      (conv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine

#### CNN-Transformer

In [20]:
cfg_model = {}
cfg_model.update(cfg_common_model)
cfg_model['model'] = '1D-CNN-Transformer'
cfg_model['generator'] = CNNTransformer
cfg_model['fc_stages'] = 2
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'max'
cfg_model['base_channels'] = 256
cfg_model['n_encoders'] = 4
cfg_model['n_heads'] = 4
cfg_model['dropout'] = 0.2
cfg_model['LR'] = 1e-3

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

# tensorboard visualization
# visualize_network_tensorboard(model, train_loader, device, nb_fname, '1D-Tiny-CNN-fc-age')

del model
cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': 0.001,
 'base_channels': 256,
 'dropout': 0.2,
 'fc_stages': 2,
 'final_pool': 'max',
 'generator': <class 'models.transformer.CNNTransformer'>,
 'in_channels': 20,
 'model': '1D-CNN-Transformer',
 'n_encoders': 4,
 'n_heads': 4,
 'out_dims': 3,
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------

CNNTransformer(
  (conv1): Conv1d(20, 256, kernel_size=(21,), stride=(9,))
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv1d(256, 256, kernel_size=(9,), stride=(3,))
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_featu

#### Summarize the loaded models

In [21]:
for cfg_model in cfg_model_pool:
    pprint.pp(cfg_model, width=150)
    print('\n' + '-' * 100 + '\n')

{'in_channels': 20,
 'out_dims': 3,
 'model': '1D-Mx',
 'generator': <class 'models.simple_cnn_1d.M7'>,
 'fc_stages': 1,
 'use_age': 'fc',
 'final_pool': 'max',
 'base_channels': 256,
 'LR': 0.001}

----------------------------------------------------------------------------------------------------

{'in_channels': 20,
 'out_dims': 3,
 'model': '1D-ResNet-2x',
 'generator': <class 'models.resnet_1d.ResNet1D'>,
 'block': <class 'models.resnet_1d.BottleneckBlock1D'>,
 'conv_layers': [2, 2, 2, 2],
 'fc_stages': 3,
 'use_age': 'fc',
 'final_pool': 'max',
 'base_channels': 64,
 'LR': 0.001}

----------------------------------------------------------------------------------------------------

{'in_channels': 20,
 'out_dims': 3,
 'model': '1D-ResNet-5x',
 'generator': <class 'models.resnet_1d.ResNet1D'>,
 'block': <class 'models.resnet_1d.BottleneckBlock1D'>,
 'conv_layers': [3, 4, 6, 3],
 'fc_stages': 3,
 'use_age': 'fc',
 'final_pool': 'max',
 'base_channels': 64,
 'LR': 0.001}

-----------

-----

## Default Configurations for Training

In [22]:
# training configurations
cfg_train = {}
cfg_train['iterations'] = 100000
cfg_train['history_interval'] = cfg_train['iterations'] // 500
cfg_train['lr_decay_step'] = round(cfg_train['iterations'] * 0.8)
cfg_train['lr_decay_gamma'] = 0.1
cfg_train['weight_decay'] = 1e-2
cfg_train['mixup'] = 0.0 # 0 for no usage
cfg_train['criterion'] = 'cross-entropy' # 'cross-entropy', 'multi-bce'

cfg_train['device'] = device
cfg_train['save_model'] = True
cfg_train['save_temporary'] = False
cfg_train['draw_result'] = True
cfg_train['watch_model'] = True

In [23]:
def train_with_wandb(config, train_loader, val_loader, test_loader, test_loader_longer, class_label_to_type):
    print('*'*120)
    print(f'{"*"*30}{config["model"] + " train starts":^60}{"*"*30}')
    print('*'*120)

    # generate model and its trainer        
    model = config['generator'](**config).to(device)
    
    optimizer = optim.AdamW(model.parameters(), 
                            lr=config['LR'], 
                            weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                          step_size=config['lr_decay_step'], 
                                          gamma=config['lr_decay_gamma'])
    
    tr_ms = train_multistep if config.get('mixup', 0) < 1e-12 else train_mixup_multistep
    
    # track granients and weights statistics
    if config.get('watch_model', None):
        wandb.watch(model, log='all', 
                    log_freq=config['history_interval'], 
                    log_graph=True)

    # train and validation routine
    best_val_acc = 0
    for i in range(0, config["iterations"], config["history_interval"]):
        # train 'history_interval' steps
        loss, train_acc = tr_ms(model, train_loader, optimizer, scheduler, 
                                config, config["history_interval"])

        # validation
        val_acc, _, _, _, _ = check_accuracy(model, val_loader, config, repeat=10)

        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_model_state = deepcopy(model.state_dict())
            if config['save_model'] and config['save_temporary']:
                save_path = f'checkpoint_temp/{wandb.run.name}/'
                os.makedirs(save_path, exist_ok=True)
                path = os.path.join(save_path, f'{config["model"]}')
                torch.save(best_model_state, path)

        # log
        wandb.log({'Loss': loss, 
                   'Train Accuracy': train_acc, 
                   'Validation Accuracy': val_acc}, step=i)

    # calculate the test accuracies for best and last models
    last_model_state = deepcopy(model.state_dict())
    last_test_result = check_accuracy(model, test_loader, config, repeat=30)
    last_test_acc = last_test_result[0]

    model.load_state_dict(best_model_state)
    best_test_result = check_accuracy(model, test_loader, config, repeat=30)
    best_test_acc = best_test_result[0]

    if last_test_acc < best_test_acc:
        model_state = best_model_state
        test_result = best_test_result
    else:
        model_state = last_model_state
        test_result = last_test_result

    model.load_state_dict(model_state)
    test_acc, test_confusion, test_debug, score, target = test_result

    # calculate the test accuracies for final model on much longer sequence
    last_test_result = check_accuracy(model, test_loader_longer, config, repeat=30)
    longer_test_acc = last_test_result[0]

    # save the model
    if config['save_model']:
        save_path = f'checkpoint_temp/{wandb.run.name}/'
        os.makedirs(save_path, exist_ok=True)
        path = os.path.join(save_path, f'{config["model"]}')
        torch.save(model_state, path)

    # leave the message
    wandb.config.final_shape = model.get_final_shape()
    wandb.config.num_params = count_parameters(model)
    wandb.log({'Test Accuracy': test_acc,
               '(Best / Last) Test Accuracy': ('Best' if last_test_acc < best_test_acc else 'Last', 
                                               round(best_test_acc, 2), round(last_test_acc, 2)),
               'Confusion Matrix (Array)': test_confusion,
               'Test Accuracy (Longer)': longer_test_acc, 
               'Test Debug Table/Serial': test_debug[0], 
               'Test Debug Table/EDF': test_debug[1], 
               'Test Debug Table/Pred': test_debug[2], 
               'Test Debug Table/GT': test_debug[3]})

    if 'lr_search' in config:
        draw_learning_rate_record(config['lr_search'], use_wandb=True)

    if config['draw_result']:
        draw_roc_curve(score, target, class_label_to_type, use_wandb=True)
        draw_confusion(test_confusion, class_label_to_type, use_wandb=True)
        draw_debug_table(test_debug, use_wandb=True)
        wandb.log({"Confusion Matrix": wandb.plot.confusion_matrix(y_true=target, 
                                                                   preds=score.argmax(axis=-1), 
                                                                   class_names=class_label_to_type)})
        wandb.log({"ROC Curve": wandb.plot.roc_curve(target, score, labels=class_label_to_type)})
        
    return model

In [24]:
def train_sweep(cfg_data, cfg_train, cfg_model_pool):
    wandb_run = wandb.init()
    wandb.run.name = wandb.run.id
    with wandb_run:
        # wandb config update
        cfg_model = cfg_model_pool[wandb.config.model_index]
        config = {}
        for k, v in {**cfg_data, **cfg_train, **cfg_model}.items():
            if k not in wandb.config:
                config[k] = v
        
        # to prevent callables from type-conversion to str
        wandb.config.update(config)
        for k, v in wandb.config.items():
            if k not in config:
                config[k] = v
                
        # build dataset
        train_loader, val_loader, test_loader, test_loader_longer, class_label_to_type = build_dataset(config)
        
        config['in_channels'] = train_loader.dataset[0]['signal'].shape[0]
        config['out_dims'] = len(class_label_to_type)
        
        # learning rate search if needed
        if config["LR"] is None:
            config['LR'], config['lr_search'] = learning_rate_search(config, train_loader, 
                                                                     min_log_lr=-4.5, max_log_lr=-3.0, 
                                                                     trials=100, steps=100)
																	 
        # train the model
        train_with_wandb(config, train_loader, val_loader, test_loader, test_loader_longer, class_label_to_type)

-----

## Train

In [25]:
sweep_data = {}
sweep_data['crop_length'] = {
    'values': [200 * 10, # 10 sec
               200 * 20, # 20 sec
              ],
}

sweep_data['EKG'] = {
    'values': ['O', 'X'],
}

sweep_data['photic'] = {
    'values': ['O', 'X'],
}

sweep_data['awgn'] = {
    'distribution': 'uniform',
    'min': 0,
    'max': 0.3,
}

sweep_data['awgn_age'] = {
    'distribution': 'uniform',
    'min': 0,
    'max': 0.3,
}

sweep_data['minibatch'] = {
    'values': [32, ],
}

In [26]:
sweep_model = {}
sweep_model['model_index'] = { 
    'values' : [i for i in range(len(cfg_model_pool))] 
}

sweep_model['fc_stages'] = { 
    'distribution' : 'int_uniform',
    'min': 0,
    'max': 4,
}

sweep_model['use_age'] = { 
    'values' : ['fc', 'conv']
}

sweep_model['final_pool'] = { 
    'values' : ['max', 'average']
}

sweep_model['first_dilation'] = { 
    'distribution' : 'int_uniform',
    'min': 1,
    'max': 2,
}

sweep_model['base_stride'] = { 
    'distribution' : 'int_uniform',
    'min': 2,
    'max': 4,
}

sweep_model['dropout'] = {
    'distribution': 'uniform',
    'min': 0.0,
    'max': 0.5
}

sweep_model['LR'] = {
    'distribution': 'log_uniform',
    'min': math.log(5e-5),
    'max': math.log(1e-3)
}

In [27]:
sweep_train = {}
sweep_train['iterations'] = {
    'values' : [100000, 150000]
}

sweep_train['lr_decay_gamma'] = {
    'distribution' : 'uniform',
    'min': 0.1,
    'max': 0.5,
}

sweep_train['lr_decay_step'] = {
    'values' : [45000, 80000]
}

sweep_train['weight_decay'] = {
    'distribution' : 'log_uniform',
    'min': math.log(1e-5),
    'max': math.log(1e-1)
}

sweep_train['mixup'] = {
    'values': [0, 0.15, 0.3]
}

sweep_train['criterion'] = {
    'values': ['cross-entropy', 'multi-bce']
}

In [28]:
sweep_config = {
    "entity": "ipis-mjkim",
    "name" : "my-sweep",
    "method" : "random",
    "parameters" : 
    {
        **sweep_data,
        **sweep_model,
        **sweep_train,
    }
}

sweep_id = wandb.sweep(sweep_config, project="eeg-analysis")

Create sweep with ID: p7oh32tt
Sweep URL: https://wandb.ai/ipis-mjkim/eeg-analysis/sweeps/p7oh32tt


In [29]:
wandb.agent(sweep_id, function=lambda: train_sweep(cfg_data, cfg_train, cfg_model_pool), count=10)

[34m[1mwandb[0m: Agent Starting Run: 3rtbt315 with config:
[34m[1mwandb[0m: 	EKG: X
[34m[1mwandb[0m: 	LR: 0.0004662746114238844
[34m[1mwandb[0m: 	awgn: 0.18027900945527034
[34m[1mwandb[0m: 	awgn_age: 0.019514074271718484
[34m[1mwandb[0m: 	base_stride: 2
[34m[1mwandb[0m: 	criterion: multi-bce
[34m[1mwandb[0m: 	crop_length: 2000
[34m[1mwandb[0m: 	dropout: 0.38915356956165076
[34m[1mwandb[0m: 	fc_stages: 2
[34m[1mwandb[0m: 	final_pool: average
[34m[1mwandb[0m: 	first_dilation: 2
[34m[1mwandb[0m: 	iterations: 150000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.24773656195738725
[34m[1mwandb[0m: 	lr_decay_step: 45000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0
[34m[1mwandb[0m: 	model_index: 10
[34m[1mwandb[0m: 	photic: O
[34m[1mwandb[0m: 	use_age: conv
[34m[1mwandb[0m: 	weight_decay: 0.08416659781989794
[34m[1mwandb[0m: Currently logged in as: [33mipis-mjkim[0m (use `wandb login --relogin` to force relogin)


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


************************************************************************************************************************
******************************              1D-CNN-Transformer train starts               ******************************
************************************************************************************************************************


VBox(children=(Label(value=' 0.83MB of 0.83MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▇▆▆▆▅▅▅▅▅▅▄▃▃▃▂▃▇▆▇▅▅▆▇▅▄▃▃▂▂▂▂▁▂▂▂▁▁▁▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▂▂▃▃▃▄▄▄▄▄▄▆▆▆▇▆▂▃▃▄▄▃▂▄▄▆▆▇▇▇▇█▇▇▇████
Validation Accuracy,▇▇▆▇▆▃█▄▇█▆▇▆▆▇▇█▇▆▅▆▅▁▄▇▇████▇▇█████▇█▃

0,1
Loss,0.40343
Test Accuracy,61.76282
Test Accuracy (Longer),57.82051
Train Accuracy,71.70312
Validation Accuracy,60.38462


[34m[1mwandb[0m: Agent Starting Run: z5b3ae0v with config:
[34m[1mwandb[0m: 	EKG: O
[34m[1mwandb[0m: 	LR: 1.927334146053636e-05
[34m[1mwandb[0m: 	awgn: 0.1666430308611518
[34m[1mwandb[0m: 	awgn_age: 0.19681091986729177
[34m[1mwandb[0m: 	base_stride: 2
[34m[1mwandb[0m: 	criterion: multi-bce
[34m[1mwandb[0m: 	crop_length: 2000
[34m[1mwandb[0m: 	dropout: 0.3645707616721826
[34m[1mwandb[0m: 	fc_stages: 1
[34m[1mwandb[0m: 	final_pool: average
[34m[1mwandb[0m: 	first_dilation: 2
[34m[1mwandb[0m: 	iterations: 150000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.10680692494469488
[34m[1mwandb[0m: 	lr_decay_step: 45000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0.15
[34m[1mwandb[0m: 	model_index: 3
[34m[1mwandb[0m: 	photic: O
[34m[1mwandb[0m: 	use_age: fc
[34m[1mwandb[0m: 	weight_decay: 0.020767396868719146


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


************************************************************************************************************************
******************************                 1D-ResNet-2x train starts                  ******************************
************************************************************************************************************************


  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


VBox(children=(Label(value=' 0.79MB of 0.79MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▇▆▅▅▄▃▃▃▂▂▂▁▁▁▁▂▂▂▂▁▂▂▁▁▁▁▂▁▁▂▁▁▁▁▁▁▂▂▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▂▃▄▅▆▆▆▇▇▇▇████████████████████████████
Validation Accuracy,▃▄▅▃▄▃▄▆▆█▆▁▇▆▆▅▅▆▄▆▇▆▇▅▆▆▆▆▅▆▇▅█▆▇▆▆▆▇▇

0,1
Loss,0.13572
Test Accuracy,59.61538
Test Accuracy (Longer),61.82692
Train Accuracy,93.39234
Validation Accuracy,58.26923


[34m[1mwandb[0m: Agent Starting Run: 1nszp2pu with config:
[34m[1mwandb[0m: 	EKG: O
[34m[1mwandb[0m: 	LR: 0.00015913197126607632
[34m[1mwandb[0m: 	awgn: 0.24974076941624496
[34m[1mwandb[0m: 	awgn_age: 0.2828054426976972
[34m[1mwandb[0m: 	base_stride: 3
[34m[1mwandb[0m: 	criterion: multi-bce
[34m[1mwandb[0m: 	crop_length: 4000
[34m[1mwandb[0m: 	dropout: 0.2541319034759994
[34m[1mwandb[0m: 	fc_stages: 2
[34m[1mwandb[0m: 	final_pool: average
[34m[1mwandb[0m: 	first_dilation: 1
[34m[1mwandb[0m: 	iterations: 150000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.4813210681030561
[34m[1mwandb[0m: 	lr_decay_step: 45000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0.15
[34m[1mwandb[0m: 	model_index: 6
[34m[1mwandb[0m: 	photic: X
[34m[1mwandb[0m: 	use_age: fc
[34m[1mwandb[0m: 	weight_decay: 0.0068031957257943245


************************************************************************************************************************
******************************                1D-ResNeXt-10x train starts                 ******************************
************************************************************************************************************************


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


VBox(children=(Label(value=' 0.81MB of 0.81MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▇▆▃▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▂▄▇▇▇▇████▇████████████████████████████
Validation Accuracy,▁▁▂▄▂▁▅▅▄▅█▄▄▅▅▄▃▆▅▆▂▄▅▆▅▄▄▄▆▅▅▅▄▄▅▆▅▆▅▆

0,1
Loss,0.11099
Test Accuracy,62.85256
Test Accuracy (Longer),63.75
Train Accuracy,94.38956
Validation Accuracy,60.09615


[34m[1mwandb[0m: Agent Starting Run: tmxlmto3 with config:
[34m[1mwandb[0m: 	EKG: X
[34m[1mwandb[0m: 	LR: 1.9022751778444847e-05
[34m[1mwandb[0m: 	awgn: 0.21450718138003624
[34m[1mwandb[0m: 	awgn_age: 0.2936510981269041
[34m[1mwandb[0m: 	base_stride: 4
[34m[1mwandb[0m: 	criterion: multi-bce
[34m[1mwandb[0m: 	crop_length: 4000
[34m[1mwandb[0m: 	dropout: 0.3017033749245153
[34m[1mwandb[0m: 	fc_stages: 3
[34m[1mwandb[0m: 	final_pool: average
[34m[1mwandb[0m: 	first_dilation: 1
[34m[1mwandb[0m: 	iterations: 100000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.16158844304584635
[34m[1mwandb[0m: 	lr_decay_step: 45000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0.3
[34m[1mwandb[0m: 	model_index: 7
[34m[1mwandb[0m: 	photic: O
[34m[1mwandb[0m: 	use_age: fc
[34m[1mwandb[0m: 	weight_decay: 4.6350867168959736e-05


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


************************************************************************************************************************
******************************                 2D-ResNet-2x train starts                  ******************************
************************************************************************************************************************


VBox(children=(Label(value=' 0.79MB of 0.79MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▂▂▂▁▂▁▁▂▁▂▁▂▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▂▃▃▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█▇█▇███████▇███
Validation Accuracy,▁▁▅▆█▄▇▂▃▆▆▆▅▅▅▄▄▃▅▄▄▅▅▆▃▅▆▅▃▃▃▄▁▃▄▄▅▂▃▃

0,1
Loss,0.35279
Test Accuracy,64.45513
Test Accuracy (Longer),66.41026
Train Accuracy,78.5195
Validation Accuracy,52.5


[34m[1mwandb[0m: Agent Starting Run: 9xak85tq with config:
[34m[1mwandb[0m: 	EKG: O
[34m[1mwandb[0m: 	LR: 0.00027470626593701644
[34m[1mwandb[0m: 	awgn: 0.2373224667067662
[34m[1mwandb[0m: 	awgn_age: 0.23081870794648376
[34m[1mwandb[0m: 	base_stride: 4
[34m[1mwandb[0m: 	criterion: cross-entropy
[34m[1mwandb[0m: 	crop_length: 4000
[34m[1mwandb[0m: 	dropout: 0.2220756752267376
[34m[1mwandb[0m: 	fc_stages: 1
[34m[1mwandb[0m: 	final_pool: average
[34m[1mwandb[0m: 	first_dilation: 2
[34m[1mwandb[0m: 	iterations: 100000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.49814726277294463
[34m[1mwandb[0m: 	lr_decay_step: 80000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0
[34m[1mwandb[0m: 	model_index: 8
[34m[1mwandb[0m: 	photic: X
[34m[1mwandb[0m: 	use_age: fc
[34m[1mwandb[0m: 	weight_decay: 0.00011833394967849464


************************************************************************************************************************
******************************                 2D-ResNet-5x train starts                  ******************************
************************************************************************************************************************


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


VBox(children=(Label(value=' 0.84MB of 0.84MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▇▆▅▄▄▃▃▃▃▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▃▄▄▅▆▆▆▇▇▇▇▇▇▇█████████████████████████
Validation Accuracy,▁▇▆▄▄▂▄▅▅▁▂▂▂▃█▆▃▅▅▆▅▆▄▂▆▅▇▁▆▆▃▄▅▃▄▃▂▃▁▅

0,1
Loss,0.01132
Test Accuracy,59.42308
Test Accuracy (Longer),59.16667
Train Accuracy,99.6875
Validation Accuracy,53.94231


[34m[1mwandb[0m: Agent Starting Run: tcvpp4xu with config:
[34m[1mwandb[0m: 	EKG: X
[34m[1mwandb[0m: 	LR: 1.9799656651249375e-05
[34m[1mwandb[0m: 	awgn: 0.2251956146685096
[34m[1mwandb[0m: 	awgn_age: 0.06846800908996704
[34m[1mwandb[0m: 	base_stride: 4
[34m[1mwandb[0m: 	criterion: multi-bce
[34m[1mwandb[0m: 	crop_length: 2000
[34m[1mwandb[0m: 	dropout: 0.11189508117406775
[34m[1mwandb[0m: 	fc_stages: 1
[34m[1mwandb[0m: 	final_pool: average
[34m[1mwandb[0m: 	first_dilation: 2
[34m[1mwandb[0m: 	iterations: 150000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.4474714285290865
[34m[1mwandb[0m: 	lr_decay_step: 80000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0.15
[34m[1mwandb[0m: 	model_index: 4
[34m[1mwandb[0m: 	photic: X
[34m[1mwandb[0m: 	use_age: conv
[34m[1mwandb[0m: 	weight_decay: 0.0014527363539353524


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


************************************************************************************************************************
******************************                 1D-ResNet-1x train starts                  ******************************
************************************************************************************************************************


VBox(children=(Label(value=' 0.84MB of 0.84MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▂▃▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██▇█▇██▇▇██████
Validation Accuracy,▅▇▇█▇▅▇▆▇▃▂▅▃▇▅▃▃▅▅▁▄▄▅▆▇▆▆▆▄▅▆▆▆▇▅▇▇▆▆▇

0,1
Loss,0.18957
Test Accuracy,58.91026
Test Accuracy (Longer),59.48718
Train Accuracy,91.1648
Validation Accuracy,53.46154


[34m[1mwandb[0m: Agent Starting Run: rnyr6o73 with config:
[34m[1mwandb[0m: 	EKG: O
[34m[1mwandb[0m: 	LR: 2.524966533535892e-05
[34m[1mwandb[0m: 	awgn: 0.079806973942033
[34m[1mwandb[0m: 	awgn_age: 0.014524014087147008
[34m[1mwandb[0m: 	base_stride: 2
[34m[1mwandb[0m: 	criterion: cross-entropy
[34m[1mwandb[0m: 	crop_length: 4000
[34m[1mwandb[0m: 	dropout: 0.3651749280879526
[34m[1mwandb[0m: 	fc_stages: 1
[34m[1mwandb[0m: 	final_pool: average
[34m[1mwandb[0m: 	first_dilation: 1
[34m[1mwandb[0m: 	iterations: 150000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.22604379050352763
[34m[1mwandb[0m: 	lr_decay_step: 80000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0.15
[34m[1mwandb[0m: 	model_index: 0
[34m[1mwandb[0m: 	photic: O
[34m[1mwandb[0m: 	use_age: fc
[34m[1mwandb[0m: 	weight_decay: 1.2845878600197326e-05


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


************************************************************************************************************************
******************************                     1D-Mx train starts                     ******************************
************************************************************************************************************************


[34m[1mwandb[0m: [32m[41mERROR[0m Error while calling W&B API: Error 1040: Too many connections (<Response [500]>)
[34m[1mwandb[0m: Network error (HTTPError), entering retry loop.


VBox(children=(Label(value=' 0.83MB of 0.83MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▆▄▃▂▃▂▂▂▁▂▂▁▂▂▁▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▂
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▄▅▇▇▇▇▇▇█▇██▇▇██████████▇██████████████
Validation Accuracy,▃▁▆▄▃█▇█▅▆▅▆█▆▇▇▆█▇▆▆▇▇▇▇▆█▇▇▇█▇▆▇██▇█▇▇

0,1
Loss,0.15442
Test Accuracy,62.72436
Test Accuracy (Longer),64.32692
Train Accuracy,95.28241
Validation Accuracy,57.21154


[34m[1mwandb[0m: Agent Starting Run: gc6vg8ac with config:
[34m[1mwandb[0m: 	EKG: X
[34m[1mwandb[0m: 	LR: 5.1223082636476106e-05
[34m[1mwandb[0m: 	awgn: 0.07523482506936899
[34m[1mwandb[0m: 	awgn_age: 0.17482965196475647
[34m[1mwandb[0m: 	base_stride: 2
[34m[1mwandb[0m: 	criterion: cross-entropy
[34m[1mwandb[0m: 	crop_length: 2000
[34m[1mwandb[0m: 	dropout: 0.05966831761549968
[34m[1mwandb[0m: 	fc_stages: 0
[34m[1mwandb[0m: 	final_pool: max
[34m[1mwandb[0m: 	first_dilation: 1
[34m[1mwandb[0m: 	iterations: 150000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.4472528136522146
[34m[1mwandb[0m: 	lr_decay_step: 45000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0.15
[34m[1mwandb[0m: 	model_index: 7
[34m[1mwandb[0m: 	photic: O
[34m[1mwandb[0m: 	use_age: fc
[34m[1mwandb[0m: 	weight_decay: 0.0008915401727405024


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


************************************************************************************************************************
******************************                 2D-ResNet-2x train starts                  ******************************
************************************************************************************************************************


VBox(children=(Label(value=' 0.85MB of 0.85MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▇▆▆▅▅▅▄▄▄▄▃▃▃▂▂▃▂▂▂▂▂▂▂▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▂▃▄▄▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█▇██████████████
Validation Accuracy,▁▅▆█▃▅▃▆▃▃▄▅▅▄▄▅▄▅▃▃▄▂▄▄▄▃▅▅▄▃▃▃▃▄▂▂▃▁▄▄

0,1
Loss,0.24632
Test Accuracy,62.53205
Test Accuracy (Longer),65.73718
Train Accuracy,91.67467
Validation Accuracy,52.30769


[34m[1mwandb[0m: Agent Starting Run: 5ols3kdc with config:
[34m[1mwandb[0m: 	EKG: X
[34m[1mwandb[0m: 	LR: 1.2255655843640337e-05
[34m[1mwandb[0m: 	awgn: 0.19019748906598336
[34m[1mwandb[0m: 	awgn_age: 0.29552546857876577
[34m[1mwandb[0m: 	base_stride: 4
[34m[1mwandb[0m: 	criterion: cross-entropy
[34m[1mwandb[0m: 	crop_length: 4000
[34m[1mwandb[0m: 	dropout: 0.0807983200151079
[34m[1mwandb[0m: 	fc_stages: 2
[34m[1mwandb[0m: 	final_pool: average
[34m[1mwandb[0m: 	first_dilation: 1
[34m[1mwandb[0m: 	iterations: 150000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.21792923960364663
[34m[1mwandb[0m: 	lr_decay_step: 80000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0
[34m[1mwandb[0m: 	model_index: 8
[34m[1mwandb[0m: 	photic: O
[34m[1mwandb[0m: 	use_age: fc
[34m[1mwandb[0m: 	weight_decay: 0.017318972727029834


************************************************************************************************************************
******************************                 2D-ResNet-5x train starts                  ******************************
************************************************************************************************************************


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


VBox(children=(Label(value=' 0.89MB of 0.89MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇▇▇▇███████████████
Validation Accuracy,▆▅▆▅▆▆▆▆█▆▆▆█▇▇▅▄▆▆▃▄▂▂▃▁▃▂▄▂▃▂▃▃▄▃▂▄▅▅▅

0,1
Loss,0.14033
Test Accuracy,56.34615
Test Accuracy (Longer),56.02564
Train Accuracy,95.14062
Validation Accuracy,49.23077


[34m[1mwandb[0m: Agent Starting Run: 7s9a9enm with config:
[34m[1mwandb[0m: 	EKG: O
[34m[1mwandb[0m: 	LR: 0.0002960940564321757
[34m[1mwandb[0m: 	awgn: 0.25883228515539264
[34m[1mwandb[0m: 	awgn_age: 0.15994726321766028
[34m[1mwandb[0m: 	base_stride: 2
[34m[1mwandb[0m: 	criterion: cross-entropy
[34m[1mwandb[0m: 	crop_length: 2000
[34m[1mwandb[0m: 	dropout: 0.4686601107588504
[34m[1mwandb[0m: 	fc_stages: 0
[34m[1mwandb[0m: 	final_pool: max
[34m[1mwandb[0m: 	first_dilation: 2
[34m[1mwandb[0m: 	iterations: 150000
[34m[1mwandb[0m: 	lr_decay_gamma: 0.10591285095852544
[34m[1mwandb[0m: 	lr_decay_step: 80000
[34m[1mwandb[0m: 	minibatch: 32
[34m[1mwandb[0m: 	mixup: 0
[34m[1mwandb[0m: 	model_index: 6
[34m[1mwandb[0m: 	photic: X
[34m[1mwandb[0m: 	use_age: fc
[34m[1mwandb[0m: 	weight_decay: 0.012047708105430491


************************************************************************************************************************
******************************                1D-ResNeXt-10x train starts                 ******************************
************************************************************************************************************************


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


VBox(children=(Label(value=' 0.91MB of 0.91MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss,█▇▇▇▆▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Test Accuracy,▁
Test Accuracy (Longer),▁
Train Accuracy,▁▂▂▃▄▄▆▇▇▇██████████████████████████████
Validation Accuracy,▄█▄█▄▇▁▁▆▆▆▇▃▂▅▃▄▁▂▄▂▄▃▁▃▃▃▃▃▃▄▆▂▄▂▄▄▃▄▃

0,1
Loss,0.00312
Test Accuracy,59.80769
Test Accuracy (Longer),58.33333
Train Accuracy,99.92188
Validation Accuracy,51.73077
