# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal** and **Abnormal** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import hydra
from omegaconf import OmegaConf
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

---

## Specify the dataset, model, and train setting

In [None]:
project = 'noname'
data_cfg_file = 'caueeg-dementia-segmented-ieracitano'  # 'task2_ieracitano'
train_cfg_file = 'base_train'
model_cfg_file = 'Ieracitano-CNN'
device = 'cuda:0'

---

## Initializing configurations using Hydra

In [None]:
with hydra.initialize(config_path='../config'):
    add_configs = [f"data={data_cfg_file}", 
                   f"train={train_cfg_file}", 
                   f"+train.device={device}", 
                   f"+train.project={project}",
                   f"model={model_cfg_file}",]
    
    cfg = hydra.compose(config_name='default', overrides=add_configs)
    
config = {**OmegaConf.to_container(cfg.data), 
          **OmegaConf.to_container(cfg.train),
          **OmegaConf.to_container(cfg.model)}

check_device_env(config)
pprint.pprint(config)

## Train

In [None]:
prepare_and_run_train(rank=None, world_size=None, config=config)