# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal** and **Abnormal** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
- `Weights and Biases` sweep is used for hyperparameter search

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import json
from copy import deepcopy

import hydra
from omegaconf import OmegaConf
import yaml
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train
from train.utils import load_sweep_config

## Environment

In [3]:
data_cfg_file = 'task2'
train_cfg_file = 'tiny_train'
device = 'cuda:0'
sweep_yaml_path = 'config/sweep/sweep_task2.yaml'
count = 1

## Sweep configurations

In [4]:
with open(sweep_yaml_path) as f:
    sweep_yaml = yaml.load(f, Loader=yaml.FullLoader)

sweep_yaml.pop('command')
pprint.pprint(sweep_yaml)

{'entity': 'ipis-mjkim',
 'method': 'random',
 'parameters': {'data.EKG': {'values': ['O', 'X']},
                'data.awgn': {'distribution': 'uniform', 'max': 0.12, 'min': 0},
                'data.awgn_age': {'distribution': 'uniform',
                                  'max': 0.3,
                                  'min': 0},
                'data.mgn': {'distribution': 'uniform', 'max': 0.1, 'min': 0},
                'data.photic': {'values': ['O', 'X']},
                'data.seq_length': {'values': [1000, 2000, 3000, 4000]},
                'model': {'values': ['1D-VGG-16',
                                     '1D-VGG-19',
                                     '1D-ResNet-18',
                                     '1D-ResNet-50',
                                     '1D-ResNeXt-50',
                                     '1D-ResNeXt-101',
                                     '1D-Wide-ResNet-50',
                                     '1D-CNN-Transformer',
                              

In [5]:
sweep_id = wandb.sweep(sweep_yaml, project=sweep_yaml['project'])

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: i57vsv26
Sweep URL: https://wandb.ai/ipis-mjkim/sweep-test/sweeps/i57vsv26


---

## Train

In [6]:
def train_sweep():
    # connect wandb and get the sweep parameters
    wandb_run = wandb.init(reinit=True)
    wandb.run.name = wandb.run.id
    
    with hydra.initialize(config_path='../config'):
        add_configs = [f"data={data_cfg_file}", 
                       f"train={train_cfg_file}", 
                       f"model={wandb.config.model}",
                       f"+train.device={device}",]
        cfg = hydra.compose(config_name='default', overrides=add_configs)
        config = {**OmegaConf.to_container(cfg.data), 
                  **OmegaConf.to_container(cfg.train), 
                  **OmegaConf.to_container(cfg.model)}
        
    config = load_sweep_config(config)
    check_device_env(config)
    pprint.pprint(config)

    # build the dataset and train the model
    if config.get('ddp', False):
        mp.spawn(prepare_and_run_train,
                 args=(config['ddp_size'], config,),
                 nprocs=config['ddp_size'],
                 join=True)
    else:
        prepare_and_run_train(rank=None, world_size=None, config=config)

In [7]:
wandb.agent(sweep_id, function=train_sweep, count=count)

[34m[1mwandb[0m: Agent Starting Run: 355xewhy with config:
[34m[1mwandb[0m: 	data.EKG: X
[34m[1mwandb[0m: 	data.awgn: 0.05593757017042083
[34m[1mwandb[0m: 	data.awgn_age: 0.26182409123147676
[34m[1mwandb[0m: 	data.mgn: 0.004242764139073519
[34m[1mwandb[0m: 	data.photic: X
[34m[1mwandb[0m: 	data.seq_length: 3000
[34m[1mwandb[0m: 	model: 1D-ResNeXt-50
[34m[1mwandb[0m: 	model.activation: relu
[34m[1mwandb[0m: 	model.dropout: 0.4845203210781028
[34m[1mwandb[0m: 	model.fc_stages: 4
[34m[1mwandb[0m: 	model.use_age: fc
[34m[1mwandb[0m: 	train.criterion: cross-entropy
[34m[1mwandb[0m: 	train.lr_scheduler_type: cosine_decay_with_warmup_half
[34m[1mwandb[0m: 	train.mixup: 0
[34m[1mwandb[0m: 	train.weight_decay: 0.014934618498429425
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mipis-mjkim[0m. Use [1m`wandb

{'EKG': 'X',
 '_target_': 'models.resnet_1d.ResNet1D',
 'activation': 'relu',
 'awgn': 0.05593757017042083,
 'awgn_age': 0.26182409123147676,
 'base_channels': 64,
 'base_lr': 0.0001,
 'block': 'bottleneck',
 'conv_layers': [3, 4, 6, 3],
 'criterion': 'cross-entropy',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'dropout': 0.4845203210781028,
 'fc_stages': 4,
 'file_format': 'memmap',
 'groups': 32,
 'in_channels': '???',
 'input_norm': 'dataset',
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'cosine_decay_with_warmup_half',
 'mgn': 0.004242764139073519,
 'minibatch': 160,
 'mixup': 0,
 'model': '1D-ResNeXt-50',
 'num_history': 2,
 'out_dims': '???',
 'photic': 'X',
 'run_mode': 'train',
 'save_model': False,
 'search_lr': False,
 'search_multiplier': 1.0,
 'seed': 0,
 'seq_length': 3000,
 'signal_length_limit': 10000000,
 'tas



VBox(children=(Label(value='0.034 MB of 0.034 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…



VBox(children=(Label(value='0.117 MB of 0.156 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.750024…

0,1
Learning Rate,▁▄█
Loss,██▁
Multi-Crop Test Accuracy,▁
Test Accuracy,▁
Train Accuracy,▁▂█
Validation Accuracy,█▁▁

0,1
Learning Rate,0.0
Loss,1.20239
Multi-Crop Test Accuracy,50.0
Test Accuracy,33.9548
Train Accuracy,35.625
Validation Accuracy,26.13445
