In [1]:
import os
import pickle
import numpy as np
import yaml
import random
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from hmc_uq.utils.evaluation import HMCSampleEvaluation, PredictiveEvaluation
from scipy.signal import correlate

from torch.utils.data import DataLoader
from hmc_uq.utils.data import SparseDataset

import torch
import torch.nn.functional as F
import hamiltorch
from hmc_uq.utils.models import MLP

from sklearn.metrics import roc_auc_score



In [2]:
target = 'CYP'
target_id = 1908
nr_eval_params = 10000
init = 'bbb'
nr_samples = 10000
nr_chains = 5
step_size = 0.0013
l = 1000
num_input_features = 4096


### Load Data

In [3]:
folds = {'train':[2,3,4], 'val': [1], 'test': [0]}

ds_type = 'test'
burnin = 4000
fold = folds[ds_type ]

In [4]:
ckpt_path = 'configs/ckpt_paths/HMC.yaml'
with open(ckpt_path) as f:
    ckpt_paths = yaml.load(f, Loader=yaml.FullLoader)[f'{init}init']

models_config = 'configs/models/HMC.yaml'
with open(models_config) as f:
    models_configs = yaml.load(f, Loader=yaml.FullLoader)   

ckpt_paths = list(ckpt_paths.values())

In [5]:
device = 'cpu'
target_info = models_configs[target]
hidden_sizes = target_info['hidden_sizes']
nr_layers = target_info['nr_layers']
tau_list = [target_info['weight_decay']]
tau_list = torch.tensor(tau_list).to(device)
tau_out = target_info['tau_out']
dropout = 0
model_loss = 'binary_class_linear_output'

In [6]:
target_info = models_configs[target]
hidden_sizes = target_info['hidden_sizes']
nr_layers = target_info['nr_layers']
#l = target_info['L']
#step_size = target_info['step_size']

parameter_sizes = [4096 * hidden_sizes, hidden_sizes]
for layer in range(nr_layers - 1):
    parameter_sizes.append(hidden_sizes * hidden_sizes, hidden_sizes)
parameter_sizes.extend([hidden_sizes, 1])


param_names = []
for i in range(nr_layers + 1):
    for t in ['weight', 'bias']:
        param_names.append(f'{t}.{i}')

In [7]:
X_singleTask = np.load('data/chembl_29/chembl_29_X.npy', allow_pickle=True).item().tocsr()
Y_singleTask = np.load('data/chembl_29/chembl_29_thresh.npy', allow_pickle=True).item().tocsr()[:,target_id]
folding=np.load('data/chembl_29/folding.npy')

dataset = SparseDataset(X_singleTask, Y_singleTask, folding, fold, device)
dataloader = DataLoader(dataset, batch_size=200, shuffle=True)

fp, labels = dataset.__getdatasets__()
labels = labels.squeeze(dim = 1)

# Predict
### *if no prediction, yet*

### Load Samples and predict

In [10]:
def load_params(burnin, chain, ckpt_paths, parameter_sizes, return_layer = True):
    params = np.load(f'{ckpt_paths[chain]}.npy')[:, burnin:10000]
    
    if return_layer:
        cumsum = np.cumsum(parameter_sizes)[:-1]
        params= np.array_split(params, cumsum, axis = 2)
        return params
    else:
        return params

In [19]:
preds_chains = []
for chain in range(nr_chains):
    params = load_params(4000, chain, ckpt_paths, parameter_sizes, return_layer=False)[0]
    net = MLP(
        input_features=num_input_features, 
        output_features=1,
        nr_layers=nr_layers,
        hidden_sizes=hidden_sizes,          
        dropout=dropout
        )

    params_torch = torch.unbind(torch.from_numpy(params))
    
    preds, _ = hamiltorch.predict_model(net, test_loader = dataloader, samples=params_torch, model_loss=model_loss, tau_out=tau_out, tau_list=tau_list)
    preds = torch.squeeze(preds, 2)  
    preds_chains.append(preds)

preds_chains = torch.stack(preds_chains)
preds_chains = F.sigmoid(preds_chains)
preds_chains = np.reshape(preds_chains, (-1, len(labels))).mean(axis = 0)

In [49]:
file_name = f'results/predictions/HMC/{target_id}_e{step_size}_l{l}_nrs{nr_samples}_nrc{nr_chains}_{init}init_{ds_type}'
np.save(file_name, preds_chains.numpy())


# Load Prediction
### *if prediction already exists*

In [8]:
file_name = f'results/predictions/HMC/{target_id}_e{step_size}_l{l}_nrs{nr_samples}_nrc{nr_chains}_{init}init_{ds_type}.npy'
preds_chains = np.load(file_name)

# Evaluate

#### 1. AUC

In [9]:
eval = PredictiveEvaluation(preds_chains, labels, ds_type)

In [10]:
results = eval.evaluate()
print(results)


        auc        nll       ece       ace        bs
0  0.495606  52.312923  0.276079  0.276079  0.276213
