In [None]:
import os
import pickle
import random

import deepchem as dc
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import MessagePassing
# from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader
# from torch.utils.data import DataLoader
from torch_geometric.utils import to_dense_batch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.utils.data import Dataset

import numpy as np
from rdkit import Chem

import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from rdkit.Chem import AllChem

import pandas as pd
import json

from torch.utils.data import Subset
from sklearn.model_selection import train_test_split

In [2]:
# Read json.config
with open('config_fnp_nn_0630.json', 'r') as f:
    config = json.load(f)
    


In [3]:
print(config)

{'experiment_name': 'gnn_0629', 'not_in_route_sample_size': 100, 'seed': 42, 'run_id': '202305-2911-2320-5a95df0e-3008-4ebe-acd8-ecb3b50607c7', 'nr_sample_targets': -1, 'model_type': 'gnn', 'validation_ratio': 0.2, 'train_batch_size': 64, 'train_shuffle': True, 'val_batch_size': 64, 'val_shuffle': False, 'pos_sampling': 'uniform', 'neg_sampling': 'uniform', 'hidden_dim': 512, 'output_dim': 256, 'temperature': 0.1, 'lr': 0.001, 'num_epochs': 100}


In [4]:
experiment_name = f"{config['experiment_name']}"

checkpoint_folder = f'GraphRuns/{experiment_name}/'
if not os.path.exists(checkpoint_folder):
    os.makedirs(checkpoint_folder)

checkpoint_name = 'checkpoint.pth'


In [5]:
experiment_name

'gnn_0629'

In [6]:
# Save 
with open(f'{checkpoint_folder}/config.json', 'w') as f:
    json.dump(config, f, indent=4)

In [7]:
save_preprocessed_data = True

### Read routes data - consider only route 1

In [8]:
input_file_routes = f'Runs/{config["run_id"]}/targ_routes.pickle'
# input_file_distances = f'Runs/{config["run_id"]}/targ_to_purch_distances.pickle'

# Routes data
with open(input_file_routes, 'rb') as handle:
    targ_routes_dict = pickle.load(handle)
    
# # Load distances data
# with open(input_file_distances, 'rb') as handle:
#     distances_dict = pickle.load(handle)


# Inventory
from paroutes import PaRoutesInventory, get_target_smiles

inventory=PaRoutesInventory(n=5)
purch_smiles = [mol.smiles for mol in inventory.purchasable_mols()]
len(purch_smiles)

def num_heavy_atoms(mol):
    return Chem.rdchem.Mol.GetNumAtoms(mol, onlyExplicit=True)

purch_mol_to_exclude = []
purch_nr_heavy_atoms = {}
for smiles in purch_smiles:
    nr_heavy_atoms = num_heavy_atoms(Chem.MolFromSmiles(smiles))
    if nr_heavy_atoms < 2:
        purch_mol_to_exclude = purch_mol_to_exclude + [smiles]
    purch_nr_heavy_atoms[smiles] = nr_heavy_atoms 

    
if config["run_id"]=="202305-2911-2320-5a95df0e-3008-4ebe-acd8-ecb3b50607c7":
    all_targets = get_target_smiles(n=5)
elif config["run_id"]=='Guacamol_combined':
     with open('Data/Guacamol/guacamol_v1_test_10ksample.txt', "r") as f:
        all_targets = [line.strip() for line in f.readlines()]  

    
targ_route_not_in_route_dict = {}
for target in all_targets:
    
    targ_route_not_in_route_dict[target] = {}
    
    target_routes_dict = targ_routes_dict.get(target, 'Target_Not_Solved')
    
    if target_routes_dict=='Target_Not_Solved':
        purch_in_route = []
    else:
        target_route_df = target_routes_dict["route_1"]
        purch_in_route = list(target_route_df.loc[target_route_df['label']!='Target', 'smiles'])
#         purch_in_route = [smiles for smiles in purch_in_route if smiles in purch_smiles]
    purch_not_in_route = [purch_smile for purch_smile in purch_smiles if purch_smile not in purch_in_route]
    random.seed(config["seed"])
    
    if config["neg_sampling"] == "uniform":
        purch_not_in_route_sample = random.sample(purch_not_in_route, config["not_in_route_sample_size"])
    elif config["neg_sampling"] == "...":
        pass
    else:
        raise NotImplementedError(f'{config["neg_sampling"]}')
    
    # Filter out molecules with only one atom (problems with featurizer)
    purch_in_route = [smiles for smiles in purch_in_route if smiles not in purch_mol_to_exclude]
    purch_not_in_route_sample = [smiles for smiles in purch_not_in_route_sample if smiles not in purch_mol_to_exclude]
    
    targ_route_not_in_route_dict[target]['positive_samples'] = purch_in_route
    targ_route_not_in_route_dict[target]['negative_samples'] = purch_not_in_route_sample




In [11]:
purch_mol_to_exclude

['Br', 'I', 'Cl', '[S-2]', 'N', 'S', 'F', 'O', '[Mg]']

In [12]:
# targ_route_not_in_route_dict['COc1ccc(F)c(-c2ccc(COc3cc(C(CC(=O)O)C4CC4)ccc3F)nc2CC(C)(C)C)c1'] 
# No positives

# targ_route_not_in_route_dict['CCc1cc2c(Br)cnnc2cc1OC']
# Si positives

{'positive_samples': ['O=[N+]([O-])O',
  'CC(=O)Cl',
  'CCc1ccccc1OC',
  'O=P(Br)(Br)Br',
  'O=NO'],
 'negative_samples': ['COC(=O)c1ccc(Br)c(C)c1',
  'CC(=O)Nc1ccc(B(O)O)cc1',
  'C[C@H](Cc1c[nH]c2ccccc12)NCC(C)(C)F',
  'NC(=O)Nc1ccc(S(=O)(=O)Cl)cc1',
  'COC(=O)c1ccc(-c2cc(OC)ncc2F)c(C2(C=O)CCCC2)c1',
  'CC(C)(C)[SiH2]OC(C)(C)c1cccc(CO)n1',
  'CCCC(C)(C#N)NC(=O)OC(C)(C)C',
  'C[C@H]1NC(=O)c2cc(B3OC(C)(C)C(C)(C)O3)[nH]c21',
  'CC(C)(C)OC(=O)Nc1ccc(F)c([C@]2(C)Cn3c(nc(Cl)c3C#N)C(N)=N2)c1',
  'COC(=O)C(CO)(CO)c1cc(-c2nc3cc(C(=N)N)ccc3[nH]2)c(O)c(-c2cc(C#N)ccc2O)c1',
  'O=C1c2ccccc2C(=O)N1[C@@H]1CCC[C@@H]1c1ccccc1',
  'C[C@@H](c1cc2cccc(Cl)c2nc1C(=O)O)N1C(=O)c2ccccc2C1=O',
  'C#CC',
  'OCC(CO)CO',
  'O=C(CBr)c1cccc(Cl)c1',
  'Cc1nn2c(Cl)cc(Cl)nc2c1Cc1cccc(C(F)(F)F)c1C',
  'COc1ccc(Oc2ccccc2)c(N)c1',
  'COC(=O)c1ccc2nc(C)c(-c3ccccc3)n2c1',
  'Brc1cnc2c(c1)OCCN2',
  'COC(=O)C=Cc1ccc(CO)cc1C',
  'OCc1cc(Cl)c(Cl)c(Cl)c1',
  'COC(=O)[C@@H]1C[C@H](NC2CCC(C)(C)CC2)CN1C(=O)OC(C)(C)C',
  'Nc1ccc(SC

In [13]:
# list(targ_route_not_in_route_dict.keys())[6]

'CCc1cc2c(Br)cnnc2cc1OC'

In [15]:
len(targ_routes_dict.keys())

9895

In [16]:
len(targ_route_not_in_route_dict.keys())



10000

### Temp - select a sample of targets

In [17]:
# Get a random sample of keys from targ_routes_dict
if config["nr_sample_targets"]!= -1:
    sample_targets = random.sample(list(targ_route_not_in_route_dict.keys()), config["nr_sample_targets"])
else:
    sample_targets = targ_route_not_in_route_dict


# Create targ_routes_dict_sample with the sampled keys and their corresponding values
targ_route_not_in_route_dict_sample = {target: targ_route_not_in_route_dict[target] for target in sample_targets}


### Run

In [18]:
def gnn_preprocess_input(input_data, featurizer, purch_featurizer_dict):
#     featurizer = dc.feat.MolGraphConvFeaturizer()
    targets = []
    positive_samples = []
    negative_samples = []

    for target_smiles, samples in tqdm(input_data.items()):
#         try:
        target_feats = featurizer.featurize(Chem.MolFromSmiles(target_smiles))
        pos_feats = [purch_featurizer_dict[positive_smiles] for positive_smiles in samples['positive_samples']]
        neg_feats = [purch_featurizer_dict[negative_smiles] for negative_smiles in samples['negative_samples']]
#         pos_feats = featurizer.featurize(pos_mols)
#         neg_feats = featurizer.featurize(neg_mols)

        targets.append(target_feats[0])
        positive_samples.append(pos_feats)
        negative_samples.append(neg_feats)
#             targets_torch = torch.tensor(target_feats, dtype=torch.double)
#             positive_samples = torch.tensor(pos_feats, dtype=torch.double)
#             negative_samples = torch.tensor(neg_feats, dtype=torch.double)
#             targets.append(targets_torch)
#             positive_samples.append(positive_samples)
#             negative_samples.append(negative_samples)
            
#         except:
#             # Handle the case where featurization fails for a sample
#             print(f"Featurization failed for sample: {target_smiles}")
        
#     targets_tensor = torch.stack(targets)
#     positive_samples_tensor = torch.stack(positive_samples)
#     negative_samples_tensor = torch.stack(negative_samples)
#     return targets_tensor, positive_samples_tensor, negative_samples_tensor
    return targets, positive_samples, negative_samples


def fingerprint_vect_from_smiles(mol_smiles):
    return AllChem.GetMorganFingerprintAsBitVect(AllChem.MolFromSmiles(mol_smiles), radius=3)

def fingerprint_preprocess_input(input_data, purch_fingerprints_dict):
    targets = []
    positive_samples = []
    negative_samples = []

    for target_smiles, samples in tqdm(input_data.items()):
#         target_feats = fingerprint_from_smiles(Chem.MolFromSmiles(target_smiles))
#         pos_mols = [Chem.MolFromSmiles(positive_smiles) for positive_smiles in samples['positive_samples']]
#         neg_mols = [Chem.MolFromSmiles(negative_smiles) for negative_smiles in samples['negative_samples']]
        target_feats = fingerprint_vect_from_smiles(target_smiles)
#         pos_feats = list(map(fingerprint_vect_from_smiles, samples['positive_samples']))
#         neg_feats = list(map(fingerprint_vect_from_smiles, samples['negative_samples']))
        pos_feats = [purch_fingerprints_dict[positive_smiles] for positive_smiles in samples['positive_samples']]
        neg_feats = [purch_fingerprints_dict[negative_smiles] for negative_smiles in samples['negative_samples']]
        
#         targets.append(target_feats[0])
#         positive_samples.append(pos_feats)
#         negative_samples.append(neg_feats)
        targets.append(torch.tensor(target_feats, dtype=torch.double))
        positive_samples.append(torch.tensor(pos_feats, dtype=torch.double))
        negative_samples.append(torch.tensor(neg_feats, dtype=torch.double))
        

    return targets, positive_samples, negative_samples



class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, targets, positive_samples, negative_samples):
        self.targets = targets
        self.positive_samples = positive_samples
        self.negative_samples = negative_samples

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        target = self.targets[idx]
        positive = self.positive_samples[idx]
        negative = self.negative_samples[idx]

        return target, positive, negative


# def collate_fn(data):
#     print("Using collate_fn")
#     targets = []
#     positive_samples = []
#     negative_samples = []

#     for target, positive, negative in data:
#         targets.append(target)
#         positive_samples.extend(positive)
#         negative_samples.extend(negative)

#     targets = torch.stack(targets, dim=0)
#     positive_samples = torch.stack(positive_samples, dim=0)
#     negative_samples = torch.stack(negative_samples, dim=0)

#     return targets, positive_samples, negative_samples

def collate_fn(data):
    targets, positive_samples, negative_samples = zip(*data)

    return targets, positive_samples, negative_samples


In [72]:

class SampleData:
    def __init__(self, target, positive_samples, negative_samples, pos_weights):
        self.target = target
        self.positive_samples = positive_samples
        self.negative_samples = negative_samples
        self.pos_weights = pos_weights


# def preprocess_input(input_data):
#     featurizer = dc.feat.MolGraphConvFeaturizer()
#     data_list = []
    
#     for target_smiles, samples in tqdm(input_data.items()):
#         target_feats = featurizer.featurize(Chem.MolFromSmiles(target_smiles))
#         pos_mols = [Chem.MolFromSmiles(positive_smiles) for positive_smiles in samples['positive_samples']]
#         neg_mols = [Chem.MolFromSmiles(negative_smiles) for negative_smiles in samples['negative_samples']]
#         pos_feats = featurizer.featurize(pos_mols)
#         neg_feats = featurizer.featurize(neg_mols)
#         data_list = data_list + [SampleData(target=target_feats, positive_samples=pos_feats, negative_samples=neg_feats)]
#     return data_list
        
# class CustomDataset(Dataset):
#     def __init__(self, data):
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         sample = self.data[idx]
# #         return sample

# #         # Extract the target_smiles, positive_sample, and negative_samples
# # #         target_smiles = sample['target_smiles']
# # #         positive_sample = sample['positive_sample']
# # #         negative_samples = sample['negative_samples']
#         target = sample.target
#         positive_samples = sample.positive_samples
#         negative_samples = sample.negative_samples

#         # Convert the data to tensors or any other necessary preprocessing

#         # Return the sample with named attributes
# #         return {
# #             'target': target,
# #             'positive_samples': positive_samples,
# #             'negative_samples': negative_samples
# #         }
#         return target, positive_samples, negative_samples

In [67]:
# Step 2: Define embedding model
class FullyConnectedModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FullyConnectedModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def custom_global_max_pool(x):
    return torch.max(x, dim=0)[0]

class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
#     def __init__(self, hidden_dim, output_dim):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # Global max pooling (from node level to graph level embeddings)
#         x = global_max_pool(x) #, edge_index[0]
        x = custom_global_max_pool(x)

        x = self.fc(x)
        return x

class FingerprintModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FingerprintModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, dtype=torch.double)
        self.fc2 = nn.Linear(hidden_dim, output_dim, dtype=torch.double)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# from scipy.special import logsumexp

# Step 3: Create a contrastive learning loss function
class NTXentLoss(nn.Module):
    def __init__(self, temperature):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.cos_sim = nn.CosineSimilarity(dim=-1)

    def forward(self, embeddings):
        sample_losses = []
        for single_sample_embeddings in embeddings:
            target_emb = single_sample_embeddings.target
            positive_embs = single_sample_embeddings.positive_samples
            negative_embs = single_sample_embeddings.negative_samples
            pos_weights = single_sample_embeddings.pos_weights
            
            
            # Positive similarity
            nr_positives = positive_embs.size(0)
            if nr_positives == 0:
                positive_similarity = torch.tensor(0.0)
            else:
                # Sample one positive
                if pos_weights is None:
                    # Randomly select a positive sample
                    row_index = torch.randint(0, nr_positives, (1,))
                    positive_emb = torch.index_select(positive_embs, dim=0, index=row_index)

                else:
                    assert len(pos_weights) == nr_positives, f'len pos_weight {len(pos_weights)} different from nr_positives{nr_positives} '            
                    row_index = torch.multinomial(pos_weights, 1)
                    positive_emb = torch.index_select(positive_embs, dim=0, index=row_index)
            
                positive_similarity = self.cos_sim(target_emb, positive_emb)
                positive_similarity /= self.temperature           
            
            # Negative similarity
            negative_similarity = self.cos_sim(target_emb, negative_embs)
            negative_similarity /= self.temperature
            
            # Old implementation
            numerator = torch.exp(positive_similarity)
            denominator = torch.sum(torch.exp(negative_similarity))
            sample_loss = -torch.log(numerator / (numerator + denominator))
            # End Old implementation
#             # New implementation
#             all_similarities = torch.cat([positive_similarity, negative_similarity], dim=0)
#             sample_loss = -positive_similarity + torch.logsumexp(all_similarities, dim=0, keepdims=True)
#             # End New implementation
            
            sample_losses = sample_losses + [sample_loss]
        
        return sum(sample_losses) / len(sample_losses)



# import torch
# from pytorch_metric_learning.losses import GenericPairLoss

# class NTXentLoss(GenericPairLoss):

#     def __init__(self, temperature, **kwargs):
#         super().__init__(use_similarity=True, mat_based_loss=False, **kwargs)
#         self.temperature = temperature

#     def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple):
#         a1, p, a2, _ = indices_tuple

#         if len(a1) > 0 and len(a2) > 0:
#             pos_pairs = pos_pairs.unsqueeze(1) / self.temperature
#             neg_pairs = neg_pairs / self.temperature
#             n_per_p = (a2.unsqueeze(0) == a1.unsqueeze(1)).float()
#             neg_pairs = neg_pairs*n_per_p
#             neg_pairs[n_per_p==0] = float('-inf')

#             max_val = torch.max(pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0].half()) ###This is the line change
#             numerator = torch.exp(pos_pairs - max_val).squeeze(1)
#             denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator
#             log_exp = torch.log((numerator/denominator) + 1e-20)
#             return {"loss": {"losses": -log_exp, "indices": (a1, p), "reduction_type": "pos_pair"}}
#         return self.zero_losses()




# Or use NTXentMultiplePositives

In [53]:
# # Step 1: Prepare the dataset
# # Assuming we have a dataset file named 'molecules.csv' with molecular structures and labels
# dataset = dc.data.CSVLoader(tasks=['property'], feature_field='smiles')
# dataset.load_from_file('molecules.csv')
# splitter = dc.splits.RandomSplitter()
# train_dataset, valid_dataset, _ = splitter.train_valid_test_split(dataset)

# Step 1: Create data dictionary getting negative samples for each target
input_data = targ_route_not_in_route_dict_sample


# Step 2: Featurizer and cast as CustomDataset
# Step 3: Create DataLoader


if config["model_type"] == 'gnn':
    featurizer = dc.feat.MolGraphConvFeaturizer()
    
    purch_mols = [Chem.MolFromSmiles(smiles) for smiles in purch_smiles]
    purch_featurizer = featurizer.featurize(purch_mols)
    purch_featurizer_dict = dict(zip(purch_smiles, purch_featurizer))
    with open(f'{checkpoint_folder}/purch_featurizer_dict.pickle', 'wb') as handle:
            pickle.dump(purch_featurizer_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    fingerprint_num_atoms_dict = None
    
    preprocessed_targets, preprocessed_positive_samples, preprocessed_negative_samples = gnn_preprocess_input(input_data, featurizer, purch_featurizer_dict)
    dataset = CustomDataset(preprocessed_targets, preprocessed_positive_samples, preprocessed_negative_samples)
elif config["model_type"] == 'fingerprints':
    purch_fingerprints = list(map(fingerprint_vect_from_smiles, purch_smiles))
    purch_fingerprints_dict = dict(zip(purch_smiles, purch_fingerprints))
    with open(f'{checkpoint_folder}/purch_fingerprints_dict.pickle', 'wb') as handle:
            pickle.dump(purch_fingerprints_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    # Also save dict to retrieve number of atoms from fingerprints
    fingerprint_num_atoms_dict = {fp: purch_nr_heavy_atoms[smiles] for smiles, fp in purch_fingerprints_dict.items()}
    with open(f'{checkpoint_folder}/fingerprint_num_atoms_dict.pickle', 'wb') as handle:
            pickle.dump(fingerprint_num_atoms_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    preprocessed_targets, preprocessed_positive_samples, preprocessed_negative_samples = fingerprint_preprocess_input(input_data, purch_fingerprints_dict)
    dataset = CustomDataset(preprocessed_targets, preprocessed_positive_samples, preprocessed_negative_samples)
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')
    
    

# def collate_fn(data):
#     targets = [sample[0] for sample in data]
#     positive_samples = [sample[1] for sample in data]
#     negative_samples = [sample[2] for sample in data]

#     return targets, positive_samples, negative_samples
# def collate_fn(data):
#     targets = [sample[0] for sample in data]
#     positive_samples = [sample[1] for sample in data]
#     negative_samples = [sample[2] for sample in data]

#     return targets, positive_samples, negative_samples

# Save
if save_preprocessed_data:
    with open(f'{checkpoint_folder}/preprocessed_targets.pickle', 'wb') as handle:
            pickle.dump(preprocessed_targets, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(f'{checkpoint_folder}/preprocessed_positive_samples.pickle', 'wb') as handle:
            pickle.dump(preprocessed_positive_samples, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(f'{checkpoint_folder}/preprocessed_negative_samples.pickle', 'wb') as handle:
            pickle.dump(preprocessed_negative_samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

# # Load
# with open(f'{checkpoint_folder}/preprocessed_targets.pickle', 'wb') as handle:
#     preprocessed_targets = pickle.load(handle)
# with open(f'{checkpoint_folder}/preprocessed_positive_samples.pickle', 'wb') as handle:
#     preprocessed_positive_samples = pickle.load(handle)
# with open(f'{checkpoint_folder}/preprocessed_negative_samples.pickle', 'wb') as handle:
#     preprocessed_negative_samples = pickle.load(handle)


KeyboardInterrupt: 

##### Train and validation split

In [None]:
validation_ratio = config["validation_ratio"]
num_samples = len(dataset)
num_val_samples = int(validation_ratio * num_samples)

train_indices, val_indices = train_test_split(range(num_samples), test_size=num_val_samples, random_state=42)

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

train_data_loader = DataLoader(train_dataset, batch_size=config["train_batch_size"], shuffle=config["train_shuffle"], collate_fn=collate_fn)
val_data_loader = DataLoader(val_dataset, batch_size=config["val_batch_size"], shuffle=config["val_shuffle"], collate_fn=collate_fn)


# Batch size: The batch size determines the number of samples processed in each iteration during training or validation. In most cases, it is common to use the same batch size for both training and validation to maintain consistency. However, there are situations where you might choose a different batch size for validation. For instance, if memory constraints are more relaxed during validation, you can use a larger batch size to speed up evaluation.
# Shuffle training data: Shuffling the training data before each epoch is beneficial because it helps the model see the data in different orders, reducing the risk of the model learning patterns specific to the order of the data. Shuffling the training data introduces randomness and promotes better generalization.
# No shuffle for validation data: It is generally not necessary to shuffle the validation data because validation is meant to evaluate the model's performance on unseen data that is representative of the real-world scenarios. Shuffling the validation data could lead to inconsistent evaluation results between different validation iterations, making it harder to track the model's progress and compare performance.


In [23]:
print(train_data_loader.collate_fn)

<function collate_fn at 0x2cc63bc70>


In [24]:
# len(preprocessed_targets[0])
# (preprocessed_targets[0].size()[0])

In [25]:
if config["model_type"] == 'gnn':
    gnn_input_dim = preprocessed_targets[0].node_features.shape[1]
    gnn_hidden_dim = config["hidden_dim"]
    gnn_output_dim = config["output_dim"]
    
    with open(f'{checkpoint_folder}/input_dim.pickle' , 'wb') as f:
        pickle.dump({'input_dim': gnn_input_dim}, f)
    
elif config["model_type"] == 'fingerprints':
#     fingerprint_input_dim = preprocessed_targets[0].GetNumBits()
    fingerprint_input_dim = (preprocessed_targets[0].size()[0]) # len(preprocessed_targets[0].node_features)
    fingerprint_hidden_dim = config["hidden_dim"]
    fingerprint_output_dim = config["output_dim"]
    
    with open(f'{checkpoint_folder}/input_dim.pickle' , 'wb') as f:
        pickle.dump({'input_dim': fingerprint_input_dim}, f)
    
    
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')



In [68]:
# Step 3: Set up the training loop for the GNN model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if config["model_type"] == 'gnn':
    model = GNNModel(
        input_dim=gnn_input_dim, 
        hidden_dim=gnn_hidden_dim, 
        output_dim=gnn_output_dim).to(device)
    model.double()
    
elif config["model_type"] == 'fingerprints':
    model = FingerprintModel(
        input_dim=fingerprint_input_dim, 
        hidden_dim=fingerprint_hidden_dim, 
        output_dim=fingerprint_output_dim).to(device)
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')



loss_fn = NTXentLoss(temperature=config["temperature"])
optimizer = optim.Adam(model.parameters(), lr=config["lr"])

num_epochs = config["num_epochs"]



In [56]:
# def compute_gnn_embedding(gnn_model):
#     node_features = torch.tensor(example.node_features, dtype=torch.double)
#     edge_index = torch.tensor(example.edge_index, dtype=torch.long)  # Assuming edge_index is of type 'long'

#     # Convert the input node features to double
#     node_features = node_features.double()

#     # Compute the embeddings for the positive example
#     embedding = model(node_features, edge_index)

In [57]:
train_data_loader.collate_fn

<function __main__.collate_fn(data)>

In [58]:
load_from_checkpoint=False

In [74]:
def compute_embeddings_and_loss(model, batch_targets, batch_positive_samples, batch_negative_samples, loss_fn, pos_sampling, fingerprint_num_atoms_dict=None):
    embeddings = []
    for i in range(len(batch_targets)):
        target = batch_targets[i]
        positives = batch_positive_samples[i]
        negatives = batch_negative_samples[i]

        if config["model_type"] == 'gnn':
            target_node_features = torch.tensor(target.node_features, dtype=torch.double)
            target_edge_index = torch.tensor(target.edge_index, dtype=torch.long)
            target_embedding = model(target_node_features, target_edge_index)
            
            if len(positives)==0:
                positive_samples_embeddings = torch.empty((0, target_embedding.size(0)))
            else:
                positive_samples_embeddings = torch.stack([
                    model(torch.tensor(example.node_features, dtype=torch.double),
                          torch.tensor(example.edge_index, dtype=torch.long))
                    for example in positives
                ], dim=0)

            negative_samples_embeddings = torch.stack([
                model(torch.tensor(example.node_features, dtype=torch.double),
                      torch.tensor(example.edge_index, dtype=torch.long))
                for example in negatives
            ], dim=0)

            
            if pos_sampling == "uniform":
                pos_weights = None
            elif pos_sampling == "prop_num_atoms":
                pos_weights = []
                for positive in positives:
                    pos_weights.append(positive.node_features.shape[0])
#                 print('Len pos weights', len(pos_weights))
                pos_weights = torch.tensor(pos_weights, dtype=torch.double)
                # Normalize the tensor to sum up to 1
                pos_weights = pos_weights / pos_weights.sum()
            else:
                raise NotImplementedError(f'{config["pos_sampling"]}')
                
            embeddings.append(
                SampleData(target=target_embedding, positive_samples=positive_samples_embeddings,
                           negative_samples=negative_samples_embeddings, pos_weights=pos_weights)
            )


        elif config["model_type"] == 'fingerprints':
            target_embedding = model(target)
            if len(positives)==0:
                positive_samples_embeddings = torch.empty((0, target_embedding.size(0)))
            else:
                positive_samples_embeddings = model(positives)
            negative_samples_embeddings = model(negatives)

            
            if pos_sampling == "uniform":
                pos_weights = None
            elif pos_sampling == "prop_num_atoms":
                pos_weights = [fingerprint_num_atoms_dict[positive]for positive in positives]  
                pos_weights = torch.tensor(pos_weights, dtype=torch.double)
                # Normalize the tensor to sum up to 1
                pos_weights = pos_weights / pos_weights.sum()

            else:
                raise NotImplementedError(f'{config["pos_sampling"]}')
            embeddings.append(
                SampleData(target=target_embedding, positive_samples=positive_samples_embeddings,
                           negative_samples=negative_samples_embeddings, pos_weights=pos_weights)
            )
            
#             embeddings = embeddings + [SampleData(target=target_embedding, positive_samples=positive_samples_embeddings, negative_samples=negative_samples_embeddings)]
        else:
            raise NotImplementedError(f'Model type {config["model_type"]}')

    # Compute loss for the batch
    loss = loss_fn(embeddings)

    return embeddings, loss

In [75]:
# Check if a checkpoint exists and load the model state and optimizer state if available
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
else:
    start_epoch = 0

# Create a SummaryWriter for TensorBoard logging
log_dir = f'{checkpoint_folder}/logs'  # Specify the directory to store TensorBoard logs
writer = SummaryWriter(log_dir)


best_val_loss = float('inf')
best_model = None

epoch_loss = pd.DataFrame(columns=['Epoch', 'TrainLoss', 'ValLoss'])
for epoch in tqdm(range(start_epoch, num_epochs)):
    # TRAIN
    model.train()
    train_loss = 0.0
    train_batches = 0
    
    for batch_idx, batch_data in enumerate(train_data_loader):
        batch_targets, batch_positive_samples, batch_negative_samples = batch_data

        optimizer.zero_grad()
        
        embeddings, loss = compute_embeddings_and_loss(model, batch_targets, batch_positive_samples, batch_negative_samples, loss_fn, config["pos_sampling"], fingerprint_num_atoms_dict)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Track total loss
        train_loss += loss.item()
        train_batches += 1
    
    # VALIDATION
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    val_batches = 0
    with torch.no_grad():  # Disable gradient calculation during validation
        for val_batch_idx, val_batch_data in enumerate(val_data_loader):
            val_batch_targets, val_batch_positive_samples, val_batch_negative_samples = val_batch_data

            val_embeddings, val_batch_loss = compute_embeddings_and_loss(model, val_batch_targets,
                                                                     val_batch_positive_samples,
                                                                     val_batch_negative_samples, loss_fn, config["pos_sampling"], fingerprint_num_atoms_dict)


            val_loss += val_batch_loss.item()
            val_batches += 1
            
    # METRICS
    # - TRAIN
    # Compute average loss for the epoch
    average_train_loss = train_loss / train_batches
        
    # Log the loss to TensorBoard
    writer.add_scalar('Loss/train', average_train_loss, epoch+1)
    
    # - VALIDATION
    average_val_loss = val_loss / val_batches
    
    # Log the loss to TensorBoard
    writer.add_scalar('Loss/val', average_val_loss, epoch+1)
    
    new_row = pd.DataFrame({'Epoch': [epoch], 'TrainLoss': [average_train_loss], 'ValLoss': [average_val_loss]})
    epoch_loss = pd.concat([epoch_loss, new_row], axis=0)

    
    if ((epoch%10==0) | (epoch==num_epochs-1)):
        print(f"{config['model_type']} Model - Epoch {epoch+1}/{num_epochs}, TrainLoss: {average_train_loss}, ValLoss: {average_val_loss}")
        
        # Save the model and optimizer state as a checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        checkpoint_path = f'{checkpoint_folder}/epoch_{epoch+1}_{checkpoint_name}'  # Specify the checkpoint file path
        torch.save(checkpoint, checkpoint_path)
        
#         loss_df = pd.DataFrame({'Epoch': range(len(epoch_loss)), 'TrainLoss': epoch_loss})
        epoch_loss.to_csv(f'{checkpoint_folder}/train_val_loss.csv', index=False)
    
    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        best_model = model
    
    
        

# Close the SummaryWriter
writer.close()

# Save the best model as a pickle
best_model_path = f'{checkpoint_folder}/model_min_val.pkl' #'path/to/best_model.pkl'

with open(best_model_path, 'wb') as f:
    pickle.dump(best_model, f)
    

  1%|▉                                                                                            | 1/100 [22:30<37:07:46, 1350.16s/it]

gnn Model - Epoch 1/100, TrainLoss: 3.3300414522612325, ValLoss: 2.778004522677617


  1%|▉                                                                                            | 1/100 [41:13<68:01:57, 2473.91s/it]


KeyboardInterrupt: 

In [None]:
import plotly.express as px
# fig = px.line(x=epoch_loss['Epoch'], y=epoch_loss['TrainLoss'], title="Train loss")
# fig.update_layout(width=1000, height=600, showlegend=False)
# fig.write_image(f"{checkpoint_folder}/Train_loss.pdf")
# fig.show()

# Create a new figure with two lines
fig = px.line()

# Add the TrainLoss line to the figure
fig.add_scatter(x=epoch_loss['Epoch'], y=epoch_loss['TrainLoss'], name='Train Loss')

# Add the ValLoss line to the figure
fig.add_scatter(x=epoch_loss['Epoch'], y=epoch_loss['ValLoss'], name='Validation Loss')

# Set the title of the figure
fig.update_layout(title="Train and Validation Loss")

# Set the layout size and show the legend
fig.update_layout(width=1000, height=600, showlegend=True)

# Save the figure as a PDF file
fig.write_image(f"{checkpoint_folder}/Train_and_Val_loss.pdf")


In [87]:
0 % np.nan == 0

False