In [1]:
import os
import pickle
import random

import deepchem as dc
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import MessagePassing
# from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader
# from torch.utils.data import DataLoader
from torch_geometric.utils import to_dense_batch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.utils.data import Dataset

import numpy as np
from rdkit import Chem

import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from rdkit.Chem import AllChem

import pandas as pd
import pickle
import json

Skipped loading some Jax models, missing a dependency. jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.


In [2]:
# Read json.config
with open('config_gnn.json', 'r') as f:
    config = json.load(f)
    


In [3]:
print(config)

{'not_in_route_sample_size': 100, 'seed': 42, 'run_id': '202305-2911-2320-5a95df0e-3008-4ebe-acd8-ecb3b50607c7', 'nr_sample_targets': -1, 'model_type': 'gnn', 'batch_size': 64, 'shuffle': True, 'hidden_dim': 512, 'output_dim': 256, 'temperature': 0.1, 'lr': 0.001, 'num_epochs': 100}


In [4]:
experiment_name = f"{config['model_type']}_v1"

checkpoint_folder = f'GraphRuns/{experiment_name}/'
# if not os.path.exists(checkpoint_folder):
#     os.makedirs(checkpoint_folder)

checkpoint_name = 'checkpoint.pth'


In [5]:
experiment_name

'gnn_v1'

In [6]:
# Save 
with open(f'{checkpoint_folder}/config.json', 'w') as f:
    json.dump(config, f, indent=4)

In [7]:
save_preprocessed_data = True

### Read routes data - consider only route 1

In [8]:
input_file_routes = f'Runs/{config["run_id"] }/targ_routes.pickle'
input_file_distances = f'Runs/{config["run_id"]}/targ_to_purch_distances.pickle'

# Routes data
with open(input_file_routes, 'rb') as handle:
    targ_routes_dict = pickle.load(handle)
    
# Load distances data
with open(input_file_distances, 'rb') as handle:
    distances_dict = pickle.load(handle)


# Inventory
from paroutes import PaRoutesInventory
inventory=PaRoutesInventory(n=5)
purch_smiles = [mol.smiles for mol in inventory.purchasable_mols()]
len(purch_smiles)

def num_heavy_atoms(mol):
    return Chem.rdchem.Mol.GetNumAtoms(mol, onlyExplicit=True)

purch_mol_to_exclude = []
for smiles in purch_smiles:
    if num_heavy_atoms(Chem.MolFromSmiles(smiles)) < 2:
        purch_mol_to_exclude = purch_mol_to_exclude + [smiles]



targ_route_not_in_route_dict = {}
for target, target_routes_dict in targ_routes_dict.items():
    targ_route_not_in_route_dict[target] = {}
    
    target_route_df = target_routes_dict["route_1"]
    purch_in_route = list(target_route_df['smiles'])
    purch_not_in_route = [purch_smile for purch_smile in purch_smiles if purch_smile not in purch_in_route]
    random.seed(config["seed"])
    purch_not_in_route_sample = random.sample(purch_not_in_route, config["not_in_route_sample_size"])
    
    # Filter out molecules with only one atom (problems with featurizer)
    purch_in_route = [smiles for smiles in purch_in_route if smiles not in purch_mol_to_exclude]
    purch_not_in_route_sample = [smiles for smiles in purch_not_in_route_sample if smiles not in purch_mol_to_exclude]
    
    targ_route_not_in_route_dict[target]['positive_samples'] = purch_in_route
    targ_route_not_in_route_dict[target]['negative_samples'] = purch_not_in_route_sample




### Temp - select a sample of targets

In [9]:
# Get a random sample of keys from targ_routes_dict
if config["nr_sample_targets"]!= -1:
    sample_targets = random.sample(list(targ_route_not_in_route_dict.keys()), config["nr_sample_targets"])
else:
    sample_targets = targ_route_not_in_route_dict


# Create targ_routes_dict_sample with the sampled keys and their corresponding values
targ_route_not_in_route_dict_sample = {target: targ_route_not_in_route_dict[target] for target in sample_targets}


### Run

In [10]:
def gnn_preprocess_input(input_data):
    featurizer = dc.feat.MolGraphConvFeaturizer()
    targets = []
    positive_samples = []
    negative_samples = []

    for target_smiles, samples in tqdm(input_data.items()):
#         try:
        target_feats = featurizer.featurize(Chem.MolFromSmiles(target_smiles))
        pos_mols = [Chem.MolFromSmiles(positive_smiles) for positive_smiles in samples['positive_samples']]
        neg_mols = [Chem.MolFromSmiles(negative_smiles) for negative_smiles in samples['negative_samples']]
        pos_feats = featurizer.featurize(pos_mols)
        neg_feats = featurizer.featurize(neg_mols)

        targets.append(target_feats[0])
        positive_samples.append(pos_feats)
        negative_samples.append(neg_feats)
#             targets_torch = torch.tensor(target_feats, dtype=torch.double)
#             positive_samples = torch.tensor(pos_feats, dtype=torch.double)
#             negative_samples = torch.tensor(neg_feats, dtype=torch.double)
#             targets.append(targets_torch)
#             positive_samples.append(positive_samples)
#             negative_samples.append(negative_samples)
            
#         except:
#             # Handle the case where featurization fails for a sample
#             print(f"Featurization failed for sample: {target_smiles}")
        
#     targets_tensor = torch.stack(targets)
#     positive_samples_tensor = torch.stack(positive_samples)
#     negative_samples_tensor = torch.stack(negative_samples)
#     return targets_tensor, positive_samples_tensor, negative_samples_tensor
    return targets, positive_samples, negative_samples


def fingerprint_vect_from_smiles(mol_smiles):
    return AllChem.GetMorganFingerprintAsBitVect(AllChem.MolFromSmiles(mol_smiles), radius=3)

def fingerprint_preprocess_input(input_data):
    targets = []
    positive_samples = []
    negative_samples = []

    for target_smiles, samples in tqdm(input_data.items()):
#         target_feats = fingerprint_from_smiles(Chem.MolFromSmiles(target_smiles))
#         pos_mols = [Chem.MolFromSmiles(positive_smiles) for positive_smiles in samples['positive_samples']]
#         neg_mols = [Chem.MolFromSmiles(negative_smiles) for negative_smiles in samples['negative_samples']]
        target_feats = fingerprint_vect_from_smiles(target_smiles)
        pos_feats = list(map(fingerprint_vect_from_smiles, samples['positive_samples']))
        neg_feats = list(map(fingerprint_vect_from_smiles, samples['negative_samples']))
        
#         targets.append(target_feats[0])
#         positive_samples.append(pos_feats)
#         negative_samples.append(neg_feats)
        targets.append(torch.tensor(target_feats, dtype=torch.double))
        positive_samples.append(torch.tensor(pos_feats, dtype=torch.double))
        negative_samples.append(torch.tensor(neg_feats, dtype=torch.double))
        

    return targets, positive_samples, negative_samples



class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, targets, positive_samples, negative_samples):
        self.targets = targets
        self.positive_samples = positive_samples
        self.negative_samples = negative_samples

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        target = self.targets[idx]
        positive = self.positive_samples[idx]
        negative = self.negative_samples[idx]

        return target, positive, negative


# def collate_fn(data):
#     print("Using collate_fn")
#     targets = []
#     positive_samples = []
#     negative_samples = []

#     for target, positive, negative in data:
#         targets.append(target)
#         positive_samples.extend(positive)
#         negative_samples.extend(negative)

#     targets = torch.stack(targets, dim=0)
#     positive_samples = torch.stack(positive_samples, dim=0)
#     negative_samples = torch.stack(negative_samples, dim=0)

#     return targets, positive_samples, negative_samples

def collate_fn(data):
    targets, positive_samples, negative_samples = zip(*data)

    return targets, positive_samples, negative_samples


In [11]:

class SampleData:
    def __init__(self, target, positive_samples, negative_samples):
        self.target = target
        self.positive_samples = positive_samples
        self.negative_samples = negative_samples


# def preprocess_input(input_data):
#     featurizer = dc.feat.MolGraphConvFeaturizer()
#     data_list = []
    
#     for target_smiles, samples in tqdm(input_data.items()):
#         target_feats = featurizer.featurize(Chem.MolFromSmiles(target_smiles))
#         pos_mols = [Chem.MolFromSmiles(positive_smiles) for positive_smiles in samples['positive_samples']]
#         neg_mols = [Chem.MolFromSmiles(negative_smiles) for negative_smiles in samples['negative_samples']]
#         pos_feats = featurizer.featurize(pos_mols)
#         neg_feats = featurizer.featurize(neg_mols)
#         data_list = data_list + [SampleData(target=target_feats, positive_samples=pos_feats, negative_samples=neg_feats)]
#     return data_list
        
# class CustomDataset(Dataset):
#     def __init__(self, data):
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         sample = self.data[idx]
# #         return sample

# #         # Extract the target_smiles, positive_sample, and negative_samples
# # #         target_smiles = sample['target_smiles']
# # #         positive_sample = sample['positive_sample']
# # #         negative_samples = sample['negative_samples']
#         target = sample.target
#         positive_samples = sample.positive_samples
#         negative_samples = sample.negative_samples

#         # Convert the data to tensors or any other necessary preprocessing

#         # Return the sample with named attributes
# #         return {
# #             'target': target,
# #             'positive_samples': positive_samples,
# #             'negative_samples': negative_samples
# #         }
#         return target, positive_samples, negative_samples

In [12]:
# Step 2: Define embedding model
class FullyConnectedModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FullyConnectedModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def custom_global_max_pool(x):
    return torch.max(x, dim=0)[0]

class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
#     def __init__(self, hidden_dim, output_dim):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
#         self.conv1 = GCNConv(None,hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # Global max pooling (from node level to graph level embeddings)
#         x = global_max_pool(x) #, edge_index[0]
        x = custom_global_max_pool(x)

        x = self.fc(x)
        return x

class FingerprintModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FingerprintModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, dtype=torch.double)
        self.fc2 = nn.Linear(hidden_dim, output_dim, dtype=torch.double)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


# Step 3: Create a contrastive learning loss function
class NTXentLoss(nn.Module):
    def __init__(self, temperature):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.cos_sim = nn.CosineSimilarity(dim=-1)

    def forward(self, embeddings):
        sample_losses = []
        for single_sample_embeddings in embeddings:
            target_emb = single_sample_embeddings.target
            positive_embs = single_sample_embeddings.positive_samples
            negative_embs = single_sample_embeddings.negative_samples
            
            # Sample one positive
            nr_positives = positive_embs.size(0)
            # Randomly select a positive sample
            row_index = torch.randint(0, nr_positives, (1,))
            positive_emb = torch.index_select(positive_embs, dim=0, index=row_index)
#             positive_emb = positive_emb.squeeze(0)
            
            positive_similarity = self.cos_sim(target_emb, positive_emb)
            positive_similarity /= self.temperature           
            negative_similarity = self.cos_sim(target_emb, negative_embs)
            negative_similarity /= self.temperature
            
            numerator = torch.exp(positive_similarity)
            denominator = torch.sum(torch.exp(negative_similarity))
            sample_loss = -torch.log(numerator / (numerator + denominator))
            sample_losses = sample_losses + [sample_loss]
        
        return sum(sample_losses) / len(sample_losses)



# import torch
# from pytorch_metric_learning.losses import GenericPairLoss

# class NTXentLoss(GenericPairLoss):

#     def __init__(self, temperature, **kwargs):
#         super().__init__(use_similarity=True, mat_based_loss=False, **kwargs)
#         self.temperature = temperature

#     def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple):
#         a1, p, a2, _ = indices_tuple

#         if len(a1) > 0 and len(a2) > 0:
#             pos_pairs = pos_pairs.unsqueeze(1) / self.temperature
#             neg_pairs = neg_pairs / self.temperature
#             n_per_p = (a2.unsqueeze(0) == a1.unsqueeze(1)).float()
#             neg_pairs = neg_pairs*n_per_p
#             neg_pairs[n_per_p==0] = float('-inf')

#             max_val = torch.max(pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0].half()) ###This is the line change
#             numerator = torch.exp(pos_pairs - max_val).squeeze(1)
#             denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator
#             log_exp = torch.log((numerator/denominator) + 1e-20)
#             return {"loss": {"losses": -log_exp, "indices": (a1, p), "reduction_type": "pos_pair"}}
#         return self.zero_losses()




# Or use NTXentMultiplePositives

In [13]:
# # Step 1: Prepare the dataset
# # Assuming we have a dataset file named 'molecules.csv' with molecular structures and labels
# dataset = dc.data.CSVLoader(tasks=['property'], feature_field='smiles')
# dataset.load_from_file('molecules.csv')
# splitter = dc.splits.RandomSplitter()
# train_dataset, valid_dataset, _ = splitter.train_valid_test_split(dataset)

# Step 1: Create data dictionary getting negative samples for each target
input_data = targ_route_not_in_route_dict_sample


# Step 2: Featurizer and cast as CustomDataset
# Step 3: Create DataLoader


if config["model_type"] == 'gnn':
    preprocessed_targets, preprocessed_positive_samples, preprocessed_negative_samples = gnn_preprocess_input(input_data)
    dataset = CustomDataset(preprocessed_targets, preprocessed_positive_samples, preprocessed_negative_samples)
    data_loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=config["shuffle"], collate_fn=collate_fn)
elif config["model_type"] == 'fingerprints':
    preprocessed_targets, preprocessed_positive_samples, preprocessed_negative_samples = fingerprint_preprocess_input(input_data)
    dataset = CustomDataset(preprocessed_targets, preprocessed_positive_samples, preprocessed_negative_samples)
    data_loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=config["shuffle"], collate_fn=collate_fn)
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')
    
    

# def collate_fn(data):
#     targets = [sample[0] for sample in data]
#     positive_samples = [sample[1] for sample in data]
#     negative_samples = [sample[2] for sample in data]

#     return targets, positive_samples, negative_samples
# def collate_fn(data):
#     targets = [sample[0] for sample in data]
#     positive_samples = [sample[1] for sample in data]
#     negative_samples = [sample[2] for sample in data]

#     return targets, positive_samples, negative_samples

# Save
if save_preprocessed_data:
    with open(f'{checkpoint_folder}/preprocessed_targets.pickle', 'wb') as handle:
            pickle.dump(preprocessed_targets, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(f'{checkpoint_folder}/preprocessed_positive_samples.pickle', 'wb') as handle:
            pickle.dump(preprocessed_positive_samples, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(f'{checkpoint_folder}/preprocessed_negative_samples.pickle', 'wb') as handle:
            pickle.dump(preprocessed_negative_samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

# # Load
# with open(f'{checkpoint_folder}/preprocessed_targets.pickle', 'wb') as handle:
#     preprocessed_targets = pickle.load(handle)
# with open(f'{checkpoint_folder}/preprocessed_positive_samples.pickle', 'wb') as handle:
#     preprocessed_positive_samples = pickle.load(handle)
# with open(f'{checkpoint_folder}/preprocessed_negative_samples.pickle', 'wb') as handle:
#     preprocessed_negative_samples = pickle.load(handle)


100%|███████████████████████████████████████████████████████████████████████| 9895/9895 [1:13:36<00:00,  2.24it/s]


In [14]:
print(data_loader.collate_fn)

<function collate_fn at 0x2d45fe440>


In [15]:
# len(preprocessed_targets[0])
# (preprocessed_targets[0].size()[0])

In [16]:
if config["model_type"] == 'gnn':
    gnn_input_dim = preprocessed_targets[0].node_features.shape[1]
    gnn_hidden_dim = config["hidden_dim"]
    gnn_output_dim = config["output_dim"]
elif config["model_type"] == 'fingerprints':
#     fingerprint_input_dim = preprocessed_targets[0].GetNumBits()
    fingerprint_input_dim = len(preprocessed_targets[0].node_features) #(preprocessed_targets[0].size()[0])
    fingerprint_hidden_dim = config["hidden_dim"]
    fingerprint_output_dim = config["output_dim"]
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')



In [17]:
train_data_loader = data_loader

In [18]:
train_data_loader

<torch.utils.data.dataloader.DataLoader at 0x2d58d5de0>

In [19]:
# Step 3: Set up the training loop for the GNN model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if config["model_type"] == 'gnn':
    model = GNNModel(
        input_dim=gnn_input_dim, 
        hidden_dim=gnn_hidden_dim, 
        output_dim=gnn_output_dim).to(device)
    model.double()
    
elif config["model_type"] == 'fingerprints':
    model = FingerprintModel(
        input_dim=fingerprint_input_dim, 
        hidden_dim=fingerprint_hidden_dim, 
        output_dim=fingerprint_output_dim).to(device)
else:
    raise NotImplementedError(f'Model type {config["model_type"]}')



loss_fn = NTXentLoss(temperature=config["temperature"])
optimizer = optim.Adam(model.parameters(), lr=config["lr"])

num_epochs = config["num_epochs"]



In [20]:
# def compute_gnn_embedding(gnn_model):
#     node_features = torch.tensor(example.node_features, dtype=torch.double)
#     edge_index = torch.tensor(example.edge_index, dtype=torch.long)  # Assuming edge_index is of type 'long'

#     # Convert the input node features to double
#     node_features = node_features.double()

#     # Compute the embeddings for the positive example
#     embedding = model(node_features, edge_index)

In [21]:
train_data_loader.collate_fn

<function __main__.collate_fn(data)>

In [22]:
load_from_checkpoint=False

In [None]:
# Check if a checkpoint exists and load the model state and optimizer state if available
if load_from_checkpoint:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
else:
    start_epoch = 0

# Create a SummaryWriter for TensorBoard logging
log_dir = f'{checkpoint_folder}/logs'  # Specify the directory to store TensorBoard logs
writer = SummaryWriter(log_dir)

model.train()
epoch_loss = pd.DataFrame(columns=['Epoch', 'TrainLoss'])
for epoch in tqdm(range(start_epoch, num_epochs)):
    total_loss = 0.0
    total_batches = 0
    for batch_idx, batch_data in enumerate(train_data_loader):
        batch_targets, batch_positive_samples, batch_negative_samples = batch_data

        optimizer.zero_grad()
        
        # Forward pass
        embeddings = []
        for i in range(len(batch_targets)):
            target = batch_targets[i]
            positives = batch_positive_samples[i]
            negatives = batch_negative_samples[i]
            if config["model_type"] == 'gnn':
                target_node_features = torch.tensor(target.node_features, dtype=torch.double)
                target_edge_index = torch.tensor(target.edge_index, dtype=torch.long)  # Assuming edge_index is of type 'long'
                target_embedding = model(target_node_features, target_edge_index)
                
                positive_samples_embeddings = torch.stack([
                    model(torch.tensor(example.node_features, dtype=torch.double).double(),
                          torch.tensor(example.edge_index, dtype=torch.long)) 
                    for example in positives
                ], dim=0)
                
                negative_samples_embeddings = torch.stack([
                    model(torch.tensor(example.node_features, dtype=torch.double).double(),
                          torch.tensor(example.edge_index, dtype=torch.long)) 
                    for example in negatives
                ], dim=0)
                embeddings = embeddings + [SampleData(target=target_embedding, positive_samples=positive_samples_embeddings, negative_samples=negative_samples_embeddings)]

            elif config["model_type"] == 'fingerprints':
                target_embedding = model(target)
                positive_samples_embeddings = model(positives)
                negative_samples_embeddings = model(negatives)
                embeddings = embeddings + [SampleData(target=target_embedding, positive_samples=positive_samples_embeddings, negative_samples=negative_samples_embeddings)]
            else:
                raise NotImplementedError(f'Model type {config["model_type"]}')


        # Compute loss
        loss = loss_fn(embeddings)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Track total loss
        total_loss += loss.item()
        total_batches += 1

    # Compute average loss for the epoch
    average_loss = total_loss / total_batches
    
    new_row = pd.DataFrame({'Epoch': [epoch], 'TrainLoss': [average_loss]})
    epoch_loss = pd.concat([epoch_loss, new_row], axis=0)
    
    # Log the loss to TensorBoard
    writer.add_scalar('Loss/train', loss.item(), epoch+1)


    
    if ((epoch%10==0) | (epoch==num_epochs-1)):
        print(f"{config['model_type']} Model - Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
        
        # Save the model and optimizer state as a checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        checkpoint_path = f'{checkpoint_folder}/epoch_{epoch+1}_{checkpoint_name}'  # Specify the checkpoint file path
        torch.save(checkpoint, checkpoint_path)
        
#         loss_df = pd.DataFrame({'Epoch': range(len(epoch_loss)), 'TrainLoss': epoch_loss})
        epoch_loss.to_csv(f'{checkpoint_folder}/train_loss.csv', index=False)
        

# Close the SummaryWriter
writer.close()

# Step 4: Evaluate and use the trained embeddings
# You can evaluate the embeddings on downstream tasks or use them for molecular similarity search or property prediction


  1%|▋                                                                       | 1/100 [24:21<40:10:50, 1461.11s/it]

gnn Model - Epoch 1/100, Loss: 2.103701420761078


 11%|███████▌                                                             | 11/100 [4:47:43<39:34:54, 1601.06s/it]

gnn Model - Epoch 11/100, Loss: 0.18274515600282853


 13%|████████▉                                                            | 13/100 [5:40:24<38:31:01, 1593.82s/it]

In [None]:
import plotly.express as px
fig = px.line(x=epoch_loss['Epoch'], y=epoch_loss['TrainLoss'], title="Train loss")
fig.update_layout(width=1000, height=600, showlegend=False)
fig.write_image(f"{checkpoint_folder}/Train_loss.pdf")
fig.show()

In [None]:
# # Save
# with open(f'{checkpoint_folder}/preprocessed_targets.pickle', 'wb') as handle:
#         pickle.dump(preprocessed_targets, handle, protocol=pickle.HIGHEST_PROTOCOL)
# with open(f'{checkpoint_folder}/preprocessed_positive_samples.pickle', 'wb') as handle:
#         pickle.dump(preprocessed_positive_samples, handle, protocol=pickle.HIGHEST_PROTOCOL)
# with open(f'{checkpoint_folder}/preprocessed_negative_samples.pickle', 'wb') as handle:
#         pickle.dump(preprocessed_negative_samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

