# Load all necessary libraries

In [None]:
import os,sys
import os.path as osp
import pandas as pd
import numpy as np
# from tqdm.notebook import tqdm
from tqdm import tqdm
from sklearn.metrics import mean_squared_error
from sklearn.metrics import root_mean_squared_error
from sklearn.metrics import mean_absolute_error
import scipy.stats as sc
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.models import load_model
from argparse import ArgumentParser
from keras.models import load_model
from keras.saving import register_keras_serializable
from keras.losses import MeanSquaredError

tf.keras.utils.set_random_seed(42)  # sets seeds for base-python, numpy and tf
tf.config.experimental.enable_op_determinism()

# Load inference functions

In [4]:
def load_data_dataset_dir(evaluation_dataset_path, evaluation_features_dir_dir):
    
    """
    Loads data from the processed dataset, given the path to csv file and path to features. Note that dataset
    is loaded into RAM, so be carefull with that.
    
    Args:
    
      evaluation_dataset_path:
          Path to dataset csv file. Note the file columns format.
          
      evaluation_features_dir_dir:
          Path to direct features directory. Note the directory format.
          
    Returns:
    
      Pandas dataframe with loaded .npy features.
    """
    
    print("Loading dataset direct mutations")
    df = pd.read_csv(evaluation_dataset_path)
    print(f'Total unique mutations: {len(df)}')

    df['features'] = df.apply(lambda r: f'{evaluation_features_dir_dir}/{r.pdb_id}/{r.pdb_id}_{r.wild_type}{r.position}{r.mutant}.npy', axis=1)
    df = df[df.features.apply(lambda v: os.path.exists(v))]
    print(f'Total mutations with features: {len(df)}')
    df.features = [np.load(f) for f in tqdm(df.features, desc="2. Loading features")]
    print(f'Total mutations after filtering: {len(df)}')

    df_train = df
    df_train.features = df_train.features.apply(lambda k: np.transpose(k, (1, 2, 3, 0)))
    
    return df_train

def load_data_dataset_rev(evaluation_dataset_path, evaluation_features_dir_rev):
    
    """
    Loads data from the processed dataset, given the path to csv file and path to features. Note that dataset
    is loaded into RAM, so be carefull with that.
    
    Args:
    
      evaluation_dataset_path:
          Path to dataset csv file. Note the file columns format.
          
      evaluation_features_dir_rev:
          Path to reverse features directory. Note the directory format.
          
    Returns:
    
      Pandas dataframe with loaded .npy features.
    """
    
    print('Loading dataset reverse mutations')
    df_rev = pd.read_csv(evaluation_dataset_path)
    df_rev.ddg = -df_rev.ddg

        
    df_rev['features'] = df_rev.apply(lambda r: f'{evaluation_features_dir_rev}/{r.pdb_id}/{r.pdb_id}_{r.wild_type}{r.position}{r.mutant}.npy', axis=1)
    df_rev = df_rev[df_rev.features.apply(lambda v: os.path.exists(v))]
    print(f'Total mutations with features: {len(df_rev)}')
        
        
    df_rev.features = [np.load(f) for f in tqdm(df_rev.features, desc="3. Loading features")]
    print(f'Total mutations after filtering: {len(df_rev)}')
    
    df_rev.features = df_rev.features.apply(lambda k: np.transpose(k, (1, 2, 3, 0)))
    
    return df_rev

@register_keras_serializable()
def rmse(y_val_direct, y_pred):

    rmse = tf.sqrt(tf.reduce_mean(tf.square(tf.squeeze(y_val_direct) - tf.squeeze(y_pred))))
    
    return rmse

# def pearson_r(y_val_direct, y_pred):
    
#     """
#     Simple function to calculate Pearson correlation coefficient, needed for model to load.
#     """

#     if tf.shape(y_val_direct)[0] == 1:
#         y_val_direct = tf.concat([y_val_direct, y_val_direct], axis=0)
#         y_pred = tf.concat([y_pred, y_pred], axis=0)

#         pr, _ = tf.py_function(sc.pearsonr, [y_val_direct, y_pred], [tf.float64, tf.float64])
#         #tf.print("Pearson correlation coefficient:", pr)
#     else:
#         y_val_direct = tf.squeeze(y_val_direct)
#         y_pred = tf.squeeze(y_pred)
    
#         pr, _ = tf.py_function(sc.pearsonr, [y_val_direct, y_pred], [tf.float64, tf.float64])
#         #tf.print("Pearson correlation coefficient:", pr)

#     return pr

@register_keras_serializable()
def pearson_r(y_true, y_pred):
    """
    Compute Pearson correlation coefficient as a Keras metric.
    """
    # Ensure tensors are flattened
    y_true = tf.reshape(y_true, [-1])
    y_pred = tf.reshape(y_pred, [-1])

    # Compute mean and standard deviation
    mean_y_true = tf.reduce_mean(y_true)
    mean_y_pred = tf.reduce_mean(y_pred)
    std_y_true = tf.math.reduce_std(y_true)
    std_y_pred = tf.math.reduce_std(y_pred)

    # Compute covariance
    covariance = tf.reduce_mean((y_true - mean_y_true) * (y_pred - mean_y_pred))

    # Compute Pearson correlation coefficient
    pearson_corr = covariance / (std_y_true * std_y_pred + tf.keras.backend.epsilon())

    return pearson_corr

# Function to prepare datasets
def prepare_datasets(df_train_dataset_dir, df_train_dataset_rev):
    
    """
    Simple function to prepare the datasets.
    
    Args:
    
      df_train_dataset_dir:
          Dataframe with loaded direct features can be obtained via load_data_dataset_dir().
          
      df_train_dataset_rev:
          Dataframe with loaded reverse features can be obtained via load_data_dataset_rev().
          
    Returns:
    
      Collection of datasets in numpy format used for inference: direct, reverse and total.
    """
    
    X_direct_dataset_dir = np.array(df_train_dataset_dir.features.to_list())
    y_direct_dataset_dir = df_train_dataset_dir.ddg.to_numpy()

    X_direct_dataset_rev = np.array(df_train_dataset_rev.features.to_list())
    y_direct_dataset_rev = df_train_dataset_rev.ddg.to_numpy()
    
    g = pd.concat([df_train_dataset_dir, df_train_dataset_rev])
    X_total_dataset = np.array(g.features.to_list())
    y_total_dataset = g.ddg.to_numpy()

    return X_direct_dataset_dir, y_direct_dataset_dir,X_direct_dataset_rev, y_direct_dataset_rev, X_total_dataset, y_total_dataset

# Function to evaluate a single model
def evaluate_model(model_path, X_dir, y_dir, X_rev, y_rev , X_tot, y_tot, model_name):
    
    """
    Simple function to for an inference of single model.
    
    Args:
    
      model_path:
          Path to models.
          
      X_dir:
          Direct dataset in numpy format.
        
      y_dir:
          Direct dataset labels in numpy format.
          
      X_rev:
          Reverse dataset in numpy format.
        
      y_rev:
          Reverse dataset labels in numpy format..
          
      X_tot:
          Total dataset in numpy format.
      
      y_tot:
          Total dataset labels in numpy format.
          
      model_name:
          Specify model name.
          
    Returns:
    
      Dictionary with inference metrics.
    """
    
    model = load_model(model_path, custom_objects={"mse": MeanSquaredError(), "rmse": rmse, "pearson_r": pearson_r})
    
    y_pred_dir = model.predict(X_dir).reshape(-1)
    y_pred_rev = model.predict(X_rev).reshape(-1)
    y_pred_tot = model.predict(X_tot).reshape(-1)

    mae_dir = mean_absolute_error(y_dir, y_pred_dir)
    mse_dir = mean_squared_error(y_dir, y_pred_dir)
    # rmse_dir = mean_squared_error(y_dir, y_pred_dir, squared=False)
    rmse_dir = root_mean_squared_error(y_dir, y_pred_dir)
    pr_dir = sc.pearsonr(y_dir, y_pred_dir)[0]

    mae_rev = mean_absolute_error(y_rev, y_pred_rev)
    mse_rev = mean_squared_error(y_rev, y_pred_rev)
    # rmse_rev = mean_squared_error(y_rev, y_pred_rev, squared=False)
    rmse_rev = root_mean_squared_error(y_rev, y_pred_rev)
    pr_rev = sc.pearsonr(y_rev, y_pred_rev)[0]
    
    mae_tot = mean_absolute_error(y_tot, y_pred_tot)
    mse_tot = mean_squared_error(y_tot, y_pred_tot)
    # rmse_tot = mean_squared_error(y_tot, y_pred_tot, squared=False)
    rmse_tot = root_mean_squared_error(y_tot, y_pred_tot)
    pr_tot = sc.pearsonr(y_tot, y_pred_tot)[0]

    return {
        "model_name": model_name,
        "mae_dir": mae_dir, "mse_dir": mse_dir, "rmse_dir": rmse_dir, "pearson_r_dir": pr_dir,
        "mae_rev": mae_rev, "mse_rev": mse_rev, "rmse_rev": rmse_rev, "pearson_r_rev": pr_rev,
        "mae_tot": mae_tot, "mse_tot": mse_tot, "rmse_tot": rmse_tot, "pearson_r_tot": pr_tot
    }

# Function to evaluate all models
def evaluate_models(df_train_dataset_dir, df_train_dataset_rev, model_dir, evalpathsave, model_type, v):
    
    """
    Simple function to for an inference of several model.
    
    Args:
    
      df_train_dataset_dir:
          Dataframe with loaded direct features can be obtained via load_data_dataset_dir().
          
      df_train_dataset_rev:
          Dataframe with loaded reverse features can be obtained via load_data_dataset_dir().
        
      model_dir:
          Path to directory with models.
          
      evalpathsave:
          Path to save inference Dataframe.
        
      model_type:
          Selection of a model type "sing" or "ens".
          
      v:
          Flag.
          
    Returns:
    
      Inference dataframe with metrics.
    """
    
    X_dir, y_dir, X_rev, y_rev, X_tot, y_tot = prepare_datasets(df_train_dataset_dir, df_train_dataset_rev)
    
    eval_results = []
    
    if model_type == "sing":
        for model_name in sorted(os.listdir(model_dir), key=lambda x: int(x.split(".")[0].split("_")[-1])):
        
            model_path = os.path.join(model_dir, model_name)
            results = evaluate_model(model_path, X_dir, y_dir, X_rev, y_rev , X_tot, y_tot, model_name)
            eval_results.append(results)
    
    if model_type == "ens":
        results = evaluate_ensemble(model_dir, X_dir, y_dir, X_rev, y_rev , X_tot, y_tot)
        eval_results.append(results)

    eval_df = pd.DataFrame(eval_results)
    eval_df.to_csv(f"{evalpathsave}/{model_dir.split('/')[-2]}_dataset_eval_results_{v}.csv", index=False)
    
    return eval_df

def evaluate_ensemble(models_dir, X_dir, y_dir, X_rev, y_rev, X_tot, y_tot):
    
    """
    Simple function to for an inference of an ensemble.
    
    Args:
    
      models_dir:
          Path to models directory.
          
      X_dir:
          Direct dataset in numpy format.
        
      y_dir:
          Direct dataset labels in numpy format.
          
      X_rev:
          Reverse dataset in numpy format.
        
      y_rev:
          Reverse dataset labels in numpy format..
          
      X_tot:
          Total dataset in numpy format.
      
      y_tot:
          Total dataset labels in numpy format.
          
    Returns:
    
      Dictionary with ensemble metrics.
    """
    
    model_files = sorted(os.listdir(models_dir), key=lambda x: int(x.split(".")[0].split("_")[-1]))
    ensemble_preds_dir = []
    ensemble_preds_rev = []
    ensemble_preds_tot = []

    for model_name in model_files:
        model_path = os.path.join(models_dir, model_name)
        model = load_model(model_path, custom_objects={"mse": MeanSquaredError(), "rmse": rmse, "pearson_r": pearson_r})

        y_pred_dir = model.predict(X_dir).reshape(-1)
        y_pred_rev = model.predict(X_rev).reshape(-1)
        y_pred_tot = model.predict(X_tot).reshape(-1)

        ensemble_preds_dir.append(y_pred_dir)
        ensemble_preds_rev.append(y_pred_rev)
        ensemble_preds_tot.append(y_pred_tot)

    # Average predictions across all models
    avg_pred_dir = np.mean(ensemble_preds_dir, axis=0)
    avg_pred_rev = np.mean(ensemble_preds_rev, axis=0)
    avg_pred_tot = np.mean(ensemble_preds_tot, axis=0)

    # Evaluate ensemble predictions
    mae_dir = mean_absolute_error(y_dir, avg_pred_dir)
    mse_dir = mean_squared_error(y_dir, avg_pred_dir)
    # rmse_dir = mean_squared_error(y_dir, avg_pred_dir, squared=False)
    rmse_dir = root_mean_squared_error(y_dir, avg_pred_dir)
    pr_dir = sc.pearsonr(y_dir, avg_pred_dir)[0]

    mae_rev = mean_absolute_error(y_rev, avg_pred_rev)
    mse_rev = mean_squared_error(y_rev, avg_pred_rev)
    # rmse_rev = mean_squared_error(y_rev, avg_pred_rev, squared=False)
    rmse_rev = root_mean_squared_error(y_rev, avg_pred_rev)
    pr_rev = sc.pearsonr(y_rev, avg_pred_rev)[0]
    
    
    mae_tot = mean_absolute_error(y_tot, avg_pred_tot)
    mse_tot = mean_squared_error(y_tot, avg_pred_tot)
    # rmse_tot = mean_squared_error(y_tot, avg_pred_tot, squared=False)
    rmse_tot = root_mean_squared_error(y_tot, avg_pred_tot)
    pr_tot = sc.pearsonr(y_tot, avg_pred_tot)[0]

    return {
        "mae_dir": mae_dir, "mse_dir": mse_dir, "rmse_dir": rmse_dir, "pearson_r_dir": pr_dir,
        "mae_rev": mae_rev, "mse_rev": mse_rev, "rmse_rev": rmse_rev, "pearson_r_rev": pr_rev,
        "mae_tot": mae_tot, "mse_tot": mse_tot, "rmse_tot": rmse_tot, "pearson_r_tot": pr_tot
    }

def add_model_col(df, key):
    df['model'] = [key+"_"+f.split(".")[0].split('_')[-1] for f in df['model_name']]
    return df
def add_model_col_ens(df, key):
    df['model'] = key
    return df

def table_report(evalpathsave, model_dir, w, df_train_dataset_dir, df_train_dataset_rev):

    """
    General function for OrgNet inference.
    
    Args:
    
      evalpathsave:
          Path to save iference dataframes.
          
      model_dir:
          Path to OrgNet models.
        
      w:
          flag.
          
      df_train_dataset_dir:
          Reverse dataset in numpy format.
        
      df_train_dataset_rev:
          Reverse dataset labels in numpy format.
          
    Returns:
    
      Inference dataframe with metrics.
    """
    
    
    model_type = "sing"
    key = f"{w}_{model_type}"
    df1 = evaluate_models(df_train_dataset_dir, df_train_dataset_rev, model_dir, evalpathsave, model_type, key)
    df1 = add_model_col(df1, key)
    
    model_type = "ens"
    key = f"{w}_{model_type}"
    df2 = evaluate_models(df_train_dataset_dir, df_train_dataset_rev, model_dir, evalpathsave, model_type, key)
    df2 = add_model_col_ens(df2, key)
    
    df = pd.concat([df1, df2])
    return df

params = ["model", "mae_dir", "mse_dir","rmse_dir", "pearson_r_dir","mae_rev","mse_rev", "rmse_rev","pearson_r_rev", "mae_tot", "mse_tot","rmse_tot","pearson_r_tot"]

# Run the cell bellow

In [None]:
# fill in your paths

feature_type = "defdif"
# feature_type = "def"

datasets_dir = ".../data_preprocessing/datasets"
features_dir = ".../data_preprocessing/Ssym/features/Ssym_ori"
# features_dir = ".../data_preprocessing/Ssym/features/Ssym_nonori"
inference_features_dir_dir =  f"{features_dir}/Ssym_{feature_type}_direct/"
inference_features_dir_rev =  f"{features_dir}/Ssym_{feature_type}_reverse/"
inference_dataset_path = f"{datasets_dir}/Ssym.csv"

#load your data
df_train_dataset_dir = load_data_dataset_dir(inference_dataset_path, inference_features_dir_dir)
df_train_dataset_rev = load_data_dataset_rev(inference_dataset_path, inference_features_dir_rev)

#get inference df
# evalpathsave = "/evalpathsave/"
evalpathsave = ".../ThermoNet-like/"
model_dir = f".../ThermoNet-like/TH_{feature_type}_Q1744/"
w = f"fflagff"
res_df = table_report(evalpathsave, model_dir, w, df_train_dataset_dir, df_train_dataset_rev)
res_df[params]