# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal** and **Abnormal** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
- `Weights and Biases` sweep is used for hyperparameter search

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import json
from copy import deepcopy

import hydra
from omegaconf import OmegaConf
import yaml
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

## Environment

In [3]:
data_cfg_file = 'task2'
device = 'cuda:0'
sweep_yaml_path = 'config/sweep/sweep_task2.yaml'
count = 3

## Sweep configurations

In [4]:
with open(sweep_yaml_path) as f:
    sweep_yaml = yaml.load(f, Loader=yaml.FullLoader)

sweep_yaml.pop('command')
pprint.pprint(sweep_yaml)

{'entity': 'ipis-mjkim',
 'method': 'random',
 'name': 'sweep-task2',
 'parameters': {'data.EKG': {'values': ['O', 'X']},
                'data.awgn': {'distribution': 'uniform', 'max': 0.12, 'min': 0},
                'data.awgn_age': {'distribution': 'uniform',
                                  'max': 0.3,
                                  'min': 0},
                'data.mgn': {'distribution': 'uniform', 'max': 0.1, 'min': 0},
                'data.photic': {'values': ['O', 'X']},
                'data.seq_length': {'values': [1000, 2000, 3000, 4000]},
                'model': {'values': ['1D-VGG-16',
                                     '1D-VGG-19',
                                     '1D-ResNet-18',
                                     '1D-ResNet-50',
                                     '1D-ResNeXt-50',
                                     '1D-ResNeXt-101',
                                     '1D-Wide-ResNet-50',
                                     '1D-CNN-Transformer',
      

In [5]:
sweep_id = wandb.sweep(sweep_yaml, project=sweep_yaml['project'])

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: b2u16zn2
Sweep URL: https://wandb.ai/ipis-mjkim/sweep-test/sweeps/b2u16zn2


## Train

In [6]:
def train_sweep():
    wandb_run = wandb.init()
    wandb.run.name = wandb.run.id

    with wandb_run:
        # init hydra
        with hydra.initialize(config_path='../config'):
            add_configs = [f"data={data_cfg_file}", 
                           f"model={wandb.config.model}",
                           f"++train.device={device}",]
                           # f"++train.total_samples=200", f"++train.search_lr=false", f"++train.num_history=1"]  # for test

            cfg = hydra.compose(config_name='default', overrides=add_configs)

            cfg_default = {**OmegaConf.to_container(cfg.data), 
                           **OmegaConf.to_container(cfg.train),
                           **OmegaConf.to_container(cfg.model)}
            
            check_device_env(cfg_default)
            pprint.pprint(cfg_default)    

        # load default configurations not selected by wandb.sweep
        config = {}
        for k, v in cfg_default.items():
            if k not in [wandb_key.split('.')[-1] for wandb_key in wandb.config.keys()]:
                config[k] = v

        # load the selected configurations from wandb sweep with preventing callables from type-conversion to str
        for k, v in wandb.config.items():
            k = k.split('.')[-1]
            if k not in config:
                config[k] = v

        # build the dataset and train the model
        if config.get('ddp', False):
            mp.spawn(prepare_and_run_train,
                     args=(config['ddp_size'], config,),
                     nprocs=config['ddp_size'],
                     join=True)
        else:
            prepare_and_run_train(rank=None, world_size=None, config=config)

In [None]:
wandb.agent(sweep_id, function=train_sweep, count=count)

[34m[1mwandb[0m: Agent Starting Run: b8vfjkxv with config:
[34m[1mwandb[0m: 	data.EKG: X
[34m[1mwandb[0m: 	data.awgn: 0.008231824375577354
[34m[1mwandb[0m: 	data.awgn_age: 0.2840872646906441
[34m[1mwandb[0m: 	data.mgn: 5.618596001919496e-05
[34m[1mwandb[0m: 	data.photic: O
[34m[1mwandb[0m: 	data.seq_length: 3000
[34m[1mwandb[0m: 	model: 1D-VGG-19
[34m[1mwandb[0m: 	model.activation: gelu
[34m[1mwandb[0m: 	model.dropout: 0.1405113964207221
[34m[1mwandb[0m: 	model.fc_stages: 3
[34m[1mwandb[0m: 	model.use_age: fc
[34m[1mwandb[0m: 	train.criterion: cross-entropy
[34m[1mwandb[0m: 	train.lr_scheduler_type: linear_decay_with_warmup
[34m[1mwandb[0m: 	train.mixup: 0.2
[34m[1mwandb[0m: 	train.num_history: 5
[34m[1mwandb[0m: 	train.search_lr: False
[34m[1mwandb[0m: 	train.total_samples: 1000
[34m[1mwandb[0m: 	train.weight_decay: 0.016962071926639524
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_N

{'EKG': 'O',
 '_target_': 'models.vgg_1d.VGG1D',
 'activation': 'mish',
 'awgn': 0.05,
 'awgn_age': 0.05,
 'base_lr': 0.0001,
 'batch_norm': True,
 'criterion': 'cross-entropy',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'dropout': 0.3,
 'fc_stages': 3,
 'file_format': 'memmap',
 'in_channels': '???',
 'input_norm': 'dataset',
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'constant_with_decay',
 'mgn': 0.0001,
 'minibatch': 160,
 'mixup': 0.0,
 'model': '1D-VGG-19',
 'num_history': 500,
 'out_dims': '???',
 'photic': 'X',
 'run_mode': 'train',
 'save_model': True,
 'search_lr': True,
 'search_multiplier': 1.0,
 'seq_length': 2000,
 'signal_length_limit': 10000000,
 'task': 'task2',
 'test_crop_multiple': 8,
 'total_samples': 35000000.0,
 'use_age': 'fc',
 'warmup_min': 3000,
 'warmup_ratio': 0.05,
 'watch_model': True,
 'weig

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


VBox(children=(Label(value='0.813 MB of 0.813 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Learning Rate,▁▂▄▅▇█
Loss,▅█▇▅▃▁
Multi-Crop Test Accuracy,▁
Test Accuracy,▁
Train Accuracy,▆▇▂▂█▁
Validation Accuracy,▇█▂▂▁▁

0,1
Learning Rate,0.0
Loss,1.09118
Multi-Crop Test Accuracy,33.33333
Test Accuracy,47.01271
Train Accuracy,33.12499
Validation Accuracy,46.47059


[34m[1mwandb[0m: Agent Starting Run: qurpdqpw with config:
[34m[1mwandb[0m: 	data.EKG: X
[34m[1mwandb[0m: 	data.awgn: 0.11399776695701888
[34m[1mwandb[0m: 	data.awgn_age: 0.02551916907612511
[34m[1mwandb[0m: 	data.mgn: 0.0941359708828874
[34m[1mwandb[0m: 	data.photic: O
[34m[1mwandb[0m: 	data.seq_length: 4000
[34m[1mwandb[0m: 	model: 1D-ResNet-18
[34m[1mwandb[0m: 	model.activation: relu
[34m[1mwandb[0m: 	model.dropout: 0.21478873596450715
[34m[1mwandb[0m: 	model.fc_stages: 3
[34m[1mwandb[0m: 	model.use_age: fc
[34m[1mwandb[0m: 	train.criterion: cross-entropy
[34m[1mwandb[0m: 	train.lr_scheduler_type: constant_with_decay
[34m[1mwandb[0m: 	train.mixup: 0.3
[34m[1mwandb[0m: 	train.num_history: 5
[34m[1mwandb[0m: 	train.search_lr: False
[34m[1mwandb[0m: 	train.total_samples: 1000
[34m[1mwandb[0m: 	train.weight_decay: 0.0006325576279088609
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME

{'EKG': 'O',
 '_target_': 'models.resnet_1d.ResNet1D',
 'activation': 'mish',
 'awgn': 0.05,
 'awgn_age': 0.05,
 'base_channels': 64,
 'base_lr': 0.0001,
 'block': 'basic',
 'conv_layers': [2, 2, 2, 2],
 'criterion': 'cross-entropy',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'dropout': 0.1,
 'fc_stages': 3,
 'file_format': 'memmap',
 'in_channels': '???',
 'input_norm': 'dataset',
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'constant_with_decay',
 'mgn': 0.0001,
 'minibatch': 160,
 'mixup': 0.0,
 'model': '1D-ResNet-18',
 'num_history': 500,
 'out_dims': '???',
 'photic': 'X',
 'run_mode': 'train',
 'save_model': True,
 'search_lr': True,
 'search_multiplier': 1.0,
 'seq_length': 2000,
 'signal_length_limit': 10000000,
 'task': 'task2',
 'test_crop_multiple': 8,
 'total_samples': 35000000.0,
 'use_age': 'fc',
 'warmup_min'

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
