# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal** and **Abnormal** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import hydra
from omegaconf import OmegaConf
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

---

## Specify the dataset, model, and train setting

In [3]:
project = 'noname'
data_cfg_file = 'task2_segmented_ieracitano'  # 'task2_ieracitano'
train_cfg_file = 'base_train'
model_cfg_file = 'Ieracitano-CNN'
device = 'cuda:0'

---

## Initializing configurations using Hydra

In [4]:
with hydra.initialize(config_path='../config'):
    add_configs = [f"data={data_cfg_file}", 
                   f"train={train_cfg_file}", 
                   f"+train.device={device}", 
                   f"+train.project={project}",
                   f"model={model_cfg_file}",]
    
    cfg = hydra.compose(config_name='default', overrides=add_configs)
    
config = {**OmegaConf.to_container(cfg.data), 
          **OmegaConf.to_container(cfg.train),
          **OmegaConf.to_container(cfg.model)}

check_device_env(config)
pprint.pprint(config)

{'EKG': 'X',
 '_target_': 'models.simple_cnn_2d.IeracitanoCNN',
 'activation': 'relu',
 'awgn': 0,
 'awgn_age': 0,
 'base_lr': 0.0001,
 'criterion': 'cross-entropy',
 'crop_multiple': 1,
 'crop_timing_analysis': False,
 'dataset_path': 'local/dataset/02_Curated_Data_220715_seg_30s/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'fc_stages': 1,
 'file_format': 'memmap',
 'in_channels': '???',
 'input_norm': 'dataset',
 'latency': 0,
 'load_event': False,
 'lr_scheduler_type': 'constant_with_decay',
 'mgn': 0,
 'minibatch': 512,
 'mixup': 0.0,
 'model': 'Ieracitano-CNN',
 'num_history': 500,
 'out_dims': '???',
 'photic': 'X',
 'project': 'noname',
 'run_mode': 'train',
 'save_model': True,
 'search_lr': True,
 'search_multiplier': 1.0,
 'seed': 0,
 'segment_simulation': True,
 'seq_length': 1000,
 'signal_length_limit': 10000000,
 'task': 'task2',
 'test_crop_multiple': 1,
 'total_samples': 100000000.0,
 'use_age': 'no',
 'use_wandb': True,
 'warmup_min

## Train

In [None]:
prepare_and_run_train(rank=None, world_size=None, config=config)