# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal** and **Abnormal** symptoms

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import os
import json
from copy import deepcopy

import hydra
from omegaconf import OmegaConf
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

---

## Specify the dataset, model, and train setting

In [None]:
project = 'sweep-test'
data_cfg_file = 'task1'
train_cfg_file = 'base_train'
model_cfg_file = '1D-ResNet-101'
device = 'cuda:0'

---

## Initializing configurations using Hydra

In [None]:
with hydra.initialize(config_path='../config'):
    add_configs = [f"data={data_cfg_file}", 
                   f"train={train_cfg_file}", 
                   f"+train.device={device}", 
                   f"model={model_cfg_file}",]
#                   f"++train.total_samples=200", f"++train.search_lr=false", f"++train.num_history=1"]  # for test
    
    cfg = hydra.compose(config_name='default', overrides=add_configs)
    
cfg_default = {**OmegaConf.to_container(cfg.data), 
               **OmegaConf.to_container(cfg.train),
               **OmegaConf.to_container(cfg.model)}

check_device_env(cfg_default)
pprint.pprint(cfg_default)

## Train

In [None]:
wandb_run = wandb.init(project=f"{project}")
wandb.run.name = wandb.run.id

with wandb_run:
    config = {}

    # load default configurations not selected by wandb.sweep
    for k, v in cfg_default.items():
        if k not in [wandb_key.split('.')[-1] for wandb_key in wandb.config.keys()]:
            config[k] = v

    # load the selected configurations from wandb sweep with preventing callables from type-conversion to str
    for k, v in wandb.config.items():
        k = k.split('.')[-1]
        if k not in config:
            config[k] = v

    # build the dataset and train the model
    if config.get('ddp', False):
        mp.spawn(prepare_and_run_train,
                 args=(config['ddp_size'], config,),
                 nprocs=config['ddp_size'],
                 join=True)
    else:
        prepare_and_run_train(rank=None, world_size=None, config=config)