# Segment Classification

In [None]:
import os.path
import sys

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from torch.utils.data import DataLoader

from dataset.individual import IndividualTrialDataset
from dataset.ki import KIDataset
from dataset.triplet import TripletSegmentDataset
from models.classifier import IndividualClassifier
from models.inception import InceptionTimeModel
from train.inceptiontime import train_inception_time
from train.transformer import train_transformer
from utils.const import SEED
from utils.data import PadCollate
from utils.misc import set_random_state
from utils.path import checkpoints_path as MODELS_PATH

sys.path.append('../')

torch.set_default_dtype(torch.float)
DEVICE = 'cuda'
INCLUDE_PDON_TF = True
INCLUDE_PDON_SEGM = True
BINARY_CLF = True
N_EPOCHS = 400
N_EPOCHS_CLF = 100
EMBEDDER_CHKPT_FPATH = os.path.join(MODELS_PATH, f'inception_time{"_pdon" if INCLUDE_PDON_SEGM else ""}.pth')
CLF_CHKPT_FPATH = os.path.join(MODELS_PATH, f'linear_clf{"_pdon" if INCLUDE_PDON_SEGM else ""}.pth')
IND_CLF_CHKPT_FPATH = os.path.join(MODELS_PATH, f'ind_clf{"_pdon" if INCLUDE_PDON_SEGM else ""}.pth')
TEST_CLF = False
set_random_state(SEED)

## Dataset/Dataloader Initialization

In [None]:
# Initialize Datasets
ds_segm_train = KIDataset(train=True, which='segments', config='ki_auto', ki_data_dirname='KI',
                          data_sources=['HC', 'PD_OFF', 'PD_ON'] if INCLUDE_PDON_SEGM else ['HC', 'PD_OFF'])
ds_segm_test = KIDataset(train=False, which='segments', config='ki_auto', ki_data_dirname='KI',
                         data_sources=['HC', 'PD_OFF', 'PD_ON'])
# Dataloader
dl_segm_train = DataLoader(TripletSegmentDataset(ds_segm_train, binary_clf=BINARY_CLF), batch_size=128, shuffle=True,
                           pin_memory=True)
dl_segm_test = DataLoader(TripletSegmentDataset(ds_segm_test, binary_clf=BINARY_CLF), batch_size=128, shuffle=False,
                          pin_memory=True)

## InceptionTime Initialization

In [None]:
# Initialize Embedder
embedder = InceptionTimeModel(bottleneck_channels=0,
                              in_channels=4,
                              kernel_sizes=65,
                              num_blocks=3,
                              num_pred_classes=32,
                              out_channels=64,
                              use_residuals='default')
embedder = embedder.to(DEVICE)

# Initialize Classifier
clf = nn.Sequential(
    nn.BatchNorm1d(embedder.out_dim),
    nn.Linear(embedder.out_dim, 10),
    nn.SELU(inplace=True),
    nn.Linear(10, 2),
)
clf = clf.to(DEVICE)

## Training Loop

In [None]:
# Optimizers
segm_optim = torch.optim.AdamW(embedder.parameters(), lr=1e-3, weight_decay=1e-2)
segm_optim.add_param_group({
    'params': clf.parameters(),
    'lr': 1e-4,
    'weight_decay': 1e-2
})
segm_optim_sched = torch.optim.lr_scheduler.SequentialLR(
    segm_optim,
    schedulers=[
        torch.optim.lr_scheduler.LinearLR(segm_optim, total_iters=N_EPOCHS // 20),
        torch.optim.lr_scheduler.CosineAnnealingLR(segm_optim, T_max=N_EPOCHS - N_EPOCHS // 20),
    ],
    milestones=[N_EPOCHS // 20]
)

# Criteria
bce_crit = nn.BCEWithLogitsLoss()
triplet_crit = nn.TripletMarginLoss(margin=1.0, p=2, swap=True)

In [None]:
if not os.path.exists(EMBEDDER_CHKPT_FPATH):
    lt, lv, at, av = train_inception_time(embedder, clf, N_EPOCHS, dl_segm_train, dl_segm_test, DEVICE, segm_optim,
                                          segm_optim_sched, bce_crit, triplet_crit)

    # Save checkpoints
    torch.save(embedder, EMBEDDER_CHKPT_FPATH)
    torch.save(clf, CLF_CHKPT_FPATH)

    # Plots
    plt.plot(lt, label='loss train')
    plt.plot(lv, label='loss test')
    plt.legend()
    plt.show()
    plt.plot(at, label='acc train')
    plt.plot(av, label='acc test')
    plt.legend()
    plt.show()
else:
    #     embedder = torch.load(EMBEDDER_CHKPT_FPATH).to(DEVICE).eval()
    ...

# Individual Classification

## Datasets and Dataloader Initialization

In [None]:
# Initialize Datasets
ds_trial_train = KIDataset(train=True, which='trials', config='ki_auto', ki_data_dirname='KI',
                           data_sources=['HC', 'PD_OFF', 'PD_ON'] if INCLUDE_PDON_TF else ['HC', 'PD_OFF'])
ds_trial_test = KIDataset(train=False, which='trials', config='ki_auto', ki_data_dirname='KI',
                          data_sources=['HC', 'PD_OFF', 'PD_ON'])

# Dataloaders
dl_trial_train = DataLoader(IndividualTrialDataset(ds_trial_train), batch_size=16, shuffle=True, pin_memory=True,
                            drop_last=True, collate_fn=PadCollate(dim=0))
dl_trial_test = DataLoader(IndividualTrialDataset(ds_trial_test), batch_size=16, shuffle=False, pin_memory=True,
                           collate_fn=PadCollate(dim=0))

## Models Initialization

In [None]:
# Freeze Embedder
embedder_re = nn.Sequential(
    Rearrange('b n l d -> (b n) d l'),
    embedder,
).to(DEVICE).eval()
for p in embedder.parameters():
    p.requires_grad = False


# Weight Initializer
def init_weights(module: nn.Module):
    if isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=1.0)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    elif isinstance(module, nn.Linear):
        module.bias.data.zero_()
        module.weight.data.normal_(mean=0.0, std=0.1)


# Initialize Transformer (Classifier)
ind_clf = IndividualClassifier(in_features=32, d_model=128, nhead=4, num_layers=2, dim_feedforward=256,
                               batch_first=True, dropout=0.3, n_classes=2 if BINARY_CLF else 3)
ind_clf.apply(init_weights)
ind_clf = ind_clf.to(DEVICE)

## Training Loop

In [None]:
# Optimizers
trial_optim = torch.optim.AdamW(ind_clf.parameters(), lr=1e-5, weight_decay=1e-1)
# segm_optim_sched = torch.optim.lr_scheduler.CosineAnnealingLR(trial_optim, T_max=N_EPOCHS_CLF)
trial_optim_sched = torch.optim.lr_scheduler.SequentialLR(
    trial_optim,
    schedulers=[
        torch.optim.lr_scheduler.LinearLR(trial_optim, total_iters=N_EPOCHS_CLF // 20),
        torch.optim.lr_scheduler.CosineAnnealingLR(trial_optim, T_max=N_EPOCHS_CLF - N_EPOCHS_CLF // 20),
    ],
    milestones=[N_EPOCHS_CLF // 20]
)

# Criteria
bce_crit = nn.BCEWithLogitsLoss()

In [None]:
lt, lv, at, av = train_transformer(embedder_re, ind_clf, N_EPOCHS_CLF, dl_trial_train, dl_trial_test, DEVICE,
                                   trial_optim, trial_optim_sched, bce_crit)

torch.save(ind_clf, IND_CLF_CHKPT_FPATH)

# Plots
plt.plot(lt, label='loss train')
plt.plot(lv, label='loss test')
plt.legend()
plt.show()
plt.plot(at, label='acc train')
plt.plot(av, label='acc test')
plt.legend()
plt.show()