In [2]:
"""
Script to train neural network to get molecules embeddings
"""
from __future__ import annotations

import argparse
import os
import pickle
import json
import random
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import time

from tqdm.auto import tqdm
from paroutes import PaRoutesInventory, get_target_smiles
from embedding_model import (
    fingerprint_preprocess_input,
    gnn_preprocess_input,
    CustomDataset,
    collate_fn,
    # SampleData,
    fingerprint_vect_from_smiles,
    compute_embeddings,
    GNNModel,
    FingerprintModel,
    NTXentLoss,
    num_heavy_atoms
)
from paroutes import PaRoutesInventory
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
import plotly.express as px
from rdkit import Chem
import deepchem as dc


def gnn_preprocess_input_v1(
    input_data,
#     featurizer,
#     featurizer_dict=None,
    pos_sampling=None,
):
    targets = []
    positive_samples = []
    negative_samples = []
    pos_weights = []

    for target_smiles, samples in tqdm(input_data.items()):
#         target_feats, pos_feats, neg_feats = gnn_preprocess_target_pos_negs(
#             target_smiles, samples, featurizer, featurizer_dict
#         )
        targets.append(target_smiles)
        positive_samples.append(samples["positive_samples"])
        negative_samples.append(samples["negative_samples"])

        # Deal with pos_sampling
        if pos_sampling == "uniform":
            positive_weights = None
        elif pos_sampling == "prop_num_atoms":
            positive_weights = [
                num_heavy_atoms(Chem.MolFromSmiles(positive_smiles))
                for positive_smiles in samples["positive_samples"]
            ]
            positive_weights = torch.tensor(positive_weights, dtype=torch.double)
            # Normalize the tensor to sum up to 1
            positive_weights = positive_weights / positive_weights.sum()
        else:
            raise NotImplementedError(f"{pos_sampling}")

        pos_weights.append(positive_weights)

    return CustomDataset(
        targets=targets,
        positive_samples=positive_samples,
        negative_samples=negative_samples,
        pos_weights=pos_weights,
    )

from embedding_model import gnn_preprocess_target_pos_negs, compute_actual_embeddings
def preprocess_and_compute_embeddings(
    device,
    model_type,
    model,
    batch_data,
    featurizer, 
    featurizer_dict
):
    targets = []
    positive_samples = []
    negative_samples = []
    pos_weights = []
    
    (
        batch_targets,
        batch_positive_samples,
        batch_negative_samples,
        batch_pos_weights,
    ) = batch_data

    for (
        target_i,
        positives_i,
        negatives_i,
        pos_weights_i,
    ) in zip(
        batch_targets,
        batch_positive_samples,
        batch_negative_samples,
        batch_pos_weights,
    ):
        # Prepare data
        # - Sample the negatives
        negatives_i_sample = random.sample(
            negatives_i, config["not_in_route_sample_size"]
        )
        samples = {
            "positive_samples": positives_i,
            "negative_samples": negatives_i_sample,
        }
        
        # - Featurize all
        target_feats_i, pos_feats_i, neg_feats_i = gnn_preprocess_target_pos_negs(
            target_i, samples, featurizer, featurizer_dict
        )
        
        (
            target_embedding,
            positive_samples_embeddings,
            negative_samples_embeddings,
        ) = compute_actual_embeddings(
            device=device,
            model_type=model_type,
            model=model,
            target=target_feats_i,
            positives=pos_feats_i,
            negatives=neg_feats_i,
        )
        targets.append(target_embedding)
        positive_samples.append(positive_samples_embeddings)
        negative_samples.append(negative_samples_embeddings)

        pos_weights.append(pos_weights_i)

    # return embeddings
    return CustomDataset(
        targets=targets,
        positive_samples=positive_samples,
        negative_samples=negative_samples,
        pos_weights=pos_weights,
    )




with open(f"config_gnn_0709_sampleInLoss.json", "r") as f:
    config = json.load(f)

experiment_name = config["experiment_name"]
checkpoint_folder = f"GraphRuns/{experiment_name}/"
if not os.path.exists(checkpoint_folder):
    os.makedirs(checkpoint_folder)

checkpoint_name = "checkpoint.pth"

# if not args.load_from_preprocessed_data:
# Save config in output folder
with open(f"{checkpoint_folder}/config.json", "w") as f:
    json.dump(config, f, indent=4)

# Read routes data
input_file_routes = f'Runs/{config["run_id"]}/targ_routes.pickle'
# input_file_distances = f'Runs/{config["run_id"]}/targ_to_purch_distances.pickle'

# Routes data
with open(input_file_routes, "rb") as handle:
    targ_routes_dict = pickle.load(handle)

# # Load distances data
# with open(input_file_distances, 'rb') as handle:
#     distances_dict = pickle.load(handle)

# Inventory

inventory = PaRoutesInventory(n=5)
purch_smiles = [mol.smiles for mol in inventory.purchasable_mols()]
# len(purch_smiles)

# def num_heavy_atoms(mol):
#     return Chem.rdchem.Mol.GetNumAtoms(mol, onlyExplicit=True)

purch_mol_to_exclude = []
purch_nr_heavy_atoms = {}
for smiles in purch_smiles:
    nr_heavy_atoms = num_heavy_atoms(Chem.MolFromSmiles(smiles))
    if nr_heavy_atoms < 2:
        purch_mol_to_exclude = purch_mol_to_exclude + [smiles]
    purch_nr_heavy_atoms[smiles] = nr_heavy_atoms

if config["run_id"] == "202305-2911-2320-5a95df0e-3008-4ebe-acd8-ecb3b50607c7":
    all_targets = get_target_smiles(n=5)
elif config["run_id"] == "Guacamol_combined":
    with open("Data/Guacamol/guacamol_v1_test_10ksample.txt", "r") as f:
        all_targets = [line.strip() for line in f.readlines()]

targ_route_not_in_route_dict = {}
for target in all_targets:
    targ_route_not_in_route_dict[target] = {}

    target_routes_dict = targ_routes_dict.get(target, "Target_Not_Solved")

    if target_routes_dict == "Target_Not_Solved":
        purch_in_route = []
    else:
        target_route_df = target_routes_dict["route_1"]
        purch_in_route = list(
            target_route_df.loc[target_route_df["label"] != "Target", "smiles"]
        )
    #         purch_in_route = [smiles for smiles in purch_in_route if smiles in purch_smiles]
    purch_not_in_route = [
        purch_smile
        for purch_smile in purch_smiles
        if purch_smile not in purch_in_route
    ]
    random.seed(config["seed"])

    
#     if config["neg_sampling"] == "uniform":
#         purch_not_in_route_sample = random.sample(
#             purch_not_in_route, config["not_in_route_sample_size"]
#         )
#     elif config["neg_sampling"] == "...":
#         pass
#     else:
#         raise NotImplementedError(f'{config["neg_sampling"]}')
    purch_not_in_route_sample = purch_not_in_route

    # Filter out molecules with only one atom (problems with featurizer)
    purch_in_route = [
        smiles for smiles in purch_in_route if smiles not in purch_mol_to_exclude
    ]
    purch_not_in_route_sample = [
        smiles
        for smiles in purch_not_in_route_sample
        if smiles not in purch_mol_to_exclude
    ]

    targ_route_not_in_route_dict[target]["positive_samples"] = purch_in_route
    targ_route_not_in_route_dict[target][
        "negative_samples"
    ] = purch_not_in_route_sample

# Get a random sample of keys from targ_routes_dict
if config["nr_sample_targets"] != -1:
    sample_targets = random.sample(
        list(targ_route_not_in_route_dict.keys()), config["nr_sample_targets"]
    )
else:
    sample_targets = targ_route_not_in_route_dict
# Create targ_routes_dict_sample with the sampled keys and their corresponding values
targ_route_not_in_route_dict_sample = {
    target: targ_route_not_in_route_dict[target] for target in sample_targets
}

input_data = targ_route_not_in_route_dict_sample

if config["model_type"] == "gnn":
    featurizer = dc.feat.MolGraphConvFeaturizer()

    purch_mols = [Chem.MolFromSmiles(smiles) for smiles in purch_smiles]
    purch_featurizer = featurizer.featurize(purch_mols)
    purch_featurizer_dict = dict(zip(purch_smiles, purch_featurizer))
    with open(f"{checkpoint_folder}/purch_featurizer_dict.pickle", "wb") as handle:
        pickle.dump(purch_featurizer_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    fingerprint_num_atoms_dict = None

    dataset = gnn_preprocess_input_v1(
        input_data=input_data, 
#         featurizer=featurizer, 
#         featurizer_dict=purch_featurizer_dict,
        pos_sampling=config["pos_sampling"],
    )

elif config["model_type"] == "fingerprints":
    purch_fingerprints = list(map(fingerprint_vect_from_smiles, purch_smiles))
    purch_fingerprints_dict = dict(zip(purch_smiles, purch_fingerprints))
    with open(
        f"{checkpoint_folder}/purch_fingerprints_dict.pickle", "wb"
    ) as handle:
        pickle.dump(
            purch_fingerprints_dict, handle, protocol=pickle.HIGHEST_PROTOCOL
        )

    # Also save dict to retrieve number of atoms from fingerprints
    # fingerprint_num_atoms_dict = {
    #     torch.tensor(fp, dtype=torch.double): purch_nr_heavy_atoms[smiles]
    #     for smiles, fp in purch_fingerprints_dict.items()
    # }
    # with open(
    #     f"{checkpoint_folder}/fingerprint_num_atoms_dict.pickle", "wb"
    # ) as handle:
    #     pickle.dump(
    #         fingerprint_num_atoms_dict, handle, protocol=pickle.HIGHEST_PROTOCOL
    #     )

    dataset = fingerprint_preprocess_input(
        input_data, 
        fingerprints_dict=purch_fingerprints_dict, 
        pos_sampling=config["pos_sampling"],
    )

else:
    raise NotImplementedError(f'Model type {config["model_type"]}')

# if args.save_preprocessed_data:
#     with open(f"{checkpoint_folder}/preprocessed_targets.pickle", "wb") as handle:
#         pickle.dump(preprocessed_targets, handle, protocol=pickle.HIGHEST_PROTOCOL)
#     with open(
#         f"{checkpoint_folder}/preprocessed_positive_samples.pickle", "wb"
#     ) as handle:
#         pickle.dump(
#             preprocessed_positive_samples, handle, protocol=pickle.HIGHEST_PROTOCOL
#         )
#     with open(
#         f"{checkpoint_folder}/preprocessed_negative_samples.pickle", "wb"
#     ) as handle:
#         pickle.dump(
#             preprocessed_negative_samples, handle, protocol=pickle.HIGHEST_PROTOCOL
#         )
# else:
#     with open(f"{checkpoint_folder}/preprocessed_targets.pickle", "rb") as handle:
#         preprocessed_targets = pickle.load(handle)
#     with open(f"{checkpoint_folder}/preprocessed_positive_samples.pickle", "rb") as handle:
#         preprocessed_positive_samples = pickle.load(handle)
#     with open(f"{checkpoint_folder}/preprocessed_negative_samples.pickle", "rb") as handle:
#         preprocessed_negative_samples = pickle.load(handle)
#     if config["model_type"] == "fingerprints":
#         with open(
#             f"{checkpoint_folder}/fingerprint_num_atoms_dict.pickle", "rb"
#         ) as handle:
#             fingerprint_num_atoms_dict = pickle.load(handle)
#     else:
#         fingerprint_num_atoms_dict = None

#     dataset = CustomDataset(
#             preprocessed_targets,
#             preprocessed_positive_samples,
#             preprocessed_negative_samples,
#         )

# Train validation split
validation_ratio = config["validation_ratio"]
num_samples = len(dataset)
num_val_samples = int(validation_ratio * num_samples)

train_indices, val_indices = train_test_split(
    range(num_samples), test_size=num_val_samples, random_state=42
)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

train_data_loader = DataLoader(
    train_dataset,
    batch_size=config["train_batch_size"],
    shuffle=config["train_shuffle"],
    collate_fn=collate_fn,
)
val_data_loader = DataLoader(
    val_dataset,
    batch_size=config["val_batch_size"],
    shuffle=config["val_shuffle"],
    collate_fn=collate_fn,
)

# Batch size: The batch size determines the number of samples processed in each iteration during training or validation. In most cases, it is common to use the same batch size for both training and validation to maintain consistency. However, there are situations where you might choose a different batch size for validation. For instance, if memory constraints are more relaxed during validation, you can use a larger batch size to speed up evaluation.
# Shuffle training data: Shuffling the training data before each epoch is beneficial because it helps the model see the data in different orders, reducing the risk of the model learning patterns specific to the order of the data. Shuffling the training data introduces randomness and promotes better generalization.
# No shuffle for validation data: It is generally not necessary to shuffle the validation data because validation is meant to evaluate the model's performance on unseen data that is representative of the real-world scenarios. Shuffling the validation data could lead to inconsistent evaluation results between different validation iterations, making it harder to track the model's progress and compare performance.



Failed to featurize datapoint 1585, Cl. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 3118, F. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 3161, [S-2]. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 4706, [Mg]. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 8338, N. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 10127, O. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize

  0%|          | 0/10000 [00:00<?, ?it/s]

AttributeError: 'str' object has no attribute 'node_features'

In [13]:
# Define network dimensions
if config["model_type"] == "gnn":
    gnn_input_dim = 30 # dataset.targets[0].node_features.shape[1]
    gnn_hidden_dim = config["hidden_dim"]
    gnn_output_dim = config["output_dim"]

    with open(f"{checkpoint_folder}/input_dim.pickle", "wb") as f:
        pickle.dump({"input_dim": gnn_input_dim}, f)

elif config["model_type"] == "fingerprints":
    #     fingerprint_input_dim = preprocessed_targets[0].GetNumBits()
    fingerprint_input_dim = dataset.targets[0].size()[
        0
    ]  # len(preprocessed_targets[0].node_features)
    fingerprint_hidden_dim = config["hidden_dim"]
    fingerprint_output_dim = config["output_dim"]

    with open(f"{checkpoint_folder}/input_dim.pickle", "wb") as f:
        pickle.dump({"input_dim": fingerprint_input_dim}, f)

else:
    raise NotImplementedError(f'Model type {config["model_type"]}')

# Step 3: Set up the training loop for the GNN model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if config["model_type"] == "gnn":
    model = GNNModel(
        input_dim=gnn_input_dim,
        hidden_dim=gnn_hidden_dim,
        output_dim=gnn_output_dim,
    ).to(device)
    model.double()

elif config["model_type"] == "fingerprints":
    model = FingerprintModel(
        input_dim=fingerprint_input_dim,
        hidden_dim=fingerprint_hidden_dim,
        output_dim=fingerprint_output_dim,
    ).to(device)
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')

loss_fn = NTXentLoss(temperature=config["temperature"], device=device)
optimizer = optim.Adam(model.parameters(), lr=config["lr"])

num_epochs = config["num_epochs"]

load_from_checkpoint = False
# input_checkpoint_folder  = 'GraphRuns/gnn_0629'
# input_checkpoint_path = f'{checkpoint_folder}/epoch_71_checkpoint.pth'

# STEP 5: Train loop
# Check if a checkpoint exists and load the model state and optimizer state if available
if load_from_checkpoint:
    checkpoint = torch.load(input_checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1

    epoch_loss = pd.read_csv(f'{input_checkpoint_folder}/train_val_loss.csv')
    best_val_loss = epoch_loss["ValLoss"].min()
    with open(f'{input_checkpoint_folder}/model_min_val.pkl', "rb") as handle:
        best_model = pickle.load(handle)
else:
    start_epoch = 0
    best_val_loss = float("inf")
    best_model = None
    epoch_loss = pd.DataFrame(columns=["Epoch", "TrainLoss", "ValLoss"])

# Create a SummaryWriter for TensorBoard logging
log_dir = (
    f"{checkpoint_folder}/logs"  # Specify the directory to store TensorBoard logs
)
writer = SummaryWriter(log_dir)

# best_val_loss = float("inf")
# best_model = None

# epoch_loss = pd.DataFrame(columns=["Epoch", "TrainLoss", "ValLoss"])
for epoch in tqdm(range(start_epoch, num_epochs)):
    # TRAIN
    model.train()
    train_loss = 0.0
    train_batches = 0

    for batch_idx, batch_data in enumerate(train_data_loader):
        optimizer.zero_grad()
        

        # Compute embeddings
        embeddings_dataset = preprocess_and_compute_embeddings(
            device=device,
            model_type=config['model_type'],
            model=model,
            batch_data=batch_data,
            featurizer=featurizer, 
            featurizer_dict=purch_featurizer_dict,
        )
        # Compute loss
        loss = loss_fn(embeddings_dataset)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Track total loss
        train_loss += loss.item()
        train_batches += 1

    # VALIDATION
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    val_batches = 0
    with torch.no_grad():  # Disable gradient calculation during validation
        for val_batch_idx, val_batch_data in enumerate(val_data_loader):
            # Compute embeddings
            val_embeddings = preprocess_and_compute_embeddings(
                device=device,
                model_type=config['model_type'],
                model=model,
                batch_data=val_batch_data,
                featurizer=featurizer, 
                featurizer_dict=purch_featurizer_dict,
            )
            # Compute loss
            val_batch_loss = loss_fn(val_embeddings)

            val_loss += val_batch_loss.item()
            val_batches += 1

    # METRICS
    # - TRAIN
    # Compute average loss for the epoch
    average_train_loss = train_loss / train_batches

    # Log the loss to TensorBoard
    writer.add_scalar("Loss/train", average_train_loss, epoch + 1)

    # - VALIDATION
    average_val_loss = val_loss / val_batches

    # Log the loss to TensorBoard
    writer.add_scalar("Loss/val", average_val_loss, epoch + 1)

    new_row = pd.DataFrame(
        {
            "Epoch": [epoch],
            "TrainLoss": [average_train_loss],
            "ValLoss": [average_val_loss],
        }
    )
    epoch_loss = pd.concat([epoch_loss, new_row], axis=0)

    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        best_model = model

    if (epoch % 1 == 0) | (epoch == num_epochs - 1):
        print(
            f"{config['model_type']} Model - Epoch {epoch+1}/{num_epochs}, TrainLoss: {average_train_loss}, ValLoss: {average_val_loss}"
        )

        # Save the model and optimizer state as a checkpoint
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        checkpoint_path = f"{checkpoint_folder}/epoch_{epoch+1}_{checkpoint_name}"  # Specify the checkpoint file path
        torch.save(checkpoint, checkpoint_path)

        #         loss_df = pd.DataFrame({'Epoch': range(len(epoch_loss)), 'TrainLoss': epoch_loss})
        epoch_loss.to_csv(f"{checkpoint_folder}/train_val_loss.csv", index=False)

        # Save the best model as a pickle
        best_model_path = (
            f"{checkpoint_folder}/model_min_val.pkl"  #'path/to/best_model.pkl'
        )

        with open(best_model_path, "wb") as f:
            pickle.dump(best_model, f)



# Close the SummaryWriter
writer.close()



# STEP 6: Plot

# fig = px.line(x=epoch_loss['Epoch'], y=epoch_loss['TrainLoss'], title="Train loss")
# fig.update_layout(width=1000, height=600, showlegend=False)
# fig.write_image(f"{checkpoint_folder}/Train_loss.pdf")
# fig.show()

# Create a new figure with two lines    
fig = px.line()

# Add the TrainLoss line to the figure
fig.add_scatter(x=epoch_loss["Epoch"], y=epoch_loss["TrainLoss"], name="Train Loss")

# Add the ValLoss line to the figure
fig.add_scatter(
    x=epoch_loss["Epoch"], y=epoch_loss["ValLoss"], name="Validation Loss"
)

# Set the title of the figure
fig.update_layout(title="Train and Validation Loss")

# Set the layout size and show the legend
fig.update_layout(width=1000, height=600, showlegend=True)

# Save the figure as a PDF file
fig.write_image(f"{checkpoint_folder}/Train_and_Val_loss.pdf")
time.sleep(10)
fig.write_image(f"{checkpoint_folder}/Train_and_Val_loss.pdf")



  0%|          | 0/5 [00:00<?, ?it/s]

gnn Model - Epoch 1/5, TrainLoss: 3.904505304336548, ValLoss: 3.3473444059491158
gnn Model - Epoch 2/5, TrainLoss: 3.1036663646697997, ValLoss: 3.030165858566761
gnn Model - Epoch 3/5, TrainLoss: 2.8874917736053467, ValLoss: 2.8653238862752914
gnn Model - Epoch 4/5, TrainLoss: 2.7190915870666506, ValLoss: 2.6334672793745995
gnn Model - Epoch 5/5, TrainLoss: 2.6033140258789063, ValLoss: 2.571866288781166


In [None]:
val_loss_post = 0.0
val_batches_post = 0
with torch.no_grad():  # Disable gradient calculation during validation
    for val_batch_idx, val_batch_data in tqdm(enumerate(val_data_loader)):
        # Compute embeddings
        val_embeddings_post = compute_embeddings(
            device=device,
            model_type=config['model_type'],
            model=model,
            batch_data=val_batch_data,
        )
        # Compute loss
        val_batch_loss_post = loss_fn(val_embeddings_post)

        val_loss_post += val_batch_loss_post.item()
        val_batches_post += 1

    # Compute average loss
    average_val_loss_post = val_loss_post / val_batches_post

print("Validation loss: ", average_val_loss_post)

# Should be the same



In [None]:
# MODEL TO EVALUATE
experiment_name = config["experiment_name"]  # gnn_0627
checkpoint_folder = f"GraphRuns/{experiment_name}/"
input_checkpoint_name = f"epoch_5_checkpoint.pth"



In [None]:
# 0. LOAD MODEL 
# Option 1: FROM CHECKPOINT
input_checkpoint_folder  = f'GraphRuns/{experiment_name}'
input_checkpoint_path = f'{input_checkpoint_folder}/{input_checkpoint_name}'


with open(f'{checkpoint_folder}/input_dim.pickle', "rb") as f:
    input_dim = pickle.load(f)['input_dim']

# Define network dimensions
if config["model_type"] == "gnn":
    gnn_input_dim = input_dim
    gnn_hidden_dim = config["hidden_dim"]
    gnn_output_dim = config["output_dim"]

elif config["model_type"] == "fingerprints":
    #     fingerprint_input_dim = preprocessed_targets[0].GetNumBits()
    fingerprint_input_dim = input_dim
    fingerprint_hidden_dim = config["hidden_dim"]
    fingerprint_output_dim = config["output_dim"]

else:
    raise NotImplementedError(f'Model type {config["model_type"]}')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if config["model_type"] == "gnn":
    model_loaded = GNNModel(
        input_dim=gnn_input_dim,
        hidden_dim=gnn_hidden_dim,
        output_dim=gnn_output_dim,
    ).to(device)
    model_loaded.double()

elif config["model_type"] == "fingerprints":
    model_loaded = FingerprintModel(
        input_dim=fingerprint_input_dim,
        hidden_dim=fingerprint_hidden_dim,
        output_dim=fingerprint_output_dim,
    ).to(device)
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')

loss_fn_loaded = NTXentLoss(temperature=config["temperature"], device=device)

checkpoint = torch.load(input_checkpoint_path)
model_loaded.load_state_dict(checkpoint["model_state_dict"])

# # # OPTION 2: From pickle
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# loss_fn = NTXentLoss(temperature=config["temperature"], device=device)
# with open(f'{checkpoint_folder}/model_min_val.pkl', "rb") as f:
#     model = pickle.load(f)




In [None]:
val_loss1 = 0.0
val_batches1 = 0
with torch.no_grad():  # Disable gradient calculation during validation
    for val_batch_idx, val_batch_data in tqdm(enumerate(val_data_loader)):
        # Compute embeddings
        val_embeddings1_loaded = compute_embeddings(
            device=device,
            model_type=config['model_type'],
            model=model_loaded,
            batch_data=val_batch_data,
        )
        # Compute loss
        val_batch_loss1 = loss_fn_loaded(val_embeddings1_loaded)

        val_loss1 += val_batch_loss1.item()
        val_batches1 += 1

    # Compute average loss
    average_val_loss1 = val_loss1 / val_batches1

print("Validation loss: ", average_val_loss1)

# If it is not the same, there is an issue in the save
# If it is the same, there is an issue in the preprocessing 
# --> Try applying model and model_loaded (on the new preprocessed data)


In [None]:
# 1. PREPROCESS DATA AGAIN (Only gnn_preprocess_input)
input_data_1 = input_data.copy()

if config["model_type"] == "gnn":
    featurizer_1 = dc.feat.MolGraphConvFeaturizer()

    purch_mols_1 = [Chem.MolFromSmiles(smiles) for smiles in purch_smiles]
    purch_featurizer_1 = featurizer.featurize(purch_mols_1)
    purch_featurizer_dict_1 = dict(zip(purch_smiles, purch_featurizer_1))

    dataset_1 = gnn_preprocess_input(
        input_data=input_data_1, 
        featurizer=featurizer_1, 
        featurizer_dict=purch_featurizer_dict_1,
        pos_sampling=config["pos_sampling"],
    )

# elif config["model_type"] == "fingerprints":
#     purch_fingerprints = list(map(fingerprint_vect_from_smiles, purch_smiles))
#     purch_fingerprints_dict = dict(zip(purch_smiles, purch_fingerprints))


#     dataset = fingerprint_preprocess_input(
#         input_data, 
#         fingerprints_dict=purch_fingerprints_dict, 
#         pos_sampling=config["pos_sampling"],
#     )

else:
    raise NotImplementedError(f'Model type {config["model_type"]}')


# 2. TRAIN VALIDATION SPLIT
validation_ratio = config["validation_ratio"]
num_samples_1 = len(dataset_1)
num_val_samples_1 = int(validation_ratio * num_samples_1)

train_indices_1, val_indices_1 = train_test_split(
    range(num_samples_1), test_size=num_val_samples_1, random_state=42
)

train_dataset_1 = Subset(dataset_1, train_indices_1)
val_dataset_1 = Subset(dataset_1, val_indices_1)

train_data_loader_1 = DataLoader(
    train_dataset_1,
    batch_size=config["train_batch_size"],
    shuffle=config["train_shuffle"],
    collate_fn=collate_fn,
)
val_data_loader_1 = DataLoader(
    val_dataset_1,
    batch_size=config["val_batch_size"],
    shuffle=config["val_shuffle"],
    collate_fn=collate_fn,
)

# Batch size: The batch size determines the number of samples processed in each iteration during training or validation. In most cases, it is common to use the same batch size for both training and validation to maintain consistency. However, there are situations where you might choose a different batch size for validation. For instance, if memory constraints are more relaxed during validation, you can use a larger batch size to speed up evaluation.
# Shuffle training data: Shuffling the training data before each epoch is beneficial because it helps the model see the data in different orders, reducing the risk of the model learning patterns specific to the order of the data. Shuffling the training data introduces randomness and promotes better generalization.
# No shuffle for validation data: It is generally not necessary to shuffle the validation data because validation is meant to evaluate the model's performance on unseen data that is representative of the real-world scenarios. Shuffling the validation data could lead to inconsistent evaluation results between different validation iterations, making it harder to track the model's progress and compare performance.








In [None]:
val_loss_new_preprocess = 0.0
val_batches_new_preprocess = 0
with torch.no_grad():  # Disable gradient calculation during validation
    for val_batch_idx, val_batch_data in tqdm(enumerate(val_data_loader_1)):
        # Compute embeddings
        val_embeddings_new_preprocess = compute_embeddings(
            device=device,
            model_type=config['model_type'],
            model=model,
            batch_data=val_batch_data,
        )
        # Compute loss
        val_batch_loss_new_preprocess = loss_fn_loaded(val_embeddings_new_preprocess)

        val_loss_new_preprocess += val_batch_loss_new_preprocess.item()
        val_batches_new_preprocess += 1

    # Compute average loss
    average_val_loss_new_preprocess = val_loss_new_preprocess / val_batches_new_preprocess

print("Validation loss: ", average_val_loss_new_preprocess)

# If it is not the same, there is an issue in the save
# If it is the same, there is an issue in the preprocessing 
# --> Try applying model and model_loaded (on the new preprocessed data)

