# Train Channel Reduction

- Train deep classifiers for the EEG-based diagnostic classification with channel reduction.
    - CAUEEG-Dementia benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
    - CAUEEG-Abnormal benchmark: Classification of **Normal** and **Abnormal** symptoms    

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import hydra
from omegaconf import OmegaConf
import wandb
import pprint
from copy import deepcopy
import itertools
from tqdm.auto import tqdm

# custom package
from run_train import check_device_env
from run_train import prepare_and_run_train

---

## Specify the dataset, model, and train setting

In [None]:
script = "train=distillation-logit train.distil_teacher_logit=local/logit-ensemble-by-timing-abnormal.pt data=caueeg-abnormal data.EKG=O data.awgn=0.09950720401001538 data.awgn_age=0.17083174076744487 data.mgn=0.01320748620440252 data.photic=O data.seq_length=3000 model=1D-ResNet-18 model.activation=relu model.dropout=0.45123289319795373 model.fc_stages=3 model.use_age=fc train.criterion=multi-bce train.lr_scheduler_type=transformer_style train.mixup=0.1 train.weight_decay=0.0001440435267533863 ++train.base_lr=0.0025298221281347044 ++train.search_lr=False"
print(script)

In [None]:
add_configs_base = []

for seg in script.split(" "):
    if 'train.project' in seg:
        continue
    elif "." in seg and "++" not in seg:
        seg = "++" + seg
    add_configs_base.append(seg)

add_configs_base.append("++model.base_model=4439k9pg")
add_configs_base.append("++train.project=caueeg-abnormal-channel-reduction")
add_configs_base.append("++train.total_samples=1.0e+7")

pprint.pprint(add_configs_base)

---

## Initializing configurations using Hydra and Train

In [None]:
signal_header = ["Fp1-AVG", "F3-AVG", "C3-AVG", "P3-AVG", "O1-AVG",
                 "Fp2-AVG", "F4-AVG", "C4-AVG", "P4-AVG", "O2-AVG", 
                 "F7-AVG", "T3-AVG", "T5-AVG","F8-AVG", "T4-AVG", 
                 "T6-AVG", "FZ-AVG", "CZ-AVG", "PZ-AVG", "EKG", "Photic"]

In [None]:
for channel_difference in [[7, 17], [5, 17], [0, 1], [7, 11], [1, 3]]:
    for distil_alpha in [0.1, 0.9, 0.98]:
        add_configs = deepcopy(add_configs_base)
        
        with hydra.initialize(config_path="../config"):
            cfg = hydra.compose(config_name="default", overrides=add_configs)
        
        config = {
            **OmegaConf.to_container(cfg.data),
            **OmegaConf.to_container(cfg.train),
            **OmegaConf.to_container(cfg.model),
        }
        config['seq_length'] = config['crop_length'] = 2000
        config['channel_difference'] = [*channel_difference]
        config['crop_timing_analysis'] = True
        config['criterion'] = 'cross-entropy'   ##############################
        config['mixup'] = 0.2   ##############################
        config['distil_logit_stand'] = True  ##############################
        config['distil_alpha'] = distil_alpha  ##############################
        config['distil_tau'] = 2.0  ##############################
        
        config['awgn'] = 0.1      ##############################
        config['mgn'] = 0.1      ##############################
    
        config['EKG'] = "X"
        config['photic'] = "X"
        config['mixed_precision'] = True
        config['montage'] = ' - '.join([signal_header[i].split('-')[0] for i in channel_difference])
        config['device'] = 'cuda:0'
        config['total_samples'] = 1.0e+7
        
        check_device_env(config)
        # pprint.pprint(config)
        prepare_and_run_train(rank=None, world_size=None, config=config)

## Parse Results

In [None]:
import csv
import numpy as np

performance_list = [] 

with open(r"local/wandb_save.csv") as fp:
    rdr = csv.reader(fp)
    for i, line in enumerate(rdr):
        if i == 0:
            continue
        mont1, mont2 = line[13].split(" - ")
        test_acc = line[16]
        multi_test_acc = line[17]
        performance_list.append({
            'Mont1': mont1,
            'Mont2': mont2,
            'Test': test_acc,
            'TTA': multi_test_acc,
        })
        
# print(performance_list)

In [None]:
idx_to_mont = [mont.split('-')[0] for mont in signal_header if mont.lower() not in ['ekg', 'photic']]
mont_to_idx = {mont: i for i, mont in enumerate(idx_to_mont)}

test = np.zeros((19, 19))
mtest = np.zeros((19, 19))

for perf in performance_list:
    test[mont_to_idx[perf['Mont1']], mont_to_idx[perf['Mont2']]] = perf['Test']
    # test[mont_to_idx[perf['Mont2']], mont_to_idx[perf['Mont1']]] = perf['Test']

    mtest[mont_to_idx[perf['Mont1']], mont_to_idx[perf['Mont2']]] = perf['TTA']
    # mtest[mont_to_idx[perf['Mont2']], mont_to_idx[perf['Mont1']]] = perf['TTA']

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'jet' # nipy_spectral, jet

fig, ax = plt.subplots(num=1, clear=True, figsize=(10.0, 10.0))

im = ax.imshow(test)
ax.set_xticks(np.arange(len(idx_to_mont)), labels=idx_to_mont)
ax.set_yticks(np.arange(len(idx_to_mont)), labels=idx_to_mont)
for i in range(len(idx_to_mont)):
    for j in range(len(idx_to_mont)):
        text = ax.text(j, i, round(test[i, j] * 100) / 100,
                       ha="center", va="center", color="w")
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(num=1, clear=True, figsize=(10.0, 10.0))

im = ax.imshow(mtest)
ax.set_xticks(np.arange(len(idx_to_mont)), labels=idx_to_mont)
ax.set_yticks(np.arange(len(idx_to_mont)), labels=idx_to_mont)
for i in range(len(idx_to_mont)):
    for j in range(len(idx_to_mont)):
        text = ax.text(j, i, round(mtest[i, j] * 100) / 100,
                       ha="center", va="center", color="w")
fig.tight_layout()
plt.show()