# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal** and **Abnormal** symptoms
- `Weights and Biases` sweep is used for hyperparameter search

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import json
from copy import deepcopy

import hydra
from omegaconf import OmegaConf
import yaml
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

## Environment

In [3]:
data_cfg_file = 'task1'
device = 'cuda:0'
sweep_yaml_path = 'config/sweep/sweep_task1.yaml'

## Sweep configurations

In [4]:
with open(sweep_yaml_path) as f:
    sweep_yaml = yaml.load(f, Loader=yaml.FullLoader)

sweep_yaml.pop('command')
pprint.pprint(sweep_yaml)

{'entity': 'ipis-mjkim',
 'method': 'random',
 'name': 'sweep-task1',
 'parameters': {'data.EKG': {'values': ['O', 'X']},
                'data.awgn': {'distribution': 'uniform', 'max': 0.12, 'min': 0},
                'data.awgn_age': {'distribution': 'uniform',
                                  'max': 0.3,
                                  'min': 0},
                'data.mgn': {'distribution': 'uniform', 'max': 0.1, 'min': 0},
                'data.photic': {'values': ['O', 'X']},
                'data.seq_length': {'values': [1000, 2000, 3000, 4000]},
                'model': {'values': ['1D-VGG-16',
                                     '1D-VGG-19',
                                     '1D-ResNet-18',
                                     '1D-ResNet-50',
                                     '1D-ResNeXt-50',
                                     '1D-ResNeXt-101',
                                     '1D-Wide-ResNet-50',
                                     '1D-CNN-Transformer',
      

In [5]:
# sweep_id = wandb.sweep(sweep_yaml, project=f"caueeg-{cfg_data['task']}")
sweep_id = wandb.sweep(sweep_yaml, project=sweep_yaml['project'])

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: 6h6cr48q
Sweep URL: https://wandb.ai/ipis-mjkim/sweep-test/sweeps/6h6cr48q


## Train

In [None]:
def train_sweep():
    # wandb_run = wandb.init(project=f"caueeg-{cfg_default['task']}")
    wandb_run = wandb.init()
    wandb.run.name = wandb.run.id

    with wandb_run:
        # init hydra
        with hydra.initialize(config_path='../config'):
            add_configs = [f"data={data_cfg_file}", 
                           f"model={wandb.config.model}",
                           f"++train.device={device}",]
                           # f"++train.total_samples=200", f"++train.search_lr=false", f"++train.num_history=1"]  # for test

            cfg = hydra.compose(config_name='default', overrides=add_configs)

            cfg_default = {**OmegaConf.to_container(cfg.data), 
                           **OmegaConf.to_container(cfg.train),
                           **OmegaConf.to_container(cfg.model)}
            
            check_device_env(cfg_default)
            pprint.pprint(cfg_default)    

        # load default configurations not selected by wandb.sweep
        config = {}
        for k, v in cfg_default.items():
            if k not in [wandb_key.split('.')[-1] for wandb_key in wandb.config.keys()]:
                config[k] = v

        # load the selected configurations from wandb sweep with preventing callables from type-conversion to str
        for k, v in wandb.config.items():
            k = k.split('.')[-1]
            if k not in config:
                config[k] = v

        # build the dataset and train the model
        if config.get('ddp', False):
            mp.spawn(prepare_and_run_train,
                     args=(config['ddp_size'], config,),
                     nprocs=config['ddp_size'],
                     join=True)
        else:
            prepare_and_run_train(rank=None, world_size=None, config=config)