In [None]:
%%capture
# Imports

# To import from the gauche package
import sys
sys.path.append('..')
sys.path.append('../benchmarks/')

import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import torch, torch_geometric

from gauche.dataloader import DataLoaderMP
from gauche.dataloader.data_utils import transform_data
from gpytorch_metrics import negative_log_predictive_density, mean_standardized_log_loss, quantile_coverage_error

import gpytorch
from botorch import fit_gpytorch_model
from rdkit.Chem import MolFromSmiles

import scipy.sparse as sp
from functools import lru_cache
import warnings

warnings.filterwarnings(action='ignore', category=UserWarning, module=r'gpytorch')
warnings.filterwarnings(action='ignore', category=gpytorch.utils.warnings.NumericalWarning, module=r'gpytorch')

In [None]:
class WLFeatureFunction(torch.nn.Module):
    def __init__(self, n_iter):
        super().__init__()
        self.convs = torch.nn.ModuleList([
            torch_geometric.nn.WLConv() for _ in range(n_iter)])

    def forward(self, graphs):
        x = graphs.x  # node_labels_one_hot
        edges = graphs.edge_index
        idx = graphs.batch

        hists = []
        for conv in self.convs:
            x = conv(x, edges)
            hists.append(conv.histogram(x, idx, norm=False))

        hists = torch.cat(hists, axis=1).float()
        return torch.nn.functional.normalize(hists, dim=1)

class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_name, data_loc):
        super().__init__()
        self.loader = DataLoaderMP()
        self.loader.load_benchmark(dataset_name, data_loc)
        self.y = self.loader.labels

    def __getitem__(self, idx):
        mol = MolFromSmiles(self.loader.features[idx])

        node_labels = np.array([
            mol.GetAtomWithIdx(i).GetSymbol() \
            for i in range(mol.GetNumAtoms()) \
        ], dtype=str)

        edges = []
        for bond in mol.GetBonds():
            start_idx = bond.GetBeginAtomIdx()
            end_idx = bond.GetEndAtomIdx()
            bond_type = bond.GetBondTypeAsDouble()

            edges.append((start_idx, end_idx))
            edges.append((end_idx, start_idx))

        if len(node_labels) == 1: edges = [(0, 0)]

        return torch_geometric.data.Data(
                    x=node_labels,
                    edge_index=torch.tensor(np.vstack(edges).T),
                )

    def __len__(self):
        return len(self.y)

class GraphGP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean = gpytorch.means.ConstantMean()
        self.covariance = gpytorch.kernels.LinearKernel(len(train_x.T))
        self.covariance.offset.requires_grad_(False)

    def forward(self, x):
        mean = self.mean(torch.zeros(len(x.data), 1)).float()
        covariance = self.covariance(x)
        return gpytorch.distributions.MultivariateNormal(mean, covariance)

In [None]:
datasets = {
    'Photoswitch': '../data/property_prediction/photoswitches.csv',
    'ESOL': '../data/property_prediction/ESOL.csv',
    'FreeSolv': '../data/property_prediction/FreeSolv.csv',
    'Lipophilicity': '../data/property_prediction/Lipophilicity.csv'
}

n_trials, test_set_size = 20, 0.2

for dataset_name, data_loc in datasets.items():
    
    data = GraphDataset(dataset_name, data_loc)
    y = data.y

    data = next(iter(torch_geometric.loader.DataLoader(data, batch_size=len(data), shuffle=False)))

    node_labels = np.hstack(data.x).reshape(-1, 1)
    node_labels = OneHotEncoder().fit_transform(node_labels)
    node_labels = torch.tensor(node_labels.toarray()).float()
    data.x = node_labels

    X = WLFeatureFunction(n_iter=3)(data).numpy()

    print(dataset_name); print('-'*50)
    r2_list = []; rmse_list = []; mae_list = []; nlpd_list = []; msll_list = []; qce_list = []
    for i in range(0, n_trials):
        np.random.seed(i); torch.manual_seed(i)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_set_size, random_state=i)

        #  We standardise the outputs but leave the inputs unchanged
        # this seems to introduce numerical instabilities
        _, y_train, _, y_test, y_scaler = transform_data(
            np.zeros_like(y_train), y_train, np.zeros_like(y_test), y_test)

        y_train = torch.tensor(y_train).flatten().float()
        y_test = torch.tensor(y_test).flatten().float()

        X_train = torch.tensor(X_train).float()
        X_test = torch.tensor(X_test).float()

        # initialise GP likelihood and model
        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        model = GraphGP(X_train, y_train, likelihood)

        # Find optimal model hyperparameters
        model.train()
        likelihood.train()

        # "Loss" for GPs - the marginal log likelihood
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

        # Use the BoTorch utility for fitting GPs in order to use the LBFGS-B optimiser (recommended)
        # fit_gpytorch_model(mll)

        optimizer = torch.optim.Adam(model.parameters(), lr=0.5)  # Includes GaussianLikelihood parameters

        for _ in range(150):
            optimizer.zero_grad()
            output = model(X_train)
            loss = -mll(output, y_train)
            loss.backward()
            optimizer.step()
        # print('Training successful')

        # Get into evaluation (predictive posterior) mode
        model.eval()
        likelihood.eval()

        # full GP predictive distribution
        trained_pred_dist = likelihood(model(X_test))

        # Compute NLPD on the Test set
        try:
            nlpd = negative_log_predictive_density(trained_pred_dist, y_test)
        except:
            with gpytorch.settings.cholesky_jitter(1e-1):
                nlpd = negative_log_predictive_density(trained_pred_dist, y_test)

        # Compute MSLL on Test set
        msll = mean_standardized_log_loss(trained_pred_dist, y_test)

        # Compute quantile coverage error on test set
        qce = quantile_coverage_error(trained_pred_dist, y_test, quantile=95)

        # mean and variance GP prediction
        f_pred = model(X_test)

        y_pred = f_pred.mean

        # Transform back to real data space to compute metrics and detach gradients
        y_pred = y_scaler.inverse_transform(y_pred.detach().unsqueeze(dim=1))
        y_test = y_scaler.inverse_transform(y_test.detach().unsqueeze(dim=1))

        # Output Standardised RMSE and RMSE on Train Set
        y_train = y_train.detach()
        y_pred_train = model(X_train).mean.detach()
        train_rmse_stan = np.sqrt(mean_squared_error(y_train, y_pred_train))
        train_rmse = np.sqrt(
            mean_squared_error(y_scaler.inverse_transform(y_train.unsqueeze(dim=1)),
                                y_scaler.inverse_transform(y_pred_train.unsqueeze(dim=1))))

        # Compute R^2, RMSE and MAE on Test set
        score = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        mae = mean_absolute_error(y_test, y_pred)

        nlpd_list.append(nlpd)
        msll_list.append(msll)
        qce_list.append(qce)

        r2_list.append(score)
        rmse_list.append(rmse)
        mae_list.append(mae)

    nlpd_list = torch.tensor(nlpd_list)
    msll_list = torch.tensor(msll_list)
    qce_list = torch.tensor(qce_list)

    r2_list = np.array(r2_list)
    rmse_list = np.array(rmse_list)
    mae_list = np.array(mae_list)

    print("\nmean NLPD: {:.4f} +- {:.4f}".format(torch.mean(nlpd_list), torch.std(nlpd_list) / torch.sqrt(torch.tensor(n_trials))))
    print("mean MSLL: {:.4f} +- {:.4f}".format(torch.mean(msll_list), torch.std(msll_list) / np.sqrt(torch.tensor(n_trials))))
    print("mean QCE: {:.4f} +- {:.4f}".format(torch.mean(qce_list), torch.std(qce_list) / np.sqrt(torch.tensor(n_trials))))

    print("mean R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list), np.std(r2_list) / np.sqrt(len(r2_list))))
    print("mean RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list), np.std(rmse_list) / np.sqrt(len(rmse_list))))
    print("mean MAE: {:.4f} +- {:.4f}\n".format(np.mean(mae_list), np.std(mae_list) / np.sqrt(len(mae_list))))
