In [None]:
import numpy as np
import sparsechem as sc
import torch
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import spearmanr, pearsonr
import matplotlib.pyplot as plt

In [None]:
# Specify full paths here
paths = [
    'cp464*****/pred/pred.json',
    'cp465***pred/pred.json',
    'cp568***/pred/pred.json',
    'cp569****/pred/pred.json',
    'cp570*****/pred/pred.json'
]
weights_path = 'path to reg_weights.csv'
y_true_path = 'path to reg_T10_y.npz'
y_mask_path = 'path to /reg_T10_censor_y.npz'
folding_path = 'path to reg_T11_fold_vector.npy'

In [None]:
weights = pd.read_csv(weights_path)
tasks_to_consider = weights[weights.aggregation_weight == 1].task_id.values

In [None]:
# Read true labels, mask them wrt validation fold
y = sc.load_sparse(y_true_path)
y_mask = sc.load_sparse(y_mask_path)
folding = np.load(folding_path)
y = y[folding == 4,:]
y_mask = y_mask[folding == 4, :]

In [None]:
y_ens = []
for i in range(len(paths)):
    y_ens.append(torch.load(paths[i]).astype('float64').tocsr())

In [None]:
summary = []
for task in tasks_to_consider:
    labels = y[:, task].data
    mask = y_mask[:, task].data
    ens_pred = [yhat[:, task].data for yhat in y_ens]
    labels = labels[mask == 0]
    ens_pred = [yhat[mask == 0] for yhat in ens_pred]
    ens_pred = np.vstack(ens_pred)
    ens_means = ens_pred.mean(axis = 0)
    ens_std = ens_pred.std(axis = 0)
    ens_abs_error = np.abs(ens_means - labels)
    summary.append({
        'task': task,
        'spearman': spearmanr(ens_abs_error, ens_std)[0],
        'pearson': pearsonr(ens_abs_error, ens_std)[0],
        'R2': r2_score(labels, ens_means)
    })

In [None]:
summary = pd.DataFrame.from_records(summary)

In [None]:
# Now you have a dataframe with each line corresponding to one task 
# and specifying R2, Spearman and Pearson correlations
# Code below are just simple plots based on that

In [None]:
plt.hist(summary.spearman.values, bins=20)

In [None]:
plt.scatter(summary.R2.values, summary.spearman.values)
plt.xlabel('r2')
plt.ylabel('spearman')

In [None]:
plt.hist(summary[summary.R2 > 0.5].spearman.values, bins=20)

In [None]:
plt.hist(summary.pearson.values, bins=20)