In [10]:
import math
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import os

import shutil
import sklearn
from sklearn.model_selection import KFold
import gpytorch
from gpytorch.models import ExactGP
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel, MaternKernel

from sklearn.metrics import confusion_matrix
import itertools
from sklearn.metrics import precision_score, recall_score, roc_auc_score, matthews_corrcoef, balanced_accuracy_score, confusion_matrix, f1_score, roc_curve,precision_recall_curve, auc
# from scipy.stats
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import accuracy_score, precision_score, f1_score, roc_auc_score, roc_curve, precision_recall_curve, auc, recall_score, confusion_matrix

import sys
sys.path.append('/Users/radhi/Desktop/GitHub/atom2024/atom2024/notebooks/')
from RF_GSCV import * # RF_GSCV contains the calculate metrics function to get the TP, TN, FP, FN scores 
from RF_atomver import prediction_type 



In [11]:

class DirichletGPModel(ExactGP):
    """
    A Dirichlet Gaussian Process (GP) model for multi-class classification.
    This model uses a Gaussian Process with a Dirichlet prior to handle multi-class classification tasks.
    It extends the ExactGP class from GPyTorch, a library for Gaussian Processes in PyTorch.
    Attributes:
        mean_module (gpytorch.means.ConstantMean): The mean module for the GP, initialized with a constant mean function for each class.
        covar_module (gpytorch.kernels.ScaleKernel): The covariance module for the GP, using a scaled RBF kernel for each class.

    Args:
        train_x (torch.Tensor): Training data features.
        train_y (torch.Tensor): Training data labels.
        likelihood (gpytorch.likelihoods.Likelihood): The likelihood function.
        num_classes (int): The number of classes for the classification task.
    """
    def __init__(self, train_x, train_y, likelihood, num_classes,kernal):
        super(DirichletGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean(batch_shape=torch.Size((num_classes,)))
        if kernal == 'matern': 
            self.covar_module = ScaleKernel(MaternKernel(nu=0.5, batch_shape=torch.Size((num_classes,))),
                batch_shape=torch.Size((num_classes,))
            )
        elif kernal == 'RBF': 
            self.covar_module = ScaleKernel(
            RBFKernel(batch_shape=torch.Size((num_classes,))),
            batch_shape=torch.Size((num_classes,)),)

        else: 
            print('invalid')
        
    def forward(self, x):
        """
        Forward pass through the GP model.
        Args:
            x (torch.Tensor): Input data features.
        Returns:
            gpytorch.distributions.MultivariateNormal: The multivariate normal distribution representing the GP posterior.
        """
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


In [12]:
class Trainer: 
    def __init__(self,model, likelihood, iterations): 
        self.model = model
        self.likelihood = likelihood 
        smoke_test = ('CI' in os.environ)
        self.n_iterations = 2 if smoke_test else iterations
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
        self.loss_fn = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)
        
    def train(self, train_x, train_y): 
        self.model.train()
        self.likelihood.train()
        predictions = [] 
        for i in range(self.n_iterations): 
            self.optimizer.zero_grad()
            output = self.model(train_x)
            loss = -self.loss_fn(output, self.likelihood.transformed_targets).sum()
            loss.backward()
            if (i%10==0): 
                print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
                    i + 1, self.n_iterations, loss.item(),
                    self.model.covar_module.base_kernel.lengthscale.mean().item(),
                    self.model.likelihood.second_noise_covar.noise.mean().item()
                ))
             
            self.optimizer.step() 
    def predict(self, input): 
        """
        Make predictions using the GP model.

        Args:
            input (torch.Tensor): The input data for making predictions.
        
        Returns:
            dist (gpytorch.distributions.MultivariateNormal): The distribution representing the GP posterior.
            observed_pred (gpytorch.distributions.MultivariateNormal): The predicted distribution considering the likelihood.
            pred_means (torch.Tensor): The means of the predicted distributions.
            class_pred (torch.Tensor): The predicted class labels.
        """
        self.model.eval()
        self.likelihood.eval()

        with gpytorch.settings.fast_pred_var(), torch.no_grad():
            dist = self.model(input)     # output distribution
            pred_means = dist.loc          # means of distributino 
            observed_pred = self.likelihood(self.model(input))    # likelihood predictions mean and var  

            class_pred = self.model(input).loc.max(0)[1]
            
        return dist, observed_pred, pred_means, class_pred
    

    def evaluate(self, x_input, y_true): 
        """
        Evaluate the GP model.

        Args:
            x_input (torch.Tensor): The input data features.
            y_true (torch.Tensor): The true labels for the input data.
        
        Returns:
            y_pred (numpy.ndarray): The predicted class labels.
        """
        y_pred = self.model(x_input).loc.max(0)[1].numpy()
        
        return y_pred
    
    def calculate_metrics(y_true, y_pred): 
        
        # return tp, tn, fp, fn
        y_true = pd.Series(y_true) if not isinstance(y_true, pd.Series) else y_true
        y_pred = pd.Series(y_pred) if not isinstance(y_pred, pd.Series) else y_pred
        
        tp = np.sum((y_true == 1) & (y_pred == 1))
        tn = np.sum((y_true == 0) & (y_pred == 0))
        fp = np.sum((y_true == 0) & (y_pred == 1))
        fn = np.sum((y_true == 1) & (y_pred == 0))
        
        return tp, tn, fp, fn

    def gp_results(self, x_input, y_true, plot_title=None): 
        """
        Calculate evaluation metrics and print results.

        Args:
            x_input (torch.Tensor): The input data features.
            y_true (torch.Tensor or numpy.ndarray): The true labels for the input data.
            plot_title (str, optional): The title for the confusion matrix plot.
        
        Returns:
            dict: A dictionary containing evaluation metrics and confusion matrix components.
        """
        y_pred = self.evaluate(x_input, y_true) 
        if isinstance(y_true, torch.Tensor):
            y_true = y_true.numpy().reshape(-1)
        # plot_confusion_matrix(y_true, y_pred, ['0','1'], title=plot_title)
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        dist = self.model(x_input)     # get predicted distributions 
        pred_means = dist.loc          # means for predicted dist  

        recall = recall_score(y_true, y_pred)
        tp, tn, fp, fn = calculate_metrics(y_true, y_pred) 
       
        specificity = tn / (tn + fp) 
        cm = confusion_matrix(y_true, y_pred)
        cm_flattened = cm.flatten().tolist()
        f1 = f1_score(y_true,y_pred)
        roc_auc = roc_auc_score(y_true,y_pred)
        mcc = matthews_corrcoef(y_true,y_pred)
        bal_acc = balanced_accuracy_score(y_true,y_pred)
        print(f'accuracy: {accuracy:.4f}, precision: {precision:.4f}, recall: {recall:.4f}, specificity: {specificity:.4f}, cm: {cm}')
        return {'accuracy': accuracy, 'precision': precision,  'recall':recall, 'specificity':specificity, 
                'f1':f1,'ROC_AUC': roc_auc,'MCC': mcc,'balanced_accuracy': bal_acc,'cm': cm_flattened,
                'TN': tn, 'FN': fn, 'FP': fp, 'TP': tp }

       

In [13]:
def make_torch_tens_float(filepath, filename): 
    trainX_df = pd.read_csv(filepath+filename+'_trainX.csv')
    trainy_df = pd.read_csv(filepath+filename+'_train_y.csv')
    testX_df = pd.read_csv(filepath+filename+'_testX.csv')
    testy_df = pd.read_csv(filepath+filename+'_test_y.csv')

    train_x_temp = trainX_df.to_numpy().astype("double") # double 
    test_x_temp = testX_df.to_numpy().astype("double") #double 
    
    train_y_temp = trainy_df.to_numpy().flatten().astype("double") #double 
    test_y_temp = testy_df.to_numpy().flatten().astype("double") #double 
   
    trainX = torch.as_tensor(train_x_temp, dtype=torch.float32)
    trainy = torch.as_tensor(train_y_temp, dtype=torch.float32)
    testX = torch.as_tensor(test_x_temp, dtype=torch.float32)
    testy = torch.as_tensor(test_y_temp, dtype=torch.float32)
    return trainX, trainy, testX, testy

In [14]:

def save_results(trainX, trainy, testX, testy, root_name, kernal, n_iterations=300, n_samples=100):
    """
    Train a Dirichlet Gaussian Process model and save the training and test performance results.

    This function trains a Dirichlet GP model on the given training data, evaluates it on both the training
    and test data, and saves various performance metrics and predictions to pandas DataFrames.

    Args:
        trainX (torch.Tensor): The training data features.
        trainy (torch.Tensor): The training data labels.
        testX (torch.Tensor): The test data features.
        testy (torch.Tensor): The test data labels.
        root_name (str): The root name used for labeling the model in the results.
        n_iterations (int, optional): The number of training iterations. Default is 300.
        n_samples (int, optional): The number of samples for prediction. Default is 100.

    Returns:
        train_perf_df (pd.DataFrame): DataFrame containing performance metrics and predictions for the training data.
        test_perf_df (pd.DataFrame): DataFrame containing performance metrics and predictions for the test data.
    """
    likelihood = DirichletClassificationLikelihood(trainy.long(), learn_additional_noise=True)
    model = DirichletGPModel(trainX, likelihood.transformed_targets, likelihood, num_classes=likelihood.num_classes, kernal=kernal)
    # n_iterations = 300
    trainer = Trainer(model, likelihood, n_iterations)
    trainer.train(trainX, trainy) 
  
    train_dist, train_observed_pred, train_pred_means, train_pred  = trainer.predict(trainX)
    train_results = trainer.gp_results(trainX, trainy)
    test_dist, test_observed_pred, test_pred_means, test_pred  = trainer.predict(testX)
    test_results = trainer.gp_results(testX, testy)
    
    train_observed_pred.mean.numpy()
    train_pred_variance2D = train_observed_pred.variance.numpy()
    test_observed_pred.mean.numpy()
    test_pred_variance2D=test_observed_pred.variance.numpy()
    
    train_pred_samples = train_dist.sample(torch.Size((256,))).exp()
    train_probabilities = (train_pred_samples / train_pred_samples.sum(-2, keepdim=True)).mean(0)

    train_prob_stds = (train_pred_samples / train_pred_samples.sum(-2, keepdim=True)).std(0)

    test_pred_samples = test_dist.sample(torch.Size((100,))).exp()

    test_probabilities = (test_pred_samples / test_pred_samples.sum(-2, keepdim=True)).mean(0)
    test_prob_stds = (test_pred_samples / test_pred_samples.sum(-2, keepdim=True)).std(0)

 
    train_perf_df = pd.DataFrame()
    test_perf_df = pd.DataFrame()
    train_perf_df['mean_pred_class0'] = train_observed_pred.mean.numpy()[0,]
    train_perf_df['mean_pred_class1'] = train_observed_pred.mean.numpy()[1,]
    train_perf_df['y'] = trainy
    train_perf_df['y_pred'] = train_pred_means.max(0)[1]
    train_perf_df['var_pred_class0']=train_observed_pred.variance.numpy()[0,]
    train_perf_df['var_pred_class1']=train_observed_pred.variance.numpy()[1,]
    train_perf_df['pred_prob_class0'] = train_probabilities.numpy()[0,]
    train_perf_df['pred_prob_class1'] = train_probabilities.numpy()[1,]
    train_perf_df['pred_prob_std_class0'] = train_prob_stds.numpy()[0,]
    train_perf_df['pred_prob_std_class1'] = train_prob_stds.numpy()[1,]
    train_perf_df['subset'] = 'train' 

    
    test_perf_df['mean_pred_class0'] = test_observed_pred.mean.numpy()[0,]
    test_perf_df['mean_pred_class1'] = test_observed_pred.mean.numpy()[1,]
    test_perf_df['y'] = testy
    test_perf_df['y_pred'] = test_pred_means.max(0)[1]
    test_perf_df['var_pred_class0']=test_observed_pred.variance.numpy()[0,]
    test_perf_df['var_pred_class1']=test_observed_pred.variance.numpy()[1,]
    test_perf_df['pred_prob_class0'] = test_probabilities.numpy()[0,]
    test_perf_df['pred_prob_class1'] = test_probabilities.numpy()[1,]
    test_perf_df['pred_prob_std_class0'] =test_prob_stds.numpy()[0,]
    test_perf_df['pred_prob_std_class1'] = test_prob_stds.numpy()[1,]
    test_perf_df['subset'] = 'test' 
    test_cm = confusion_matrix(testy, test_perf_df['y_pred'])
    test_cm_flattened = test_cm.flatten().tolist()
   
    return train_perf_df, test_perf_df, model, likelihood


In [15]:
gp_kfold_results = "/Users/radhi/Desktop/GitHub/atom2024/atom2024/notebooks/paper/results/gp_kfold_results/"
datapath = '/Users/radhi/Desktop/GitHub/atom2024/atom2024/notebooks/paper/datasets/80train_20test/k_fold/validation/'
neks = ['NEK2_binding', 'NEK2_inhibition', 'NEK3_binding', 'NEK5_binding','NEK9_binding','NEK9_inhibition']
feats = ['MOE','MFP']
samps = ['none_scaled','UNDER', 'SMOTE', 'ADASYN']
kernal_type = ['RBF','matern' ]
folds = ['fold1','fold2','fold3','fold4','fold5']
final_cols = []
train_results = []
test_results = []
final_cols=['model','NEK','strategy','feat_type','kernel_type','fold', 'cm','recall', 'specificity', 'accuracy', 'precision', 
                'f1', 'ROC_AUC', 'MCC', 'balanced_accuracy']
for nek in neks:
    for feat in feats:
        for samp in samps:
            for fold in folds:
                for kernal in kernal_type:
                    root_name = f'{nek}_{feat}_{samp}_{fold}'
                    trainX, trainy, testX, testy = make_torch_tens_float(datapath,f'{root_name}_validation')
                    train_perf, test_perf, model, likelihood= save_results(trainX, trainy, testX, testy, root_name, kernal, n_iterations=300, n_samples=100)   
                    with open(f'{gp_kfold_results}{root_name}_{kernal}.pkl', 'wb') as f: 
                        pickle.dump(model,f)
                    with open(f'{gp_kfold_results}{root_name}_{kernal}_likelihood.pkl', 'wb') as f: 
                        pickle.dump(likelihood,f)
                
                    for i, df in enumerate(list([train_perf, test_perf])): 
                        df['NEK'] = nek
                        df['feat_type']=feat 
                        df['strategy']=feat 
                        df['fold']=fold 
                        df['kernel_type']=f'GP_{kernal}'
                        df['model'] =f'{root_name}_{kernal}'
                        if i == 0:
                            df.to_csv(f'{gp_kfold_results}{root_name}_{kernal}_train.csv', index=False)
                            train_results.append(df.iloc[[0]][final_cols].values.flatten())
                        if i == 1: 
                            df.to_csv(f'{gp_kfold_results}{root_name}_{kernal}_test.csv', index=False)
                        

Iter 1/300 - Loss: 7.107   lengthscale: 0.693   noise: 0.693
Iter 11/300 - Loss: 5.931   lengthscale: 0.693   noise: 1.297
Iter 21/300 - Loss: 5.441   lengthscale: 0.693   noise: 1.982
Iter 31/300 - Loss: 5.272   lengthscale: 0.693   noise: 2.576
Iter 41/300 - Loss: 5.222   lengthscale: 0.693   noise: 3.023
Iter 51/300 - Loss: 5.208   lengthscale: 0.693   noise: 3.341
Iter 61/300 - Loss: 5.204   lengthscale: 0.693   noise: 3.562
Iter 71/300 - Loss: 5.203   lengthscale: 0.693   noise: 3.710
Iter 81/300 - Loss: 5.202   lengthscale: 0.693   noise: 3.804
Iter 91/300 - Loss: 5.202   lengthscale: 0.693   noise: 3.858
Iter 101/300 - Loss: 5.202   lengthscale: 0.693   noise: 3.885
Iter 111/300 - Loss: 5.202   lengthscale: 0.693   noise: 3.897
Iter 121/300 - Loss: 5.202   lengthscale: 0.693   noise: 3.899
Iter 131/300 - Loss: 5.202   lengthscale: 0.693   noise: 3.898
Iter 141/300 - Loss: 5.202   lengthscale: 0.693   noise: 3.896
Iter 151/300 - Loss: 5.202   lengthscale: 0.693   noise: 3.895
Ite

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Iter 11/300 - Loss: 5.931   lengthscale: 1.251   noise: 1.297
Iter 21/300 - Loss: 5.435   lengthscale: 2.036   noise: 1.981
Iter 31/300 - Loss: 5.252   lengthscale: 3.138   noise: 2.567
Iter 41/300 - Loss: 5.188   lengthscale: 4.413   noise: 2.983
Iter 51/300 - Loss: 5.163   lengthscale: 5.598   noise: 3.240
Iter 61/300 - Loss: 5.152   lengthscale: 6.563   noise: 3.371
Iter 71/300 - Loss: 5.145   lengthscale: 7.327   noise: 3.410
Iter 81/300 - Loss: 5.140   lengthscale: 7.946   noise: 3.385
Iter 91/300 - Loss: 5.136   lengthscale: 8.467   noise: 3.318
Iter 101/300 - Loss: 5.132   lengthscale: 8.918   noise: 3.224
Iter 111/300 - Loss: 5.129   lengthscale: 9.319   noise: 3.112
Iter 121/300 - Loss: 5.126   lengthscale: 9.680   noise: 2.988
Iter 131/300 - Loss: 5.123   lengthscale: 10.008   noise: 2.857
Iter 141/300 - Loss: 5.120   lengthscale: 10.307   noise: 2.722
Iter 151/300 - Loss: 5.118   lengthscale: 10.582   noise: 2.582
Iter 161/300 - Loss: 5.115   lengthscale: 10.834   noise: 2.4



Iter 21/300 - Loss: 5.405   lengthscale: 1.960   noise: 1.977
Iter 31/300 - Loss: 5.169   lengthscale: 3.030   noise: 2.520
Iter 41/300 - Loss: 5.085   lengthscale: 4.127   noise: 2.824
Iter 51/300 - Loss: 5.066   lengthscale: 4.820   noise: 2.928
Iter 61/300 - Loss: 5.057   lengthscale: 5.114   noise: 2.905
Iter 71/300 - Loss: 5.049   lengthscale: 5.196   noise: 2.800
Iter 81/300 - Loss: 5.042   lengthscale: 5.206   noise: 2.637
Iter 91/300 - Loss: 5.035   lengthscale: 5.215   noise: 2.432
Iter 101/300 - Loss: 5.027   lengthscale: 5.243   noise: 2.200
Iter 111/300 - Loss: 5.020   lengthscale: 5.285   noise: 1.955
Iter 121/300 - Loss: 5.013   lengthscale: 5.329   noise: 1.708
Iter 131/300 - Loss: 5.005   lengthscale: 5.367   noise: 1.468
Iter 141/300 - Loss: 4.998   lengthscale: 5.399   noise: 1.243
Iter 151/300 - Loss: 4.991   lengthscale: 5.431   noise: 1.037
Iter 161/300 - Loss: 4.985   lengthscale: 5.468   noise: 0.855
Iter 171/300 - Loss: 4.979   lengthscale: 5.512   noise: 0.698




accuracy: 1.0000, precision: 1.0000, recall: 1.0000, specificity: 1.0000, cm: [[36  0]
 [ 0 36]]
accuracy: 0.8889, precision: 0.8889, recall: 0.8889, specificity: 0.8889, cm: [[8 1]
 [1 8]]
Iter 1/300 - Loss: 7.106   lengthscale: 0.693   noise: 0.693
Iter 11/300 - Loss: 5.908   lengthscale: 1.291   noise: 1.296
Iter 21/300 - Loss: 5.368   lengthscale: 2.157   noise: 1.965
Iter 31/300 - Loss: 5.198   lengthscale: 3.099   noise: 2.506
Iter 41/300 - Loss: 5.154   lengthscale: 3.758   noise: 2.864
Iter 51/300 - Loss: 5.138   lengthscale: 4.160   noise: 3.057
Iter 61/300 - Loss: 5.131   lengthscale: 4.424   noise: 3.129
Iter 71/300 - Loss: 5.127   lengthscale: 4.627   noise: 3.116
Iter 81/300 - Loss: 5.123   lengthscale: 4.803   noise: 3.049
Iter 91/300 - Loss: 5.120   lengthscale: 4.968   noise: 2.945
Iter 101/300 - Loss: 5.117   lengthscale: 5.127   noise: 2.818
Iter 111/300 - Loss: 5.113   lengthscale: 5.281   noise: 2.676
Iter 121/300 - Loss: 5.110   lengthscale: 5.430   noise: 2.523
It

