In [47]:
import os
import pickle
import random

import deepchem as dc
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import MessagePassing
from torch_geometric.data import DataLoader
# from torch.utils.data import DataLoader
from torch_geometric.utils import to_dense_batch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.utils.data import Dataset

import numpy as np
from rdkit import Chem



In [None]:
checkpoint_folder = 'GraphRuns'
# if not os.path.exists(checkpoint_folder):
#     os.makedirs(checkpoint_folder)

checkpoint_name = 'gnn_checkpoint.pth'

not_in_route_sample_size = 10
seed=42

### Read routes data - consider only route 1

In [6]:
run_id = '202305-2911-2320-5a95df0e-3008-4ebe-acd8-ecb3b50607c7'

input_file_routes = f'Runs/{run_id}/targ_routes.pickle'
input_file_distances = f'Runs/{run_id}/targ_to_purch_distances.pickle'

# Routes data
with open(input_file_routes, 'rb') as handle:
    targ_routes_dict = pickle.load(handle)
    
# Load distances data
with open(input_file_distances, 'rb') as handle:
    distances_dict = pickle.load(handle)


# Inventory
from paroutes import PaRoutesInventory
inventory=PaRoutesInventory(n=5)
purch_smiles = [mol.smiles for mol in inventory.purchasable_mols()]
len(purch_smiles)

def num_heavy_atoms(mol):
    return Chem.rdchem.Mol.GetNumAtoms(mol, onlyExplicit=True)

purch_mol_to_exclude = []
for smiles in purch_smiles:
    if num_heavy_atoms(Chem.MolFromSmiles(smiles)) < 2:
        purch_mol_to_exclude = purch_mol_to_exclude + [smiles]



targ_route_not_in_route_dict = {}
for target, target_routes_dict in targ_routes_dict.items():
    targ_route_not_in_route_dict[target] = {}
    
    target_route_df = target_routes_dict["route_1"]
    purch_in_route = list(target_route_df['smiles'])
    purch_not_in_route = [purch_smile for purch_smile in purch_smiles if purch_smile not in purch_in_route]
    random.seed(seed)
    purch_not_in_route_sample = random.sample(purch_not_in_route,not_in_route_sample_size)
    
    # Filter out molecules with only one atom (problems with featurizer)
    purch_in_route = [smiles for smiles in purch_in_route if smiles not in purch_mol_to_exclude]
    purch_not_in_route_sample = [smiles for smiles in purch_not_in_route_sample if smiles not in purch_mol_to_exclude]
    
    targ_route_not_in_route_dict[target]['positive_samples'] = purch_in_route
    targ_route_not_in_route_dict[target]['negative_samples'] = purch_not_in_route_sample




### Temp - select a sample of targets

In [7]:
# Get a random sample of keys from targ_routes_dict
nr_sample_targets = 200
sample_targets = random.sample(list(targ_route_not_in_route_dict.keys()), nr_sample_targets)

# Create targ_routes_dict_sample with the sampled keys and their corresponding values
targ_route_not_in_route_dict_sample = {target: targ_route_not_in_route_dict[target] for target in sample_targets}


### Run

In [8]:

class SampleData:
    def __init__(self, target, positive_samples, negative_samples):
        self.target = target
        self.positive_samples = positive_samples
        self.negative_samples = negative_samples


# INPUT

# OUTPUT 
# data = [
#     {
#         'target': 'target1_feats',
#         'positive_samples': ['positive1_feats', 'positive2_feats'],
#         'negative_samples': ['negative1_feats', 'negative2_feats']
#     },
#     {
#         'target': 'target2_feats',
#         'positive_samples': ['positive3_feats', 'positive4_feats']
#         'negative_samples': ['negative3_feats', 'negative4_feats']
#     },
#     # ... add more data samples
# ]
def preprocess_input(input_data):
    featurizer = dc.feat.MolGraphConvFeaturizer()
    data_list = []
    
    for target_smiles, samples in tqdm(input_data.items()):
        target_feats = featurizer.featurize(Chem.MolFromSmiles(target_smiles))
        pos_mols = [Chem.MolFromSmiles(positive_smiles) for positive_smiles in samples['positive_samples']]
        neg_mols = [Chem.MolFromSmiles(negative_smiles) for negative_smiles in samples['negative_samples']]
        pos_feats = featurizer.featurize(pos_mols)
        neg_feats = featurizer.featurize(neg_mols)
        data_list = data_list + [SampleData(target=target_feats, positive_samples=pos_feats, negative_samples=neg_feats)]
    return data_list
        
#     X = []
#     y = []
#     featurizer = dc.feat.MolGraphConvFeaturizer()
#     for target_smiles, samples in tqdm(input_data.items()):
#         target_label = 1.0  # Assign a positive label to the target molecule
#         target_mol = Chem.MolFromSmiles(target_smiles)

#         # Add positive samples to the dataset
#         for positive_smiles in samples['positive_samples']:
#             positive_mol = Chem.MolFromSmiles(positive_smiles)
#             target_positive_feats = featurizer.featurize([target_mol, positive_mol])
#             X.append(target_positive_feats)
#             y.append(target_label)

#         # Add negative samples to the dataset
#         for negative_smiles in samples['negative_samples']:
#             negative_mol = Chem.MolFromSmiles(negative_smiles)
#             target_negative_feats = featurizer.featurize([target_mol, negative_mol])
#             X.append(target_negative_feats)
#             y.append(0.0)  # Assign a negative label

#     X = np.concatenate(X, axis=0)
#     y = np.array(y)

#     dataset = dc.data.NumpyDataset(X, y)
#     return dataset



# import torch
# from torch_geometric.data import Data

# def preprocess_input(input_data):
#     data_list = []
    
#     for target, samples in input_data.items():
#         target_embedding = get_embedding(target)  # Assuming you have a function to obtain the embedding for a target molecule
        
#         for positive_sample in samples['positive_samples']:
#             positive_embedding = get_embedding(positive_sample)  # Assuming you have a function to obtain the embedding for a positive sample
            
#             positive_data = Data(x=positive_embedding, y=torch.tensor([1.0]))  # Create a positive data instance
#             data_list.append(positive_data)
        
#         for negative_sample in samples['negative_samples']:
#             negative_embedding = get_embedding(negative_sample)  # Assuming you have a function to obtain the embedding for a negative sample
            
#             negative_data = Data(x=negative_embedding, y=torch.tensor([0.0]))  # Create a negative data instance
#             data_list.append(negative_data)
    
#     return data_list



In [53]:
# class CustomDataset(Dataset):
#     def __init__(self, data):
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         target_smiles = list(self.data.keys())[idx]
#         sample = self.data[target_smiles]

#         # Extract the positive_samples and negative_samples
#         positive_samples = sample['positive_samples']
#         negative_samples = sample['negative_samples']

#         # Convert the data to tensors or any other necessary preprocessing

#         # Return the sample with named attributes
#         return {
#             'target_smiles': target_smiles,
#             'positive_samples': positive_samples,
#             'negative_samples': negative_samples
#         }
    
# Takes a list of samples, where each sample is a dictionary containing the 'target_smiles', 
# 'positive_sample', and 'negative_samples'. 
# In the __getitem__ method, we extract these attributes and return the sample as a dictionary 
# with named attributes.
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
#         return sample

#         # Extract the target_smiles, positive_sample, and negative_samples
# #         target_smiles = sample['target_smiles']
# #         positive_sample = sample['positive_sample']
# #         negative_samples = sample['negative_samples']
        target = sample.target
        positive_samples = sample.positive_samples
        negative_samples = sample.negative_samples

        # Convert the data to tensors or any other necessary preprocessing

        # Return the sample with named attributes
#         return {
#             'target': target,
#             'positive_samples': positive_samples,
#             'negative_samples': negative_samples
#         }
        return target, positive_samples, negative_samples

In [10]:
# Step 2: Define embedding model
class FullyConnectedModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FullyConnectedModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# class GNNModel(MessagePassing):
#     def __init__(self, input_dim, hidden_dim, output_dim):
#         super(GNNModel, self).__init__(aggr='mean')
#         self.gnn = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )

#     def forward(self, x, edge_index):
#         x = self.gnn(x)
#         x = self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
#         return x
    
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
#     def __init__(self, hidden_dim, output_dim):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
#         self.conv1 = GCNConv(None,hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.fc(x)
        return x
    
# Step 3: Create a contrastive learning loss function
class NTXentLoss(nn.Module):
    def __init__(self, temperature):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.cos_sim = nn.CosineSimilarity(dim=-1)

#     def forward(self, embeddings, positive_pairs):
#         anchor, positive = positive_pairs[:, 0], positive_pairs[:, 1]
#         anchor_emb = embeddings[anchor]
#         positive_emb = embeddings[positive]
#         anchor_similarity = self.cos_sim(anchor_emb, positive_emb)
#         anchor_similarity /= self.temperature

#         all_similarity = self.cos_sim(embeddings.unsqueeze(1), embeddings.unsqueeze(0))
#         all_similarity = all_similarity.view(-1)
#         all_similarity = all_similarity[~positive_pairs.byte()]

#         numerator = torch.exp(anchor_similarity)
#         denominator = torch.sum(torch.exp(all_similarity))
#         loss = -torch.log(numerator / (numerator + denominator))
#         return loss.mean()

    def forward(self, embeddings):
        sample_losses = []
        for single_sample_embeddings in embeddings:
            target_emb = single_sample_embeddings.target
            positive_embs = single_sample_embeddings.positive_samples
            negative_embs = single_sample_embeddings.negative_samples
            
            # Sample one positive
            nr_positives = len(positive_embs)
            positive_emb = sample(positive_embs, 1)
            
            positive_similarity = self.cos_sim(target_emb, positive_emb)
            positive_similarity /= self.temperature
            negative_similarity = self.cos_sim(target_emb, negative_embs)
            negative_similarity /= self.temperature
            
            numerator = torch.exp(positive_similarity)
            denominator = torch.sum(torch.exp(negative_similarity))
            sample_loss = -torch.log(numerator / (numerator + denominator))
            sample_losses = sample_losses + [sample_loss]
        
        return sample_losses.mean()
    
    
# import torch
# from pytorch_metric_learning.losses import GenericPairLoss

# class NTXentLoss(GenericPairLoss):

#     def __init__(self, temperature, **kwargs):
#         super().__init__(use_similarity=True, mat_based_loss=False, **kwargs)
#         self.temperature = temperature

#     def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple):
#         a1, p, a2, _ = indices_tuple

#         if len(a1) > 0 and len(a2) > 0:
#             pos_pairs = pos_pairs.unsqueeze(1) / self.temperature
#             neg_pairs = neg_pairs / self.temperature
#             n_per_p = (a2.unsqueeze(0) == a1.unsqueeze(1)).float()
#             neg_pairs = neg_pairs*n_per_p
#             neg_pairs[n_per_p==0] = float('-inf')

#             max_val = torch.max(pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0].half()) ###This is the line change
#             numerator = torch.exp(pos_pairs - max_val).squeeze(1)
#             denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator
#             log_exp = torch.log((numerator/denominator) + 1e-20)
#             return {"loss": {"losses": -log_exp, "indices": (a1, p), "reduction_type": "pos_pair"}}
#         return self.zero_losses()




# Or use NTXentMultiplePositives

In [11]:
# # Step 1: Prepare the dataset
# # Assuming we have a dataset file named 'molecules.csv' with molecular structures and labels
# dataset = dc.data.CSVLoader(tasks=['property'], feature_field='smiles')
# dataset.load_from_file('molecules.csv')
# splitter = dc.splits.RandomSplitter()
# train_dataset, valid_dataset, _ = splitter.train_valid_test_split(dataset)

# Step 1: Create data dictionary getting negative samples for each target
input_data = targ_route_not_in_route_dict_sample

# Step 2: Featurizer and cast as CustomDataset
preprocessed_data = preprocess_input(input_data)
dataset = CustomDataset(preprocessed_data)

# Step 3: Create DataLoader
# def collate_fn(data):
#     targets = [sample[0] for sample in data]
#     positive_samples = [sample[1] for sample in data]
#     negative_samples = [sample[2] for sample in data]

#     return targets, positive_samples, negative_samples
def collate_fn(data):
    targets = [sample[0] for sample in data]
    positive_samples = [sample[1] for sample in data]
    negative_samples = [sample[2] for sample in data]

    return targets, positive_samples, negative_samples

batch_size = 64  
shuffle = True  
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)




100%|███████████████████████████████████████████████████████████████████████████| 200/200 [01:23<00:00,  2.40it/s]


In [54]:

dataset = CustomDataset(preprocessed_data)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)



In [12]:
gnn_input_dim = preprocessed_data[0].target[0].node_features.shape[1]
gnn_hidden_dim = 512
gnn_output_dim = 256

In [55]:
dataset[0]

(array([GraphData(node_features=[20, 30], edge_index=[2, 42], edge_features=None, pos=[0])],
       dtype=object),
 array([GraphData(node_features=[20, 30], edge_index=[2, 42], edge_features=None, pos=[0]),
        GraphData(node_features=[14, 30], edge_index=[2, 30], edge_features=None, pos=[0]),
        GraphData(node_features=[12, 30], edge_index=[2, 22], edge_features=None, pos=[0]),
        GraphData(node_features=[11, 30], edge_index=[2, 22], edge_features=None, pos=[0]),
        GraphData(node_features=[5, 30], edge_index=[2, 8], edge_features=None, pos=[0])],
       dtype=object),
 array([GraphData(node_features=[18, 30], edge_index=[2, 36], edge_features=None, pos=[0]),
        GraphData(node_features=[7, 30], edge_index=[2, 14], edge_features=None, pos=[0]),
        GraphData(node_features=[20, 30], edge_index=[2, 42], edge_features=None, pos=[0]),
        GraphData(node_features=[8, 30], edge_index=[2, 14], edge_features=None, pos=[0]),
        GraphData(node_features=[21, 3

In [56]:
preprocessed_data[0].positive_samples

array([GraphData(node_features=[20, 30], edge_index=[2, 42], edge_features=None, pos=[0]),
       GraphData(node_features=[14, 30], edge_index=[2, 30], edge_features=None, pos=[0]),
       GraphData(node_features=[12, 30], edge_index=[2, 22], edge_features=None, pos=[0]),
       GraphData(node_features=[11, 30], edge_index=[2, 22], edge_features=None, pos=[0]),
       GraphData(node_features=[5, 30], edge_index=[2, 8], edge_features=None, pos=[0])],
      dtype=object)

In [57]:
gnn_train_data_loader = data_loader

In [51]:
gnn_train_data_loader

<torch.utils.data.dataloader.DataLoader at 0x2d70764a0>

In [58]:
# Step 3: Set up the training loop for the GNN model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gnn_model = GNNModel(
    input_dim=gnn_input_dim, 
    hidden_dim=gnn_hidden_dim, 
    output_dim=gnn_output_dim).to(device)
# gnn_model = GNNModel(hidden_dim=gnn_hidden_dim, output_dim=gnn_output_dim).to(device)

gnn_loss_fn = NTXentLoss(temperature=0.1)
gnn_optimizer = optim.Adam(gnn_model.parameters(), lr=0.001)

num_epochs = 10
checkpoint_path = f'{checkpoint_folder}/{checkpoint_name}'  # Specify the checkpoint file path

# Check if a checkpoint exists and load the model state and optimizer state if available
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    gnn_model.load_state_dict(checkpoint['model_state_dict'])
    gnn_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
else:
    start_epoch = 0

# Create a SummaryWriter for TensorBoard logging
log_dir = f'{checkpoint_folder}/logs'  # Specify the directory to store TensorBoard logs
writer = SummaryWriter(log_dir)

gnn_model.train()
for epoch in tqdm(range(start_epoch, num_epochs)):
    total_loss = 0.0
    total_batches = 0
    for batch_idx, batch_data in enumerate(gnn_train_data_loader):
        optimizer.zero_grad()
        
        # Forward pass
        # TODO change, scan trough 3 list at the same time (targets, pos, neg)
        embeddings = []
        for sample_data in batch_data:
            target_embedding = gnn_model(sample_data.target)
            positive_sample_embeddings = gnn_model(sample_data.positive_samples)
            negative_samples_embeddings = gnn_model(sample_data.negative_samples)
            embeddings = embeddings + [SampleData(target=target_embedding, positive_samples=positive_sample_embeddings, negative_samples=negative_samples_embeddings)]

        # Compute loss
        loss = gnn_loss_fn(embeddings)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Track total loss
        total_loss += loss.item()
        total_batches += 1

    # Compute average loss for the epoch
    average_loss = total_loss / total_batches

    print(f"GNN Model - Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

    # Log the loss to TensorBoard
    writer.add_scalar('Loss/train', loss.item(), epoch+1)

    # Save the model and optimizer state as a checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': gnn_model.state_dict(),
        'optimizer_state_dict': gnn_optimizer.state_dict(),
    }
    torch.save(checkpoint, checkpoint_path)

# Close the SummaryWriter
writer.close()

# Step 4: Evaluate and use the trained embeddings
# You can evaluate the embeddings on downstream tasks or use them for molecular similarity search or property prediction


  0%|                                                                                      | 0/10 [00:00<?, ?it/s]


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object