# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal** and **Abnormal** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import json
from copy import deepcopy

import hydra
from omegaconf import OmegaConf
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

---

## Specify the dataset, model, and train setting

In [3]:
project = 'sweep-test'
data_cfg_file = 'task2'
train_cfg_file = 'base_train'
model_cfg_file = '1D-ResNet-101'
device = 'cuda:0'

---

## Initializing configurations using Hydra

In [4]:
with hydra.initialize(config_path='../config'):
    add_configs = [f"data={data_cfg_file}", 
                   f"train={train_cfg_file}", 
                   f"+train.device={device}", 
                   f"model={model_cfg_file}",]
#                   f"++train.total_samples=200", f"++train.search_lr=false", f"++train.num_history=1"]  # for test
    
    cfg = hydra.compose(config_name='default', overrides=add_configs)
    
cfg_default = {**OmegaConf.to_container(cfg.data), 
               **OmegaConf.to_container(cfg.train),
               **OmegaConf.to_container(cfg.model)}

check_device_env(cfg_default)
pprint.pprint(cfg_default)

{'EKG': 'O',
 '_target_': 'models.resnet_1d.ResNet1D',
 'activation': 'mish',
 'awgn': 0.05,
 'awgn_age': 0.05,
 'base_channels': 64,
 'base_lr': 0.0001,
 'block': 'bottleneck',
 'conv_layers': [3, 4, 23, 3],
 'criterion': 'cross-entropy',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'dropout': 0.1,
 'fc_stages': 3,
 'file_format': 'memmap',
 'in_channels': '???',
 'input_norm': 'dataset',
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'constant_with_decay',
 'mgn': 0.0001,
 'minibatch': 160,
 'mixup': 0.0,
 'model': '1D-ResNet-101',
 'num_history': 1,
 'out_dims': '???',
 'photic': 'X',
 'run_mode': 'train',
 'save_model': True,
 'search_lr': False,
 'search_multiplier': 1.0,
 'seq_length': 2000,
 'signal_length_limit': 10000000,
 'task': 'task2',
 'test_crop_multiple': 8,
 'total_samples': 200,
 'use_age': 'fc',
 'warmup_min':

## Train

In [5]:
wandb_run = wandb.init(project=f"{project}")
wandb.run.name = wandb.run.id

with wandb_run:
    config = {}

    # load default configurations not selected by wandb.sweep
    for k, v in cfg_default.items():
        if k not in [wandb_key.split('.')[-1] for wandb_key in wandb.config.keys()]:
            config[k] = v

    # load the selected configurations from wandb sweep with preventing callables from type-conversion to str
    for k, v in wandb.config.items():
        k = k.split('.')[-1]
        if k not in config:
            config[k] = v

    # build the dataset and train the model
    if config.get('ddp', False):
        mp.spawn(prepare_and_run_train,
                 args=(config['ddp_size'], config,),
                 nprocs=config['ddp_size'],
                 join=True)
    else:
        prepare_and_run_train(rank=None, world_size=None, config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mipis-mjkim[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


VBox(children=(Label(value='0.825 MB of 0.825 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Learning Rate,▁
Loss,▁
Multi-Crop Test Accuracy,▁
Test Accuracy,▁
Train Accuracy,▁
Validation Accuracy,▁

0,1
Learning Rate,1e-05
Loss,1.21899
Multi-Crop Test Accuracy,33.33333
Test Accuracy,32.26695
Train Accuracy,36.25
Validation Accuracy,35.71429
