In [1]:
"""
Script to train neural network to get molecules embeddings
"""
from __future__ import annotations

import argparse
import os
import pickle
import json
import random
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import time

from tqdm.auto import tqdm
from paroutes import PaRoutesInventory, get_target_smiles
from embedding_model import (
    fingerprint_preprocess_input,
    gnn_preprocess_input,
    CustomDataset,
    collate_fn,
    # SampleData,
    fingerprint_vect_from_smiles,
    compute_embeddings,
    GNNModel,
    FingerprintModel,
    NTXentLoss,
    num_heavy_atoms
)
from paroutes import PaRoutesInventory
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
import plotly.express as px
from rdkit import Chem
import deepchem as dc


    
with open(f"config_gnn_0627.json", "r") as f:
    config = json.load(f)



# MODEL TO EVALUATE
experiment_name = config["experiment_name"]  # gnn_0629
checkpoint_folder = f"GraphRuns/{experiment_name}/"
input_checkpoint_name = f"epoch_5_checkpoint.pth"




No normalization for AvgIpc. Feature removed!
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading modules with transformers dependency. No module named 'transformers'
cannot import name 'HuggingFaceModel' from 'deepchem.models.torch_models' (/Users/ilariasartori/miniforge3/envs/syntheseus_temp/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'pytorch_lightning'
Skipped loading some Jax models, missing a dependency. jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.


In [2]:
# 1. PREPROCESS DATA AGAIN
# if not args.load_from_preprocessed_data:

# Read routes data
input_file_routes = f'Runs/{config["run_id"]}/targ_routes.pickle'
# input_file_distances = f'Runs/{config["run_id"]}/targ_to_purch_distances.pickle'

# Routes data
with open(input_file_routes, "rb") as handle:
    targ_routes_dict = pickle.load(handle)

# # Load distances data
# with open(input_file_distances, 'rb') as handle:
#     distances_dict = pickle.load(handle)

# Inventory

inventory = PaRoutesInventory(n=5)
purch_smiles = [mol.smiles for mol in inventory.purchasable_mols()]
# len(purch_smiles)

# def num_heavy_atoms(mol):
#     return Chem.rdchem.Mol.GetNumAtoms(mol, onlyExplicit=True)

purch_mol_to_exclude = []
purch_nr_heavy_atoms = {}
for smiles in purch_smiles:
    nr_heavy_atoms = num_heavy_atoms(Chem.MolFromSmiles(smiles))
    if nr_heavy_atoms < 2:
        purch_mol_to_exclude = purch_mol_to_exclude + [smiles]
    purch_nr_heavy_atoms[smiles] = nr_heavy_atoms

if config["run_id"] == "202305-2911-2320-5a95df0e-3008-4ebe-acd8-ecb3b50607c7":
    all_targets = get_target_smiles(n=5)
elif config["run_id"] == "Guacamol_combined":
    with open("Data/Guacamol/guacamol_v1_test_10ksample.txt", "r") as f:
        all_targets = [line.strip() for line in f.readlines()]

targ_route_not_in_route_dict = {}
for target in all_targets:
    targ_route_not_in_route_dict[target] = {}

    target_routes_dict = targ_routes_dict.get(target, "Target_Not_Solved")

    if target_routes_dict == "Target_Not_Solved":
        purch_in_route = []
    else:
        target_route_df = target_routes_dict["route_1"]
        purch_in_route = list(
            target_route_df.loc[target_route_df["label"] != "Target", "smiles"]
        )
    #         purch_in_route = [smiles for smiles in purch_in_route if smiles in purch_smiles]
    purch_not_in_route = [
        purch_smile
        for purch_smile in purch_smiles
        if purch_smile not in purch_in_route
    ]
    random.seed(config["seed"])

    if config["neg_sampling"] == "uniform":
        purch_not_in_route_sample = random.sample(
            purch_not_in_route, config["not_in_route_sample_size"]
        )
    elif config["neg_sampling"] == "...":
        pass
    else:
        raise NotImplementedError(f'{config["neg_sampling"]}')

    # Filter out molecules with only one atom (problems with featurizer)
    purch_in_route = [
        smiles for smiles in purch_in_route if smiles not in purch_mol_to_exclude
    ]
    purch_not_in_route_sample = [
        smiles
        for smiles in purch_not_in_route_sample
        if smiles not in purch_mol_to_exclude
    ]

    targ_route_not_in_route_dict[target]["positive_samples"] = purch_in_route
    targ_route_not_in_route_dict[target][
        "negative_samples"
    ] = purch_not_in_route_sample

# Get a random sample of keys from targ_routes_dict
if config["nr_sample_targets"] != -1:
    sample_targets = random.sample(
        list(targ_route_not_in_route_dict.keys()), config["nr_sample_targets"]
    )
else:
    sample_targets = targ_route_not_in_route_dict
# Create targ_routes_dict_sample with the sampled keys and their corresponding values
targ_route_not_in_route_dict_sample = {
    target: targ_route_not_in_route_dict[target] for target in sample_targets
}

input_data = targ_route_not_in_route_dict_sample

if config["model_type"] == "gnn":
    featurizer = dc.feat.MolGraphConvFeaturizer()

    purch_mols = [Chem.MolFromSmiles(smiles) for smiles in purch_smiles]
    purch_featurizer = featurizer.featurize(purch_mols)
    purch_featurizer_dict = dict(zip(purch_smiles, purch_featurizer))

    dataset = gnn_preprocess_input(
        input_data=input_data, 
        featurizer=featurizer, 
        featurizer_dict=purch_featurizer_dict,
        pos_sampling=config["pos_sampling"],
    )

elif config["model_type"] == "fingerprints":
    purch_fingerprints = list(map(fingerprint_vect_from_smiles, purch_smiles))
    purch_fingerprints_dict = dict(zip(purch_smiles, purch_fingerprints))


    dataset = fingerprint_preprocess_input(
        input_data, 
        fingerprints_dict=purch_fingerprints_dict, 
        pos_sampling=config["pos_sampling"],
    )

else:
    raise NotImplementedError(f'Model type {config["model_type"]}')


# 2. TRAIN VALIDATION SPLIT
validation_ratio = config["validation_ratio"]
num_samples = len(dataset)
num_val_samples = int(validation_ratio * num_samples)

train_indices, val_indices = train_test_split(
    range(num_samples), test_size=num_val_samples, random_state=42
)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

train_data_loader = DataLoader(
    train_dataset,
    batch_size=config["train_batch_size"],
    shuffle=config["train_shuffle"],
    collate_fn=collate_fn,
)
val_data_loader = DataLoader(
    val_dataset,
    batch_size=config["val_batch_size"],
    shuffle=config["val_shuffle"],
    collate_fn=collate_fn,
)

# Batch size: The batch size determines the number of samples processed in each iteration during training or validation. In most cases, it is common to use the same batch size for both training and validation to maintain consistency. However, there are situations where you might choose a different batch size for validation. For instance, if memory constraints are more relaxed during validation, you can use a larger batch size to speed up evaluation.
# Shuffle training data: Shuffling the training data before each epoch is beneficial because it helps the model see the data in different orders, reducing the risk of the model learning patterns specific to the order of the data. Shuffling the training data introduces randomness and promotes better generalization.
# No shuffle for validation data: It is generally not necessary to shuffle the validation data because validation is meant to evaluate the model's performance on unseen data that is representative of the real-world scenarios. Shuffling the validation data could lead to inconsistent evaluation results between different validation iterations, making it harder to track the model's progress and compare performance.







Failed to featurize datapoint 146, N. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 703, S. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 1055, Br. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 1126, I. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 3401, O. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 4103, F. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint

  0%|          | 0/10000 [00:00<?, ?it/s]

In [3]:
# 0. LOAD MODEL 
# Option 1: FROM CHECKPOINT
input_checkpoint_folder  = f'GraphRuns/{experiment_name}'
input_checkpoint_path = f'{input_checkpoint_folder}/{input_checkpoint_name}'


with open(f'{checkpoint_folder}/input_dim.pickle', "rb") as f:
    input_dim = pickle.load(f)['input_dim']

# Define network dimensions
if config["model_type"] == "gnn":
    gnn_input_dim = input_dim
    gnn_hidden_dim = config["hidden_dim"]
    gnn_output_dim = config["output_dim"]

elif config["model_type"] == "fingerprints":
    #     fingerprint_input_dim = preprocessed_targets[0].GetNumBits()
    fingerprint_input_dim = input_dim
    fingerprint_hidden_dim = config["hidden_dim"]
    fingerprint_output_dim = config["output_dim"]

else:
    raise NotImplementedError(f'Model type {config["model_type"]}')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if config["model_type"] == "gnn":
    model = GNNModel(
        input_dim=gnn_input_dim,
        hidden_dim=gnn_hidden_dim,
        output_dim=gnn_output_dim,
    ).to(device)
    model.double()

elif config["model_type"] == "fingerprints":
    model = FingerprintModel(
        input_dim=fingerprint_input_dim,
        hidden_dim=fingerprint_hidden_dim,
        output_dim=fingerprint_output_dim,
    ).to(device)
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')

loss_fn = NTXentLoss(temperature=config["temperature"], device=device)

checkpoint = torch.load(input_checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])

# # # OPTION 2: From pickle
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# loss_fn = NTXentLoss(temperature=config["temperature"], device=device)
# with open(f'{checkpoint_folder}/model_min_val.pkl', "rb") as f:
#     model = pickle.load(f)



<All keys matched successfully>

In [4]:
checkpoint1 = torch.load(input_checkpoint_path)
checkpoint2 = torch.load(input_checkpoint_path)

aa = checkpoint1["model_state_dict"]
bb = checkpoint1["model_state_dict"]

In [5]:
print(aa.keys())


odict_keys(['conv1.bias', 'conv1.lin.weight', 'conv2.bias', 'conv2.lin.weight', 'fc.weight', 'fc.bias'])


In [6]:
aa["conv1.lin.weight"]

tensor([[-0.0182,  0.0882,  0.0097,  ..., -0.0630,  0.0137, -0.0142],
        [-0.0137, -0.0323,  0.0690,  ...,  0.0586, -0.0718,  0.0425],
        [-0.0948, -0.0169, -0.2129,  ..., -0.0453,  0.0261,  0.0008],
        ...,
        [ 0.0455,  0.0312,  0.0367,  ...,  0.1213,  0.0331, -0.1026],
        [ 0.0481,  0.1346,  0.0286,  ..., -0.0846,  0.0547, -0.0316],
        [ 0.0465, -0.1177,  0.0113,  ...,  0.0605,  0.0737,  0.0774]],
       dtype=torch.float64)

In [7]:
bb["conv1.lin.weight"]

tensor([[-0.0182,  0.0882,  0.0097,  ..., -0.0630,  0.0137, -0.0142],
        [-0.0137, -0.0323,  0.0690,  ...,  0.0586, -0.0718,  0.0425],
        [-0.0948, -0.0169, -0.2129,  ..., -0.0453,  0.0261,  0.0008],
        ...,
        [ 0.0455,  0.0312,  0.0367,  ...,  0.1213,  0.0331, -0.1026],
        [ 0.0481,  0.1346,  0.0286,  ..., -0.0846,  0.0547, -0.0316],
        [ 0.0465, -0.1177,  0.0113,  ...,  0.0605,  0.0737,  0.0774]],
       dtype=torch.float64)

In [8]:
model1 = GNNModel(
    input_dim=gnn_input_dim,
    hidden_dim=gnn_hidden_dim,
    output_dim=gnn_output_dim,
).to(device)
model1.double()

model2 = GNNModel(
    input_dim=gnn_input_dim,
    hidden_dim=gnn_hidden_dim,
    output_dim=gnn_output_dim,
).to(device)
model2.double()

model1.load_state_dict(aa)
model2.load_state_dict(bb)





<All keys matched successfully>

In [9]:
val_loss1 = 0.0
val_batches1 = 0
with torch.no_grad():  # Disable gradient calculation during validation
    for val_batch_idx, val_batch_data in tqdm(enumerate(val_data_loader)):
        # Compute embeddings
        val_embeddings1 = compute_embeddings(
            device=device,
            model_type=config['model_type'],
            model=model1,
            batch_data=val_batch_data,
        )
        # Compute loss
        val_batch_loss1 = loss_fn(val_embeddings1)

        val_loss1 += val_batch_loss1.item()
        val_batches1 += 1

    # Compute average loss
    average_val_loss1 = val_loss1 / val_batches1

print("Validation loss: ", average_val_loss1)



0it [00:00, ?it/s]

Validation loss:  3.619757167994976


In [10]:
val_loss2 = 0.0
val_batches2 = 0
with torch.no_grad():  # Disable gradient calculation during validation
    for val_batch_idx, val_batch_data in tqdm(enumerate(val_data_loader)):
        # Compute embeddings
        val_embeddings2 = compute_embeddings(
            device=device,
            model_type=config['model_type'],
            model=model2,
            batch_data=val_batch_data,
        )
        # Compute loss
        val_batch_loss2 = loss_fn(val_embeddings2)

        val_loss2 += val_batch_loss2.item()
        val_batches2 += 1

    # Compute average loss
    average_val_loss2 = val_loss2 / val_batches2

print("Validation loss: ", average_val_loss2)



0it [00:00, ?it/s]

Validation loss:  3.585004560649395


In [None]:
# 3: EVALUATE MODEL
model.eval()  # Set the model to evaluation mode

# Loss on train
train_loss = 0.0
train_batches = 0

with torch.no_grad():  # Disable gradient calculation during validation
    for train_batch_idx, train_batch_data in enumerate(train_data_loader):
        # Compute embeddings
        train_embeddings = compute_embeddings(
            device=device,
            model_type=config['model_type'],
            model=model,
            batch_data=train_batch_data,
        )
        # Compute loss
        train_batch_loss = loss_fn(train_embeddings)

        train_loss += train_batch_loss.item()
        train_batches += 1

    # Compute average loss
    average_train_loss = train_loss / train_batches

print("Train loss: ", average_train_loss)

# Loss on validation
val_loss = 0.0
val_batches = 0
with torch.no_grad():  # Disable gradient calculation during validation
    for val_batch_idx, val_batch_data in enumerate(val_data_loader):
        # Compute embeddings
        val_embeddings = compute_embeddings(
            device=device,
            model_type=config['model_type'],
            model=model,
            batch_data=val_batch_data,
        )
        # Compute loss
        val_batch_loss = loss_fn(val_embeddings)

        val_loss += val_batch_loss.item()
        val_batches += 1

    # Compute average loss
    average_val_loss = val_loss / val_batches

print("Validation loss: ", average_val_loss)



