In [1]:
# Import libraries
import pandas as pd
import pickle

from src.model_training_functions import train_model_on_chunks, fine_tune_model_on_chunks, convert_compound_pairs, NeuralNetworkModel
from src.model_training_functions import evaluate_test_data, prepare_and_evaluate_pairs

from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

import torch
import os

In [2]:
def get_roc_auc(test_data, pred_col, binary_sim_col='y'):
    """
    Compute ROC curves and AUC values for the test set.

    Parameters
    ----------
    test_data : pd.DataFrame
        DataFrame containing test results, including ground truth and predicted scores.
    pred_col : str
        Column name for the model's predicted values.
    binary_sim_col : str, optional
        Column name for the binary similarity label (default: 'y').

    Returns
    -------
    roc_auc_pred : float
        Area under the ROC curve for the model's predictions.
    """

    # Compute ROC curve and AUC for model predictions
    fpr_pred, tpr_pred, _ = roc_curve(test_data[binary_sim_col], test_data[pred_col])
    roc_auc_pred = auc(fpr_pred, tpr_pred)

    return roc_auc_pred

def get_pr_auc(test_data, pred_col, binary_sim_col='y'):
    """
    Compute the area under the Precision-Recall curve (PR AUC) for the test set.

    Parameters
    ----------
    test_data : pd.DataFrame
        DataFrame containing test results, including ground truth and predicted scores.
    pred_col : str
        Column name for the model's predicted values.
    binary_sim_col : str, optional
        Column name for the binary similarity label (default: 'y').

    Returns
    -------
    pr_auc : float
        Area under the Precision-Recall curve.
    precision : np.ndarray
        Precision values for the curve.
    recall : np.ndarray
        Recall values for the curve.
    """
    # Compute precision-recall curve
    precision, recall, _ = precision_recall_curve(test_data[binary_sim_col], test_data[pred_col])
    # Compute the area under the curve (AUC)
    pr_auc = auc(recall, precision)

    return pr_auc

In [3]:
# Load Fingerprints database
data_dir = 'data'

# Convert fingerprints to np.float32 format 
if not 'comps_fps_np.pkl' in os.listdir(data_dir):    
    with open(f'{data_dir}/comps_fps.pkl','rb') as f:
        db_ligs = pickle.load(f)
    
    db_ligs = {l:np.array(db_ligs[l], dtype=np.float32) for l in db_ligs}
    with open(f'{data_dir}/comps_fps_np.pkl','wb') as f:
        pickle.dump(db_ligs,f)
else:
    with open(f'{data_dir}/comps_fps_np.pkl','rb') as f:
        db_ligs = pickle.load(f)

## Train model

In [5]:
train_dir = 'train_datasets'
model = train_model_on_chunks(
        train_dir,
        db_ligs,                # o cualquier recurso que necesites para conv_suma
        hidden_layers=[512, 256, 128, 64], 
        dropout_prob=0.3, 
        n_epochs=5)

torch.save(model.state_dict(), f'model.pth')


=== Epoch 1/5 ===
 - chunk_11.csv | Train Loss: 0.5264, Val Loss: 0.5174
 - chunk_4.csv | Train Loss: 0.5148, Val Loss: 0.5073
 - chunk_30.csv | Train Loss: 0.5103, Val Loss: 0.4993
 - chunk_15.csv | Train Loss: 0.5057, Val Loss: 0.5003
 - chunk_21.csv | Train Loss: 0.5058, Val Loss: 0.4955
 - chunk_22.csv | Train Loss: 0.5021, Val Loss: 0.5044
 - chunk_8.csv | Train Loss: 0.4983, Val Loss: 0.4868
 - chunk_13.csv | Train Loss: 0.4955, Val Loss: 0.4903
 - chunk_16.csv | Train Loss: 0.4964, Val Loss: 0.4841
 - chunk_24.csv | Train Loss: 0.4880, Val Loss: 0.4771
 - chunk_6.csv | Train Loss: 0.4880, Val Loss: 0.4743
 - chunk_26.csv | Train Loss: 0.4837, Val Loss: 0.4787
 - chunk_1.csv | Train Loss: 0.4816, Val Loss: 0.4701
 - chunk_12.csv | Train Loss: 0.4785, Val Loss: 0.4708
 - chunk_25.csv | Train Loss: 0.4759, Val Loss: 0.4622
 - chunk_9.csv | Train Loss: 0.4715, Val Loss: 0.4613
 - chunk_3.csv | Train Loss: 0.4715, Val Loss: 0.4627
 - chunk_7.csv | Train Loss: 0.4698, Val Loss: 0.465

## Evaluation on test datasets

In [4]:
def evaluate_test_data(model,test_data,db_ligs):
    model.eval()
    X_test,y_test = convert_compound_pairs(test_data,db_ligs)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    device = next(model.parameters()).device
    X_test = X_test.to(device)
    with torch.no_grad():
        preds = model(X_test).flatten().cpu().numpy()
    return preds

In [5]:
# Load model (same parameters as the trained model)
input_size = len(next(iter(db_ligs.values())))
model = NeuralNetworkModel(input_size=input_size,hidden_layers=[512,256,128,64],output_size=1,dropout_prob=0.3)
model.load_state_dict(torch.load('model.pth'))

  model.load_state_dict(torch.load('model.pth'))


<All keys matched successfully>

In [6]:
# Load test data
test_data = pd.read_csv('./test_datasets/test_pairs.csv')

In [7]:
# Make predictions
preds = evaluate_test_data(model,test_data,db_ligs)
test_data['pred'] = preds

In [8]:
# Compute ROC AUC and PR AUC per protein and collect results in a DataFrame

results = [
    {
        "prot": prot,
        "ROC AUC": get_roc_auc(data_prot := test_data[test_data["prot"] == prot], "pred"),
        "PR AUC": get_pr_auc(data_prot, "pred")
    }
    for prot in test_data["prot"].unique()
]

test_results = pd.DataFrame(results)

# Show test results
display(test_results)

Unnamed: 0,prot,ROC AUC,PR AUC
0,Q9P0X4,0.73709,0.545869
1,Q9UKV0,0.818359,0.510388


## Example of fine tuning

In [9]:
# Load model
input_size = len(next(iter(db_ligs.values())))
model = NeuralNetworkModel(input_size=input_size,hidden_layers=[512,256,128,64],output_size=1,dropout_prob=0.3)
model.load_state_dict(torch.load('model.pth'))

  model.load_state_dict(torch.load('model.pth'))


<All keys matched successfully>

In [10]:
# Fine tune model (in this example, using the test data)
ft_train_dir = './test_datasets/'

ft_model = fine_tune_model_on_chunks(
    ft_train_dir,
    db_ligs,                     # fingerprint dictionary used by convert_compound_pairs
    model,               # path to .pt file or an nn.Module instance
    n_epochs=5
)

# To save the model
torch.save(ft_model.state_dict(), f'ft_model.pth')


=== Fine‑tuning Epoch 1/5 ===
 - test_pairs.csv | Train Loss: 0.4433 | Val Loss: 0.3813
=== End of Epoch 1/5 | Avg Train Loss: 0.4433 ===

=== Fine‑tuning Epoch 2/5 ===
 - test_pairs.csv | Train Loss: 0.3994 | Val Loss: 0.3515
=== End of Epoch 2/5 | Avg Train Loss: 0.3994 ===

=== Fine‑tuning Epoch 3/5 ===
 - test_pairs.csv | Train Loss: 0.3649 | Val Loss: 0.3245
=== End of Epoch 3/5 | Avg Train Loss: 0.3649 ===

=== Fine‑tuning Epoch 4/5 ===
 - test_pairs.csv | Train Loss: 0.3370 | Val Loss: 0.2995
=== End of Epoch 4/5 | Avg Train Loss: 0.3370 ===

=== Fine‑tuning Epoch 5/5 ===
 - test_pairs.csv | Train Loss: 0.3097 | Val Loss: 0.2759
=== End of Epoch 5/5 | Avg Train Loss: 0.3097 ===


## Evaluation on new data

In [11]:
# Load pairs of compounds (in this example, two lists of SMILES, one for each compound from each pair)
compound_list_1 = ['CCCO','O=C(c1ccc(Oc2ccccc2)cc1)N1CCN(c2ncccn2)CC1','CC(C)Nc1ncnc(SC#N)c1[N+](=O)[O-]']
compound_list_2 = ['NCCCN(Cc1nn2ccc(Cl)c2c(=O)n1Cc1ccccc1)C(=O)c1ccc(Cl)cc1','CCCCOc1ccccc1C[C@H]1COC(=O)[C@@H]1Cc1ccc(Cl)c(Cl)c1','C=CC(=O)Nc1cccc(Nc2nc(Nc3ccc(SCC(=O)N4CCOCC4)cc3)ncc2Cl)c1']

# Generate pair DataFrame
new_pairs = pd.DataFrame({'l1':compound_list_1,'l2':compound_list_2})

In [13]:
# Note: Pairs with Tanimoto > 0.4 should not be considered,
# as the model was exclusively trained on pairs with Tanimoto < 0.4.

new_pairs_pred = prepare_and_evaluate_pairs(new_pairs, model)

# Show predictions
display(new_pairs_pred)



Unnamed: 0,l1,l2,Tanimoto,pred
0,CCCO,NCCCN(Cc1nn2ccc(Cl)c2c(=O)n1Cc1ccccc1)C(=O)c1c...,0.035088,0.124464
1,O=C(c1ccc(Oc2ccccc2)cc1)N1CCN(c2ncccn2)CC1,CCCCOc1ccccc1C[C@H]1COC(=O)[C@@H]1Cc1ccc(Cl)c(...,0.169014,0.017712
2,CC(C)Nc1ncnc(SC#N)c1[N+](=O)[O-],C=CC(=O)Nc1cccc(Nc2nc(Nc3ccc(SCC(=O)N4CCOCC4)c...,0.179487,0.005534
