In [None]:
from torchmetrics import AUROC
from torchmetrics.classification import BinaryAveragePrecision

import os
from os.path import basename, dirname, join, exists
import pickle
import time
import glob
import argparse
from tqdm import tqdm
import torch
from torchmdnet import datasets, attention_weights
from torchmdnet.models.model import load_model
from torchmdnet.utils import make_splits
from torchmdnet.data import Subset
from torch_geometric.data import DataLoader
from torch_scatter import scatter
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from torchmetrics.functional.classification import binary_average_precision
from torch.nn.functional import mse_loss, l1_loss

def rmse(pred, target):
    return torch.sqrt(mse_loss(pred, target))

def null_model(predicted, ground_truth, final_test=False):
    from sklearn.metrics import roc_auc_score
    import numpy as np
    n_tasks = ground_truth.shape[1]
    ground_truth_np = ground_truth.cpu().numpy()
    predicted_np = predicted.cpu().numpy()
    auc = []
    auc_dict = {i: float for i in range(n_tasks)}
    for i in range(n_tasks):
        if np.any(ground_truth_np[:, i] == 0) and np.any(ground_truth_np[:, i] == 1):
            auroc = AUROC(task='binary', ignore_index=-100)
            auc.append(auroc(torch.zeros_like(ground_truth[:, i])+1., ground_truth[:, i]))
            auc_dict[i] = auc
        else:
            continue
    if final_test:
        return auc_dict, sum(auc) / len(auc)

    return sum(auc) / len(auc)


def null_model_pr(predicted, ground_truth, final_test=False):
    from sklearn.metrics import roc_auc_score
    import numpy as np
    n_tasks = ground_truth.shape[1]
    ground_truth_np = ground_truth.cpu().numpy()
    predicted_np = predicted.cpu().numpy()
    auc = []
    auc_dict = {i: float for i in range(n_tasks)}
    for i in range(n_tasks):
        if np.any(ground_truth_np[:, i] == 0) and np.any(ground_truth_np[:, i] == 1):
            average_precision = BinaryAveragePrecision(threshold=None, ignore_index=-100)
            auc.append(average_precision(torch.zeros_like(ground_truth[:, i]).float()+1., ground_truth[:, i]))
            auc_dict[i] = auc
        else:
            continue
    if final_test:
        return auc_dict, sum(auc) / len(auc)

    return sum(auc) / len(auc)

def multitask_prauc(predicted, ground_truth, final_test=False):
    from sklearn.metrics import roc_auc_score
    import numpy as np
    n_tasks = ground_truth.shape[1]
    ground_truth_np = ground_truth.cpu().numpy()
    predicted_np = predicted.cpu().numpy()
    auc = []
    auc_dict = {i: float for i in range(n_tasks)}
    for i in range(n_tasks):
        if np.any(ground_truth_np[:, i] == 0) and np.any(ground_truth_np[:, i] == 1):
            average_precision = BinaryAveragePrecision(threshold=None, ignore_index=-100)
            auc.append(average_precision(predicted[:, i], ground_truth[:, i]))
            auc_dict[i] = auc
        else:
            continue
    if final_test:
        return auc_dict, sum(auc) / len(auc)

    return sum(auc) / len(auc)

def multitask_auc(predicted, ground_truth, final_test=False):
    from sklearn.metrics import roc_auc_score
    import numpy as np
    n_tasks = ground_truth.shape[1]
    ground_truth_np = ground_truth.cpu().numpy()
    predicted_np = predicted.cpu().numpy()
    auc = []
    auc_dict = {i: float for i in range(n_tasks)}
    for i in range(n_tasks):
        if np.any(ground_truth_np[:, i] == 0) and np.any(ground_truth_np[:, i] == 1):
            auroc = AUROC(task='binary', ignore_index=-100)
            auc.append(auroc(predicted[:, i], ground_truth[:, i]))
            auc_dict[i] = auc
        else:
            continue
    #import pdb; pdb.set_trace()
    if final_test:
        return auc_dict, sum(auc) / len(auc)
    return sum(auc) / len(auc)

dataset_name = "dili"

model_path = f"./models/{dataset_name}.ckpt"

dataset = "TDCTox"
dataset_arg = {"num_conformers": 1, "conformer": "best", "dataset": dataset_name}
dataset_root = "./data/TDCTox"
dataset_split = "scaffold"
splits_path = f"./data/TDCTox/splits/{dataset_name}_split_1_scaffold.npz"

#dataset = "MoleculeNet"
#dataset_arg = {"num_conformers": 1, "conformer": "best", "data_version": "geom", "dataset": dataset_name}
#dataset_root = "./data/MoleculeNet"
#dataset_split = "scaffold"
#splits_path = f"./data/MoleculeNet/splits/{dataset_name}_seed1_confs1_scaffold.npz"


data = getattr(datasets, dataset)(dataset_root, **dataset_arg)

splits = np.load(splits_path)
data = DataLoader(torch.utils.data.Subset(data, splits["idx_test"]), batch_size=1, num_workers=6)

# load model
print("loading model")
device = "cuda:0"

model = load_model(model_path, device=device).eval()





In [None]:
labels = []
preds_test = []
targets_test = []
energies = []


for batch in tqdm(data):
    labels.append(batch.tox_labels.numpy())
    targets_test += batch.tox_labels.cpu()
    energies.append(batch.y)
    
    z, pos, batch, Q = batch.z.to(device), batch.pos.to(device), batch.batch.to(device), batch.Q.to(device)
    pred, deriv = model(z, pos, batch, Q=Q)
    preds_test += pred.detach().cpu()
    
targets = torch.stack(targets_test)
preds = torch.stack(preds_test)

if dataset_name == "ld50":
    print(f"The MAE is {l1_loss(preds, targets)}")
    print(f"The RMSE is {rmse(preds, targets)}")
    print(f"The null-model RMSE is {rmse(torch.tensor([2.54]).repeat_interleave(len(targets)).unsqueeze(1), targets)}")
else:
    aucs, mean_auc = multitask_auc(preds, targets.long(), final_test=True)
    aucs_pr, mean_auc_pr = multitask_prauc(preds, targets.long(), final_test=True)
    null_aucs, mean_null_aucs = null_model(preds, targets.long(), final_test=True)
    null_aucs_pr, mean_null_aucs_pr = null_model_pr(preds, targets.long(), final_test=True)

    print(f"\n The ROC-AUC is: {mean_auc}\n")
    print(f"\n The PR-AUC is: {mean_auc_pr}\n")
    print(f"\n The null model is: {mean_null_aucs}\n")
    print(f"\n The PR null model is: {mean_null_aucs_pr}\n")
    
    
    