In [1]:
import copy
import sys

import einops
import numpy
import numpy as np
import torch

sys.exc_info()
from utils.config import process_config
from datasets.sleepset import *
from graphs.models.attention_models.windowFeature_base import *
from sklearn.metrics import f1_score, cohen_kappa_score, roc_auc_score, confusion_matrix
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from utils.deterministic_pytorch import deterministic
# import umap
from scipy.stats import entropy
import os
from torchdistill.core.forward_hook import ForwardHookManager
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
import pandas as pd
%config InlineBackend.figure_format = 'retina'

In [2]:
def sleep_load_encoder(encoders):
    encs = []
    for num_enc in range(len(encoders)):
        if encoders[num_enc]["model"] == "TF":
            layers = ["huy_pos_inner", "inner_att", "aggregation_att_contx_inner", "huy_pos_outer", "outer_att"]
            enc = Multi_Transformer(128, inner= 29, outer = 21, modalities=1, heads=8,
                                 layers = layers, num_layers=4, pos = False)
        else:
            enc_class = globals()[encoders[num_enc]["model"]]
            args = encoders[num_enc]["args"]
            enc = enc_class(args = args)
            enc = nn.DataParallel(enc, device_ids=[torch.device(0)])

        if encoders[num_enc]["pretrainedEncoder"]["use"]:
            print("Loading encoder from {}".format(encoders[num_enc]["pretrainedEncoder"]["dir"]))
            checkpoint = torch.load(encoders[num_enc]["pretrainedEncoder"]["dir"])
            enc.load_state_dict(checkpoint["encoder_state_dict"])
        encs.append(enc)
    return encs

def load_models(config, device, checkpoint, only_model=False):

    model_class = globals()[config.model.model_class]
    # config.pretrainedEncoder = [False]
    enc = sleep_load_encoder(encoders=config.model.encoders)
    model = model_class(enc, args = config.model.args)
    # model = model.to('cpu')
    # model = nn.DataParallel(model, device_ids='cpu')
    model = model.to(device)
    model = nn.DataParallel(model, device_ids=[torch.device(i) for i in config.gpu_device])

    #
    if only_model:
        return model

    # config.pretrainedEncoder = [True]
    # enc = sleep_load_encoder(encoder_models=config.encoder_models,pretrainedEncoder=config.pretrainedEncoder,save_dir_encoder=config.savetrainedEncoder)
    # best_model = model_class(enc, channel = config.channel)
    # best_model = best_model.to(device)
    # best_model = nn.DataParallel(best_model, device_ids=[torch.device(i) for i in config.gpu_device])

    best_model = copy.deepcopy(model)
    # best_model = best_model.to('cpu')
    # best_model = nn.DataParallel(best_model, device_ids='cpu')
    # model.load_state_dict(checkpoint["model_state_dict"])
    best_model.load_state_dict(checkpoint["best_model_state_dict"])

    return model, best_model

def find_patient_list(data_loader):
    patient_list = [int(data.split("/")[-1][1:5]) for data in data_loader.dataset.dataset[0] if data.split("/")[-1]!="empty"]
    return patient_list

def perf_measure(y_actual, y_hat):
    TP = 0
    FP = 0
    TN = 0
    FN = 0

    for i in range(len(y_hat)):
        if y_actual[i] == y_hat[i] == 1:
            TP += 1
        if y_hat[i] == 1 and y_actual[i] != y_hat[i]:
            FP += 1
        if y_actual[i] == y_hat[i] == 0:
            TN += 1
        if y_hat[i] == 0 and y_actual[i] != y_hat[i]:
            FN += 1

    return (TP, FP, TN, FN)

def print_perf(model_name, patient_num, preds, tts, multiclass=True):
    preds_eeg = preds.argmax(-1)
    test_acc = np.equal(tts, preds_eeg).sum() / len(tts)
    test_f1 = f1_score(preds_eeg, tts) if not multiclass else f1_score(preds_eeg, tts, average="macro")
    test_perclass_f1 = f1_score(preds_eeg, tts) if not multiclass else f1_score(preds_eeg, tts, average=None)
    test_k = cohen_kappa_score(tts, preds_eeg)
    test_auc = roc_auc_score(tts, preds_eeg) if not multiclass else 0
    test_conf = confusion_matrix(tts, preds_eeg)
    tp, fp, tn, fn = perf_measure(tts, preds_eeg)
    test_spec = tn / (tn + fp) if (tn + fp) != 0 else 0
    test_sens = tp / (tp + fn) if (tp + fn) != 0 else 0
    print("{} Patient {} has acc: {}, f1: {}, k:{} and f1_per_class: {}".format(model_name, patient_num,
                                                                             round(test_acc * 100, 1),
                                                                             round(test_f1 * 100, 1),
                                                                             round(test_k, 3),
                                                                             np.round(test_perclass_f1 * 100,
                                                                                      1)))
    return test_k

def get_performance_windows(preds, tts, print_it=True, window=40):
    tts_unfolded = torch.from_numpy(tts).unfold(0, window, window).numpy()
    preds_unfolded = torch.from_numpy(preds).unfold(0, window, window).numpy()
    perf_window = np.array([np.equal(tts_unfolded[i], preds_unfolded[i]).sum() / len(tts_unfolded[i]) for i in range(tts_unfolded.shape[0])])
    perf_window[perf_window != perf_window] = 1
    if print_it:
        for i in perf_window: print("{:.3f}".format(i), end=" ")
        print()
    perf_window = perf_window
    return perf_window

def get_windows(input, window=40):
    input_unfolded = torch.from_numpy(input).unfold(0, window, window).numpy()
    if (input_unfolded != input_unfolded).any(): print("Τhere are nan")
    input_unfolded = input_unfolded.mean(axis=-1)
    # input_unfolded = input_unfolded.repeat(window)
    return input_unfolded

def change_numbers_preds(preds, tts, argmax=True):
    if argmax: pred_plus = copy.deepcopy(preds).argmax(-1)
    else: pred_plus = copy.deepcopy(preds)
    pred_plus[pred_plus == 4] = 5
    pred_plus[pred_plus == 3] = 4
    pred_plus[pred_plus == 2] = 3
    pred_plus[pred_plus == 5] = 2

    target_plus = copy.deepcopy(tts)
    target_plus[target_plus == 4] = 5
    target_plus[target_plus == 3] = 4
    target_plus[target_plus == 2] = 3
    target_plus[target_plus == 5] = 2

    return pred_plus, target_plus

def find_matches(pred_plus, target_plus):

    non_matches = (pred_plus != target_plus).astype(int)
    non_matches_idx = non_matches.nonzero()[0]
    return non_matches_idx

def print_max_perf(predictors, performs, tts, window_floor):
    max_preds = []
    for i in range(len(performs[0])):
        max_mod = np.argmax(np.array([perf[i] for perf in performs]))
        max_preds.append(predictors[max_mod][i*window_floor:(i+1)*window_floor])
    max_preds = np.array(max_preds).flatten()

    multiclass = True
    model_name = "Max"
    test_acc = np.equal(tts, max_preds).sum() / len(tts)
    test_f1 = f1_score(max_preds, tts) if not multiclass else f1_score(max_preds, tts, average="macro")
    test_perclass_f1 = f1_score(max_preds, tts) if not multiclass else f1_score(max_preds, tts, average=None)
    test_k = cohen_kappa_score(tts, max_preds)
    test_auc = roc_auc_score(tts, max_preds) if not multiclass else 0
    test_conf = confusion_matrix(tts, max_preds)
    tp, fp, tn, fn = perf_measure(tts, max_preds)
    test_spec = tn / (tn + fp) if (tn + fp) != 0 else 0
    test_sens = tp / (tp + fn) if (tp + fn) != 0 else 0
    print("{} Patient {} has acc: {}, f1: {}, k:{} and f1_per_class: {}".format(model_name, patient_num,
                                                                             round(test_acc * 100, 1),
                                                                             round(test_f1 * 100, 1),
                                                                             round(test_k, 3),
                                                                             np.round(test_perclass_f1 * 100,
                                                                                      1)))
    return test_k, max_preds

print("Done Loading")

Done Loading


In [3]:
multimodal_config_name = "./configs/shhs/multi_modal/eeg_eog/established_models/fourier_transformer_eeg_eog_mat_BIOBLIP_lossw.json"
eeg_config_name = "./configs/shhs/single_channel/fourier_transformer_cls_eeg_mat_adv.json"
eog_config_name = "./configs/shhs/single_channel/fourier_transformer_cls_eog_mat.json"
emg_config_name = "./configs/shhs/single_channel/fourier_transformer_cls_emg_mat.json"
router_config_name = "./configs/shhs/router/router_fourier_tf_eeg_eog.json"

multimodal_config = process_config(multimodal_config_name, False)
eeg_config = process_config(eeg_config_name, False)
eog_config = process_config(eog_config_name, False)
emg_config = process_config(emg_config_name, False)
router_config = process_config(router_config_name, False)

multimodal_config.data_view_dir = [
            {"list_dir" : "patient_mat_list.txt", "data_type": "stft", "mod": "eeg", "num_ch": 1},
            {"list_dir" :"patient_eog_mat_list.txt", "data_type": "stft", "mod": "eog", "num_ch": 1},
            {"list_dir" :"patient_emg_mat_list.txt", "data_type": "stft", "mod": "emg", "num_ch": 1},
            {"list_dir" : "patient_mat_list.txt", "data_type": "time", "mod": "eeg", "num_ch": 1},
            {"list_dir" : "patient_eog_mat_list.txt", "data_type": "time", "mod": "eog", "num_ch": 1},
            {"list_dir" : "patient_emg_mat_list.txt", "data_type": "time", "mod": "emg", "num_ch": 1}
        ]

In [4]:
device = "cuda:0"

#Load the models
checkpoint_multimodal = torch.load(multimodal_config.model.save_dir, map_location="cpu")
checkpoint_eeg = torch.load(eeg_config.model.save_dir, map_location="cpu")
checkpoint_eog = torch.load(eog_config.model.save_dir, map_location="cpu")
checkpoint_emg = torch.load(emg_config.model.save_dir, map_location="cpu")
checkpoint_router = torch.load(router_config.model.save_dir, map_location="cpu")
dataloader = globals()[multimodal_config.dataloader_class]
data_loader = dataloader(config=multimodal_config)
data_loader.load_metrics_ongoing(checkpoint_multimodal["metrics"])
data_loader.weights = checkpoint_multimodal["logs"]["weights"]
_, best_model_multimodal = load_models(config=multimodal_config, device=device, checkpoint=checkpoint_multimodal)
_, best_model_eeg = load_models(config=eeg_config, device=device, checkpoint=checkpoint_eeg)
_, best_model_eog = load_models(config=eog_config, device=device, checkpoint=checkpoint_eog)
_, best_model_emg = load_models(config=emg_config, device=device, checkpoint=checkpoint_emg)
router_model  = load_models(config=router_config, device=device, checkpoint=checkpoint_emg, only_model=True)


We are splitting dataset by huy splits
True
We are loading weights
[0. 0. 0. 0. 0.]


In [12]:
router_model.train()
start = time.time()

tts, preds, batch_loss, early_stop = [], [], [], False
saved_at_step, prev_epoch_time = 0, 0
for logs["current_epoch"] in range(logs["current_epoch"], router_config.max_epoch):
    pbar = tqdm(enumerate(data_loader.train_loader), desc="Training", leave=None, disable=router_config.tdqm_disable, position=0)
    for batch_idx, batch in pbar:
        data, target, inits = batch[0], batch[1], batch[2]

        data = [data[i].float().cuda() for i in range(len(data))]

        if "softlabels" in router_config.config and router_config.config.softlabels:
            target = target.float().cuda()
        else:
            target = target.cuda().flatten(start_dim=0, end_dim=1).long()
            if len(target.shape) > 1:
                target = target.argmax(dim=1).cuda()

        optimizer.zero_grad()
        loss, pred = self.this_train_step_func(data, target, inits, teacher_preds)
        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
        optimizer.step()
        scheduler.step(logs["current_step"]+1)
        # if "sparse_loss" in self.config and self.config.sparse_loss:
        # loss, pred = self.sleep_train_one_step_sparse(data, target, inits)

        batch_loss.append(loss)
        tts.append(target)
        preds.append(pred)

        del data, target, pred, loss

        pbar_message = Fore.WHITE + "Training batch {0:d}/{1:d} steps no improve {2:d} with ".format(batch_idx, len(self.agent.data_loader.train_loader)-1, self.agent.logs["steps_no_improve"])
        mean_batch = self._calc_mean_batch_loss(batch_loss=batch_loss)
        for mean_key in mean_batch: pbar_message += "{}: {:.3f} ".format(mean_key, mean_batch[mean_key])

        if logs["current_step"] % router_config.validate_every == 0 and \
            logs["current_step"] // router_config.validate_every >= router_config.validate_after and \
            batch_idx!=0:
            early_stop = monitor_n_saver.checkpointing(batch_loss = mean_batch, predictions = preds, targets = tts)
            batch_loss, tts, preds = [], [], []

            if early_stop: break
            saved_at_step = logs["current_step"] // router_config.validate_every
            prev_epoch_time = time.time() - self.agent.start
            self.start = time.time()

        pbar_message += " time {:.1f} sec/step, ".format(prev_epoch_time)
        pbar_message += "saved at {}".format(saved_at_step)
        pbar.set_description(pbar_message)
        pbar.refresh()
        self.agent.logs["current_step"] += 1

            return_matches= True if self.w_loss["alignments"]!=0 else False
            return_order= True if self.w_loss["order"]!=0 else False

            # pred, matches = self.get_predictions_time_series_alignment(views, inits)
            output = self.agent.model(data, return_matches=return_matches, return_order=return_order)

            teacher_preds = teacher_preds.to(self.agent.device).flatten(start_dim=0, end_dim=1)  if "kd_label" in self.agent.config and self.agent.config.kd_label else None
            ce_loss = {}
            for k, v in output["preds"].items():
                ce_loss[k] = self.agent.loss(v, target, teacher_preds) if teacher_preds else self.agent.loss(v, target)

            total_loss = 0
            output_losses = {}
            for i in ce_loss:
                total_loss += self.w_loss[i] * ce_loss[i]
                ce_loss[i] = ce_loss[i].detach().cpu().numpy()
                output_losses.update({"ce_loss_{}".format(i): ce_loss[i]})

            if return_matches:
                matches = output["matches"].flatten(start_dim=0, end_dim=1).flatten(start_dim=1)
                if "blip_loss" in self.agent.config:
                    alignment_target = self.agent.alignment_target[:data[0].shape[0], :data[0].shape[1], :data[0].shape[1]].flatten(start_dim=0, end_dim=1)
                else:
                    alignment_target = self.agent.alignment_target[:data[0].shape[0], :data[0].shape[1]].flatten(start_dim=0, end_dim=1)

                alignment_loss = self.agent.alignment_loss(matches, alignment_target)
                total_loss += self.w_loss["alignments"]*alignment_loss
                alignment_loss = alignment_loss.detach().cpu().numpy()
                output_losses.update({"alignment_loss": alignment_loss})

            if return_order:
                unfolded_target = einops.rearrange(target," (b outer) -> b outer", b=data[0].shape[0], outer=data[0].shape[1])
                unfolded_target = unfolded_target.unfold(1,3,1)
                same_label_left = unfolded_target[:, :, 0] == unfolded_target[:, :, 1]
                same_label_right = unfolded_target[:, :, 2] == unfolded_target[:, :, 1]

                order_target = torch.zeros([data[0].shape[0], data[0].shape[1]-2]).cuda() != 0 #Initially everything is False
                order_target[same_label_left] = same_label_right[same_label_left] == True #If the ones that are left are True, index right and take only the True ones from that.

                order_target = order_target.flatten().long()
                order_loss = self.agent.order_loss(output["order"], order_target)
                total_loss += self.w_loss["order"]*order_loss
                order_loss = order_loss.detach().cpu().numpy()
                output_losses.update({"order_loss": order_loss})


            total_loss.backward()

            total_loss =  total_loss.detach().cpu().numpy()
            output_losses.update({"total": total_loss})

            for i in output["preds"]:  output["preds"][i] =  output["preds"][i].detach().cpu().numpy()

            return output_losses, output["preds"]