# Train Networks

- Train a deep classifier for the EEG-based diagnostic classification
    - CAUEEG-Dementia benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
    - CAUEEG-Abnormal benchmark: Classification of **Normal** and **Abnormal** symptoms    

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import hydra
from omegaconf import OmegaConf
import wandb
import numpy as np
import pprint
import gc
from copy import deepcopy

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp

from train.train_script import train_script
from datasets.caueeg_script import build_dataset_for_train
from datasets.temple_eeg_script import build_dataset_for_tuab_train
from models.utils import count_parameters

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

---

## Specify the dataset, model, and train setting

In [None]:
project = 'caueeg-dementia-train-size'
data_cfg_file = 'caueeg-dementia'
device = 'cuda:0'
seed = 0

---

## Initializing configurations using Hydra

In [None]:
with hydra.initialize(config_path='../config'):
    add_configs = [f"data={data_cfg_file}", 
                   f"+train.device={device}", 
                   f"+train.project={project}",
                   f"++train.seed={seed}",
                   f"++data.EKG=O",
                   f"++data.awgn=0.004872735559634612",
                   f"++data.awgn_age=0.03583361229344302",
                   f"++data.mgn=0.09575622309480344",
                   f"++data.photic=O",
                   f"++data.seq_length=2000",
                   f"model=1D-ResNet-18",
                   f"++model.activation=gelu",
                   f"++model.dropout=0.3",
                   f"++model.fc_stages=3",
                   f"++model.use_age=conv",
                   f"++train.criterion=multi-bce",
                   f"++train.lr_scheduler_type=cosine_decay_with_warmup_half",
                   f"++train.mixup=0.2",
                   f"++train.weight_decay=0.04394746639552375",]
    cfg = hydra.compose(config_name='default', overrides=add_configs)
    
config_base = {**OmegaConf.to_container(cfg.data), 
               **OmegaConf.to_container(cfg.train),
               **OmegaConf.to_container(cfg.model)}

check_device_env(config_base)
pprint.pprint(config_base)

In [None]:
rank = None
world_size = None

## Train

In [None]:
repeat = 1

for ratio in np.linspace(0.2, 0.8, num=4):
    for r in range(repeat):
        config = deepcopy(config_base)
        
        # collect some garbage
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

        # fix the seed for reproducibility (a negative seed value means not fixing)
        if config.get('seed', 0) >= 0:
            config['seed'] = config.get('seed', 0) + r*10
            seed = config['seed']
            torch.manual_seed(seed)
            np.random.seed(seed)

        # compose dataset
        if config.get('dataset_name', None) == 'tuab':
            train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_tuab_train(config)
        else:
            train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config)
            
        # reduce the training set size
        serial_dict_by_class = {}
        for i, data in enumerate(train_loader.dataset):
            class_name = data['class_name']
            if class_name in serial_dict_by_class.keys():
                serial_dict_by_class[class_name].append(i)
            else:
                serial_dict_by_class[class_name] = [i]
        
        keep_list = np.array([], dtype=int)
        for k, v in serial_dict_by_class.items():
            keep_list = np.append(keep_list, np.random.choice(np.array(v), round(len(v) * ratio)))
        
        data_list = []
        for keep in keep_list:
            data_list.append(train_loader.dataset.data_list[keep])
        train_loader.dataset.data_list = data_list
        
        config['train_set_size'] = len(train_loader.dataset.data_list)
                
        # generate the model and update some configurations
        model = hydra.utils.instantiate(config)
        model = model.to(config['device'])
        config['output_length'] = model.get_output_length()
        config['num_params'] = count_parameters(model)
        

        # train
        train_script(config, model, train_loader, val_loader, test_loader, multicrop_test_loader,
                     config['preprocess_train'], config['preprocess_test'])