# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-Dementia benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
    - CAUEEG-Abnormal benchmark: Classification of **Normal** and **Abnormal** symptoms

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import hydra
from omegaconf import OmegaConf
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

---

## Specify the dataset, model, and train setting

In [3]:
project = 'sweep-test'
data_cfg_file = 'caueeg-dementia'
train_cfg_file = 'base_train'
model_cfg_file = '2D-ResNeXt-50'
device = 'cuda:0'

---

## Initializing configurations using Hydra

In [4]:
# with hydra.initialize(config_path='../config'):
#     add_configs = [f"data={data_cfg_file}", 
#                    f"train={train_cfg_file}", 
#                    f"+train.device={device}", 
#                    f"+train.project={project}",
#                    f"model={model_cfg_file}",]
    
#     cfg = hydra.compose(config_name='default', overrides=add_configs)
    
# config = {**OmegaConf.to_container(cfg.data), 
#           **OmegaConf.to_container(cfg.train),
#           **OmegaConf.to_container(cfg.model)}

# check_device_env(config)
# pprint.pprint(config)

In [5]:
# with hydra.initialize(config_path='../config'):
#     add_configs = [f"data={data_cfg_file}", 
#                    f"train={train_cfg_file}", 
#                    f"+train.device={device}", 
#                    f"+train.project={project}",
#                    f"++train.init_from=1sl7ipca",
#                    f"model=2D-ResNeXt-50",
#                    f"++model.activation=mish",                   
#                    f"++model.criterion=cross-entropy",
#                    f"++model.fc_stages=5",
#                    f"++model.dropout=0.04197529259802718",
#                    f"++train.mixup=0",
#                    f"++train.seq_length=4000",
#                    f"++train.awgn=0.10394966750385833",
#                    f"++train.awgn_age=0.01804953928628786",
#                    f"++train.mgn=0.056254713649316834",
#                    f"++train.age_mean=71.35855",
#                    f"++train.age_std=9.637834",
#                    f"++train.lr_scheduler_type=constant_with_twice_decay",
#                    f"++data.file_format=memmap",
#                    f"++data.photic=O",
#                    f"++data.EKG=X",
#                    f"++data.file_format=memmap",
#                   ]
    
#     cfg = hydra.compose(config_name='default', overrides=add_configs)
    
# config = {**OmegaConf.to_container(cfg.data), 
#           **OmegaConf.to_container(cfg.train),
#           **OmegaConf.to_container(cfg.model)}

# check_device_env(config)
# pprint.pprint(config)

In [6]:
with hydra.initialize(config_path='../config'):
    add_configs = [f"data={data_cfg_file}", 
                   f"train={train_cfg_file}", 
                   f"+train.device={device}", 
                   f"+train.project={project}",
#                   f"++train.init_from=1sl7ipca",
                   f"++train.resume=35i3jb9v",
                   f"model=2D-ResNeXt-50",
                   f"++model.activation=mish",                   
                   f"++model.criterion=cross-entropy",
                   f"++model.fc_stages=5",
                   f"++model.dropout=0.04197529259802718",
                   f"++train.mixup=0",
                   f"++train.seq_length=4000",
                   f"++train.awgn=0.10394966750385833",
                   f"++train.awgn_age=0.01804953928628786",
                   f"++train.mgn=0.056254713649316834",
                   f"++train.age_mean=71.35855",
                   f"++train.age_std=9.637834",
                   f"++train.lr_scheduler_type=constant_with_twice_decay",
                   f"++data.file_format=memmap",
                   f"++data.photic=O",
                   f"++data.EKG=X",
                   f"++data.file_format=memmap",
                  ]
    
    cfg = hydra.compose(config_name='default', overrides=add_configs)
    
config = {**OmegaConf.to_container(cfg.data), 
          **OmegaConf.to_container(cfg.train),
          **OmegaConf.to_container(cfg.model)}

check_device_env(config)
pprint.pprint(config)

{'EKG': 'X',
 '_target_': 'models.resnet_2d.ResNet2D',
 'activation': 'mish',
 'age_mean': 71.35855,
 'age_std': 9.637834,
 'awgn': 0.10394966750385833,
 'awgn_age': 0.01804953928628786,
 'base_channels': 64,
 'base_lr': 0.0001,
 'block': 'bottleneck',
 'conv_layers': [3, 4, 6, 3],
 'criterion': 'cross-entropy',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'dropout': 0.04197529259802718,
 'fc_stages': 5,
 'file_format': 'memmap',
 'groups': 32,
 'in_channels': '???',
 'input_norm': 'dataset',
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'constant_with_twice_decay',
 'mgn': 0.056254713649316834,
 'minibatch': 128,
 'mixup': 0,
 'model': '2D-ResNeXt-50',
 'num_history': 500,
 'out_dims': '???',
 'photic': 'O',
 'project': 'sweep-test',
 'resume': '35i3jb9v',
 'run_mode': 'train',
 'save_model': True,
 'search_lr': True,
 'search

## Train

In [7]:
prepare_and_run_train(rank=None, world_size=None, config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mipis-mjkim[0m. Use [1m`wandb login --relogin`[0m to force relogin


'Training resumes from 35i3jb9v'
{'EKG': 'X',
 '_target_': 'models.resnet_2d.ResNet2D',
 'activation': 'mish',
 'age_mean': 71.35855,
 'age_std': 9.637834,
 'awgn': 0.10394966750385833,
 'awgn_age': 0.01804953928628786,
 'base_channels': 64,
 'base_lr': 4.30168830445025e-05,
 'block': 'bottleneck',
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'conv_layers': [3, 4, 6, 3],
 'criterion': 'cross-entropy',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'dropout': 0.04197529259802718,
 'fc_stages': 5,
 'file_format': 'memmap',
 'groups': 32,
 'in_channels': 40,
 'init_from': '1sl7ipca',
 'input_norm': 'dataset',
 'iterations': 781250,
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'constant_with_twice_decay',
 'mgn': 0.056254713

VBox(children=(Label(value='0.129 MB of 0.129 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Learning Rate,██████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss,▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁
Multi-Crop Test Accuracy,▁
Test Accuracy,▁
Train Accuracy,█▇█▇▇█▇▇██████▇█▅█████▁██████▇████▇▇████
Validation Accuracy,▆▅▄▇▆▄▅▄▇▅▆▄▅▅▄▄▃▇▇▅█▃▅▆▅▄▅▁▄▅▅▆█▂▁▄▁▂▆▃

0,1
Learning Rate,0.0
Loss,0.0
Multi-Crop Test Accuracy,67.65537
Test Accuracy,66.10169
Train Accuracy,100.0
Validation Accuracy,63.9916
