In [None]:
import sys
sys.exc_info()
from utils.config import process_config
from datasets.sleepset import *
from graphs.models.attention_models.windowFeature_base import *
from sklearn.metrics import f1_score, cohen_kappa_score, roc_auc_score, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt


def get_predictions_time_series(model, views, inits):
    """
    This is a function to exploit the fact that time series are not always continuous. We dont want to correlate signals from different patients/recordings just because the batch is not fully dividing the number of recording imgs.
    :param views: List of tensors, data views/modalities
    :param inits: Tensor indicating with value one, when there incontinuities.
    :return: predictions of the model on the batch
    """
    inits_sum = (inits.sum(dim=1) > 1).nonzero(as_tuple=True)[0]
    if len(inits_sum) > 0:
        batch = views[0].shape[0]
        outer = views[0].shape[1]
        batch_idx_checked = torch.ones(batch, dtype=torch.bool)
        pred = torch.zeros(batch * outer, 5).cuda()
        for idx in inits_sum:
            if inits[idx].sum() > 1:
                ones_idx = (inits[idx] > 0).nonzero(as_tuple=True)[0]
                if (ones_idx[0] + 1 == ones_idx[1]):  # and ones_idx[0]!=0 and ones_idx[1]!= len(inits[idx])
                    if ones_idx[0] == 0:
                        pred_split_0 = model([view[idx, ones_idx[0]].unsqueeze(dim=0).unsqueeze(dim=1) for view in views])
                    else:
                        pred_split_0 = model([view[idx, :ones_idx[0] + 1].unsqueeze(dim=0) for view in views])
                    if ones_idx[1] == len(inits[idx]):
                        pred_split_1 = model([view[idx, -1].unsqueeze(dim=0).unsqueeze(dim=1) for view in views])
                    else:
                        pred_split_1 = model([view[idx, ones_idx[1]:].unsqueeze(dim=0) for view in views])

                    pred[idx * outer:(idx + 1) * outer] = torch.cat([pred_split_0, pred_split_1], dim=0)
                    batch_idx_checked[idx] = False
                else:
                    pred[idx * outer:(idx + 1) * outer] = model([view[idx].unsqueeze(dim=0) for view in views])
        pred[batch_idx_checked.repeat_interleave(outer)] = model([view[batch_idx_checked] for view in views])
    else:
        pred = model(views)
    return pred
def perf_measure(y_actual, y_hat):
    TP = 0
    FP = 0
    TN = 0
    FN = 0

    for i in range(len(y_hat)):
        if y_actual[i] == y_hat[i] == 1:
            TP += 1
        if y_hat[i] == 1 and y_actual[i] != y_hat[i]:
            FP += 1
        if y_actual[i] == y_hat[i] == 0:
            TN += 1
        if y_hat[i] == 0 and y_actual[i] != y_hat[i]:
            FN += 1

    return (TP, FP, TN, FN)
def test(model,data_loader):
    model.eval()
    with torch.no_grad():
        tts, preds, inits = [], [], []
        pbar = tqdm(enumerate(data_loader.test_loader), desc="Test")
        for batch_idx, (data, target, init, _) in pbar:
            views = [data[i].float().to(device) for i in range(len(data))]
            label = target.to(device).flatten()
            pred = get_predictions_time_series(model, views, init)
            tts.append(label)
            preds.append(pred)
            inits.append(init.flatten())
            pbar.set_description("Test batch {0:d}/{1:d}".format(batch_idx, len(data_loader.test_loader)))
            pbar.refresh()
        tts = torch.cat(tts).cpu().numpy()
        preds = torch.cat(preds).cpu().numpy()

    multiclass = False
    if preds.shape[1] > 2:
        multiclass = True
    preds = preds.argmax(axis=1)
    test_acc = np.equal(tts, preds).sum() / len(tts)
    test_f1 = f1_score(preds, tts) if not multiclass else f1_score(preds, tts, average="macro")
    test_perclass_f1 = f1_score(preds, tts) if not multiclass else f1_score(preds, tts, average=None)
    test_k = cohen_kappa_score(tts, preds)
    test_auc = roc_auc_score(tts, preds) if not multiclass else 0
    test_conf = confusion_matrix(tts, preds)
    tp, fp, tn, fn = perf_measure(tts, preds)
    test_spec = tn / (tn + fp)
    test_sens = tp / (tp + fn)
    print("Test accuracy: {0:.2f}% f1 :{1:.4f}, k :{2:.4f}, sens:{3:.4f}, spec:{4:.4f}, f1_per_class :{5:40}".format(
            test_acc * 100,
            test_f1,
            test_k, test_spec, test_sens,
            "{}".format(list(test_perclass_f1))))
    return test_acc, test_f1, test_k, test_auc, test_conf, test_perclass_f1, test_spec, test_sens
def plot_hypnogram(data_loader, patient_num, model, device):
    data_loader.test_loader.dataset.choose_specific_patient(patient_num)
    data, target, inits, idxs = next(iter(data_loader.test_loader))
    views = [data[i].float().to(device) for i in range(len(data))]
    target = target.to(device).flatten()
    pred = get_predictions_time_series(model, views, inits)
    pred = pred.argmax(axis=1)

    target = target + 0.05
    target = target[:60 * 2 * 3]
    pred = pred[:60 * 2 * 3]
    hours = len(target)
    plt.plot(pred.detach().cpu().numpy())
    plt.plot(target.detach().cpu().numpy())
    plt.yticks([0, 1, 2, 3, 4], labels=["Wake", "N1", "N2", "N3", "REM"])
    plt.xticks([i * 120 for i in range((hours // 120) + 1)],
               labels=["{}_hours".format(i) for i in range((hours // 120) + 1)])
    plt.legend(["Prediction", "Label"])
    plt.show()
def sleep_plot_losses( current_step, config, train_logs, test_logs, best_logs):
    print("we are in loss")
    step = int(current_step)
    plt.figure()
    plt.plot(range(1,step), train_logs[1:step, 1].cpu().numpy(), label="Train")
    plt.plot(range(1,step), train_logs[1:step, 0].cpu().numpy(), label="Valid")
    plt.plot((best_logs[0].item(), best_logs[0].item()), (0, best_logs[1].item()), linestyle="--",
             color="y", label="Chosen Point")
    plt.plot((0, best_logs[0].item()), (best_logs[1].item(), best_logs[1].item()), linestyle="--",
             color="y")

    if config.rec_test:
        plt.plot(range(step), test_logs[0:step, 1].cpu().numpy(), label="Test")
        best_test = test_logs[:, 1].argmin()
        plt.plot((test_logs[best_test, 0].item(), test_logs[best_test, 0].item()),
                 (0, test_logs[best_test, 1].item()), linestyle="--", color="r", label="Actual Best Loss")
        plt.plot((0, test_logs[best_test, 0].item()),
                 (test_logs[best_test, 1].item(), test_logs[best_test, 1].item()), linestyle="--",
                 color="r")

        best_test = test_logs[:, 4].argmax()
        plt.plot((test_logs[best_test, 0].item(), test_logs[best_test, 0].item()),
                 (0, test_logs[best_test, 1].item()), linestyle="--", color="g", label="Actual Best Kappa")
        plt.plot((0, test_logs[best_test, 0].item()),
                 (test_logs[best_test, 1].item(), test_logs[best_test, 1].item()), linestyle="--",
                 color="g")
    print(best_logs)
    plt.xlabel('Epochs')
    plt.ylabel('Loss Values')
    plt.title("Loss")
    plt.legend()
    plt.savefig("/users/sista/kkontras/Documents/Sleep_Project/data/2021_data/loss.png")
    plt.show()
def sleep_plot_k(current_step, config, train_logs, test_logs, best_logs):
    step = int(current_step)

    if step < 5:
        return
    plt.figure()
    plt.plot(range(1,step), train_logs[1:step, 6].cpu().numpy(), label="Train")
    plt.plot(range(1,step), train_logs[1:step, 7].cpu().numpy(), label="Valid")
    plt.plot((best_logs[0].item(), best_logs[0].item()), (0, best_logs[4].item()), linestyle="--",
             color="y", label="Chosen Point")
    plt.plot((0, best_logs[0].item()), (best_logs[4].item(), best_logs[4].item()), linestyle="--",
             color="y")

    if config.rec_test:
        plt.plot(range(step), test_logs[0:step, 4].cpu().numpy(), label="Test")
        best_test = test_logs[:, 4].argmax()
        plt.plot((test_logs[best_test, 0].item(), test_logs[best_test, 0].item()),
                 (0, test_logs[best_test, 4].item()), linestyle="--", color="r", label="Actual Best")
        plt.plot((0, test_logs[best_test, 0].item()),
                 (test_logs[best_test, 4].item(), test_logs[best_test, 4].item()), linestyle="--",
                 color="r")

    plt.xlabel('Epochs')
    plt.ylabel("Cohen's Kappa")
    plt.title("Kappa")
    plt.legend()
    plt.savefig("/users/sista/kkontras/Documents/Sleep_Project/data/2021_data/kappa.png")
    plt.show()

In [None]:
config = "./configs/shhs/fourier_transformer_eeg.json"
config = process_config(config)
config.test_batch = 512
config.print_statistics = False

device = "cuda:{}".format(config.gpu_device[0])
# dataloader = globals()[config.dataloader_class]
# data_loader = dataloader(config=config)
model_class = globals()[config.model_class]
model = model_class([], channel = config.channel)
model = model.to(device)
model = nn.DataParallel(model, device_ids=[torch.device(i) for i in config.gpu_device])
checkpoint = torch.load(config.save_dir)
model.load_state_dict(checkpoint["best_model_state_dict"])

train_logs = checkpoint["train_logs"]
test_logs = checkpoint["test_logs"]
current_epoch = checkpoint["epoch"]
best_logs = checkpoint["best_logs"]

# test_acc, test_f1, test_k, test_auc, test_conf, test_perclass_f1, test_spec, test_sens = test(model, data_loader)
# plot_hypnogram(data_loader, 20, model, device)

In [1]:

train_logs

/usr/local/opt/python/bin/python3.7
