# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal** and **Abnormal** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
- `Weights and Biases` sweep is used for hyperparameter search

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import os
import json
from copy import deepcopy

import hydra
from omegaconf import OmegaConf
import yaml
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train
from train.utils import load_sweep_config

## Environment

In [None]:
data_cfg_file = 'task2'
train_cfg_file = 'tiny_train'
device = 'cuda:0'
sweep_yaml_path = 'config/sweep/sweep_task2.yaml'
count = 1

## Sweep configurations

In [None]:
with open(sweep_yaml_path) as f:
    sweep_yaml = yaml.load(f, Loader=yaml.FullLoader)

sweep_yaml.pop('command')
pprint.pprint(sweep_yaml)

In [None]:
sweep_id = wandb.sweep(sweep_yaml, project=sweep_yaml['project'])

---

## Train

In [None]:
def train_sweep():
    # connect wandb and get the sweep parameters
    wandb_run = wandb.init(reinit=True)
    wandb.run.name = wandb.run.id
    
    with hydra.initialize(config_path='../config'):
        add_configs = [f"data={data_cfg_file}", 
                       f"train={train_cfg_file}", 
                       f"model={wandb.config.model}",
                       f"+train.device={device}",]
        cfg = hydra.compose(config_name='default', overrides=add_configs)
        config = {**OmegaConf.to_container(cfg.data), 
                  **OmegaConf.to_container(cfg.train), 
                  **OmegaConf.to_container(cfg.model)}
        
    config = load_sweep_config(config)
    check_device_env(config)
    pprint.pprint(config)

    # build the dataset and train the model
    if config.get('ddp', False):
        mp.spawn(prepare_and_run_train,
                 args=(config['ddp_size'], config,),
                 nprocs=config['ddp_size'],
                 join=True)
    else:
        prepare_and_run_train(rank=None, world_size=None, config=config)

In [None]:
wandb.agent(sweep_id, function=train_sweep, count=count)