In [1]:
import os
from argparse import ArgumentParser
from shutil import copyfile
from pytorch_lightning.utilities.apply_func import move_data_to_device

import torch
from flask import Flask, render_template
from tqdm.auto import tqdm


def get_best_ckpt(folder):
    tmp = [el[:-5].split("-") for el in sorted(os.listdir(folder))]
    tmp = sorted(
        tmp,
        key=lambda x: (
            float(x[1].split("=")[1]),
            float(x[2].split("=")[1]),
            float(x[3].split("=")[1]),
        ),
        reverse=True,
    )
    return os.path.join(folder, "-".join(tmp[0]) + ".ckpt")


def get_50_intersection_metric(y_pred, y_true, threshold = 0.5):
    abc = torch.logical_and(y_pred, y_true).sum()
#     print(abc, y_true.sum())
    return abc > threshold * y_true.sum()


def get_preds(nets, dataloader):
    print("Running through dataset")
    outs = []
    for idx, batch in enumerate(tqdm(dataloader)):
        data, meta = move_data_to_device(batch, nets[0].device)
        y_preds = []
        for i, net in enumerate(nets):
            y_pred, y_true = net(data["feature"], meta["length"]), data["label"]
            y_preds.append((torch.sigmoid(y_pred) > 0.5).float())
        y_pred = torch.mean(torch.stack(y_preds), dim=0)
        for i in range(len(y_true)):
            outs.append(get_50_intersection_metric(y_pred[i], y_true[i]))
    return outs


In [2]:
from pbsp.metrics import batch_metrics, batch_work
from pbsp.net import Net

data_dir = os.path.abspath(os.path.expanduser("../data"))
ckpt_dir = os.path.abspath(os.path.expanduser("~/logs/cv_0"))
nets = []
metrics = []
for i in range(10):
    ckpt = get_best_ckpt(os.path.join(ckpt_dir, "fold_" + str(i), "checkpoints"))
    print(ckpt)
    nets.append(Net.load_from_checkpoint(ckpt).cuda())
    device = nets[i].device
    nets[i].freeze()
    nets[i].eval()
    # break

/home/crvineeth97/logs/cv_0/fold_0/checkpoints/epoch=73-v_mcc=0.553-v_acc=0.926-v_f1=0.579.ckpt
Loading pssm
Loading ss2
Loading solv
Positional weights will be computed on the fly
Loading pssm
Loading ss2
Loading solv
/home/crvineeth97/logs/cv_0/fold_1/checkpoints/epoch=54-v_mcc=0.533-v_acc=0.923-v_f1=0.562.ckpt
Loading pssm
Loading ss2
Loading solv
Positional weights will be computed on the fly
Loading pssm
Loading ss2
Loading solv
/home/crvineeth97/logs/cv_0/fold_2/checkpoints/epoch=58-v_mcc=0.534-v_acc=0.931-v_f1=0.555.ckpt
Loading pssm
Loading ss2
Loading solv
Positional weights will be computed on the fly
Loading pssm
Loading ss2
Loading solv
/home/crvineeth97/logs/cv_0/fold_3/checkpoints/epoch=52-v_mcc=0.482-v_acc=0.911-v_f1=0.511.ckpt
Loading pssm
Loading ss2
Loading solv
Positional weights will be computed on the fly
Loading pssm
Loading ss2
Loading solv
/home/crvineeth97/logs/cv_0/fold_4/checkpoints/epoch=55-v_mcc=0.547-v_acc=0.918-v_f1=0.573.ckpt
Loading pssm
Loading ss2
Loa

In [4]:
validations = []
for i in range(10):
    validations.append(get_preds([nets[i]], nets[i].val_dataloader()))

Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


Running through dataset


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




In [5]:
for validation in validations:
    valid = torch.stack(validation)
    print(valid.sum() * 1.0 / len(valid) * 100.0)

tensor(72.6986, device='cuda:0')
tensor(71.6898, device='cuda:0')
tensor(72.4464, device='cuda:0')
tensor(71.6267, device='cuda:0')
tensor(75.9773, device='cuda:0')
tensor(76.1034, device='cuda:0')
tensor(78.6255, device='cuda:0')
tensor(75.5359, device='cuda:0')
tensor(75.2837, device='cuda:0')
tensor(74.6532, device='cuda:0')


In [6]:
test = get_preds(nets, nets[0].test_dataloader())

Running through dataset


HBox(children=(FloatProgress(value=0.0, max=77.0), HTML(value='')))




In [7]:
testy = torch.stack(test)
print(testy.sum() * 1.0 / len(testy) * 100.0)

tensor(86.4854, device='cuda:0')
