In [1]:
import json

import torch
from torch import nn
from torch.utils.data import DataLoader
import pandas as pd
import tqdm
import numpy as np
from qumedl.mol.encoding.selfies_ import Selfies
from qumedl.models.transformer.pat import CausalMolPAT
from qumedl.models.transformer.loss_functions import compute_transformer_loss
from qumedl.training.collator import TensorBatchCollator
from qumedl.training.tensor_batch import TensorBatch
from qumedl.models.activations import NewGELU
from qumedl.models.priors import GaussianPrior
from orquestra.drug.discovery.docking.utils import process_molecule
from orquestra.drug.discovery.validator.filter_abstract import FilterAbstract
from torch.optim.lr_scheduler import CosineAnnealingLR

# from qumedl.models.priors import QCBMPrior
from orquestra.drug.discovery.validator import (
    GeneralFilter,
    PainFilter,
    WehiMCFilter,
    SybaFilter,
)
from orquestra.drug.metrics import MoleculeNovelty, get_diversity
from orquestra.drug.utils import ConditionFilters
import wandb  # Import wandb
import os
from datetime import datetime
import sys
import torch
import torch.nn as nn
import cloudpickle

## RBM
import optax
from orquestra.qml.models.rbm.jx import RBM
from orquestra.qml.api import Batch

# Initialize Qiskit Runtime Service with specific credentials
import pickle


class TartarusFilters(FilterAbstract):
    def apply(self, smile: str):
        _, status = process_molecule(smile)
        if status == "PASS":
            return True
        return False


def save_object(obj, filename):
    """Save a Python object to a file using pickle."""
    with open(filename, "wb") as file:  # Open the file in write-binary mode
        pickle.dump(obj, file)


def load_object(filename):
    """Load a Python object from a pickle file."""
    with open(filename, "rb") as file:  # Open the file in read-binary mode
        return pickle.load(file)


class RBMModel(RBM):
    def __init__(
        self,
        n_visible: int,
        n_hidden: int,
        random_seed=32,
        optimizer=optax.sgd(learning_rate=1e-6),
    ):
        super().__init__(
            n_visible, n_hidden, random_seed=random_seed, optimizer=optimizer
        )
        self.num_qubits = self.n_visible

    def train(self, data, probs, n_epoch):
        rbm_batch = Batch(data=data, probs=probs)
        # rbm_batch.batch_size = -1
        all_resuls = []
        for i in range(n_epoch):
            all_resuls.append(self._train_on_batch(rbm_batch))
        return all_resuls


# save in file:
def save_obj(obj, file_path):
    with open(file_path, "wb") as f:
        r = cloudpickle.dump(obj, f)
    return r


def load_obj(file_path):
    with open(file_path, "rb") as f:
        obj = cloudpickle.load(f)
    return obj


def create_project_log_folder(project_name="pat"):
    # Generate a project name based on the current date
    current_date = datetime.now()
    # datetime.today().strftime("%Y_%d_%mT%H_%M_%S.%f")
    project_name = current_date.strftime(f"{project_name}_%Y-%m-%d_%H-%M-%S.%f")
    project_today = current_date.strftime(f"{project_name}_%Y-%m-%d")

    # Define the path for the logs directory
    logs_dir_path = "./logs"

    # Check if the logs directory exists, if not create it
    if not os.path.exists(logs_dir_path):
        os.makedirs(logs_dir_path)

    # Define the path for the new project directory within the logs folder
    project_dir_path = os.path.join(logs_dir_path, project_name)

    # Check if the project directory exists, if not create it
    if not os.path.exists(project_dir_path):
        os.makedirs(project_dir_path)

    print(f"Project log folder created at: {project_dir_path}")
    return (project_dir_path, project_name, project_today)

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [3]:
print(os.environ["CUDA_VISIBLE_DEVICES"])

0,1


In [4]:
DEVICE = (
    "cuda" if torch.cuda.is_available() else "cpu"
)  # needs to be cuda on the cluster

print("no input")
prior_name = "rbm"
prior_size = 16  # int(sys.argv[2])
random_seed = 0
cuda_device_code = "4"
# os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device_code
dataset_arg = "tartarus"
backend_sim = True
wandb_active = False

print(
    f"prior_name:{prior_name},prior_size:{prior_size},DEVICE:{DEVICE},cuda_device_code:{cuda_device_code},dataset_arg:{dataset_arg},random_seed:{random_seed}"
)

no input
prior_name:rbm,prior_size:16,DEVICE:cuda,cuda_device_code:4,dataset_arg:tartarus,random_seed:0


In [5]:
# DEVICE = 'cpu'
batch_size = 1024
prior_dim = prior_size

model_dim = embedding_dim = 256  # should be embedding_dim/n_attn_heads
n_attn_heads = 8
n_encoder_layers = 4

n_g_samples = 5000

dropout = 0.2

n_epochs = 100
learning_rate = 1e-3
min_learning_rate = 1e-6
gradient_accumulation_steps = 1

n_epochs_prior = 30
n_test_samples = 5000

dataset_name = "/root/generative-models/scripts/data/docking_hill_climbing_0.csv"

In [6]:
pickle_name = dataset_name.split(".")[0]
if os.path.isfile(f"{pickle_name}.pkl"):
    selfies = load_object(f"{pickle_name}.pkl")
else:
    selfies = Selfies.from_smiles_csv(dataset_name)
    save_object(selfies, f"{pickle_name}.pkl")

In [7]:
selfies

<qumedl.mol.encoding.selfies_.Selfies at 0x7f38270d2d50>

In [9]:
smiles_dataset_df = pd.read_csv(dataset_name)
smiles_dataset = smiles_dataset_df.smiles.to_list()

selfies_dataset = selfies.as_dataset()

dl_shuffler = torch.Generator()
dl_shuffler.manual_seed(random_seed)


if prior_name == "random":
    prior = GaussianPrior(dim=prior_dim)
    prior_trainable = False
elif prior_name == "rbm":
    prior = RBMModel(
        n_visible=prior_dim,
        n_hidden=2 * prior_dim,
        random_seed=random_seed,
        optimizer=optax.sgd(learning_rate=1e-6),
    )
    prior_trainable = True

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [41]:
model = CausalMolPAT(
    vocab_size=selfies.n_tokens,
    embedding_dim=embedding_dim,
    prior_dim=prior.num_qubits,
    model_dim=model_dim,
    n_attn_heads=n_attn_heads,
    n_encoder_layers=n_encoder_layers,
    hidden_act=NewGELU(),
    dropout=dropout,
    padding_token_idx=selfies.pad_index,
)

In [12]:
# samll transformer
# wandb_project_name = "pat_2024-08-16_02-09-32.570627_2024-08-16"
# log_path = "/root/generative-models/scripts/logs/pat_2024-08-16_02-09-32.570627/model_epoch_99.pt"

In [36]:
# big trasformer model
wand_project_name = "pat_2024-08-16_15-22-18.083128_2024-08-16"
# # /root/generative-models/scripts/logs/pat_2024-08-16_15-22-18.083128/model_epoch_99.pt
log_path = "/root/generative-models/scripts/logs/pat_2024-08-16_15-22-18.083128/model_epoch_99.pt"

In [37]:
torch.cuda.device_count()

2

In [38]:
# if torch.cuda.device_count() > 1:
#     print("Let's use", torch.cuda.device_count(), "GPUs!")
#     # Wrap the model with nn.DataParallel
#     model = nn.DataParallel(model)
#     batch_size = batch_size * torch.cuda.device_count()

Let's use 2 GPUs!


In [42]:
model.to(DEVICE)

CausalMolPAT(
  (embedding): Embedding(110, 256, padding_idx=13)
  (projectx_addy): ProjectXxY(
    (_projection): Linear(in_features=16, out_features=256, bias=False)
    (_activation): Identity()
  )
  (repeat): Identity()
  (pe): PositionalEncoding(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (projection): Linear(in_features=256, out_features=256, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=256, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace

In [44]:
import torch


def remove_prefix_from_state_dict(state_dict, prefix="_orig_mod."):
    """Remove a prefix from all keys in the state_dict."""
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_key = k[len(prefix) :]  # Remove the prefix
            new_state_dict[new_key] = v
        else:
            new_state_dict[k] = v
    return new_state_dict


# Load the state_dict
state_dict = torch.load(log_path, map_location=torch.device("cpu"))

# Remove the prefix
state_dict = remove_prefix_from_state_dict(state_dict)

# Load the modified state_dict into your model
# model.load_state_dict(state_dict)

In [45]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [23]:
start_tokens = torch.full(
    (n_g_samples, 1),
    fill_value=selfies.start_index,
    device=DEVICE,
    dtype=torch.int,
)
prior_samples = torch.tensor(
    np.asarray(prior.generate(n_g_samples, random_seed=random_seed))
).to(DEVICE)

In [47]:
model.to(DEVICE)

CausalMolPAT(
  (embedding): Embedding(110, 256, padding_idx=13)
  (projectx_addy): ProjectXxY(
    (_projection): Linear(in_features=16, out_features=256, bias=False)
    (_activation): Identity()
  )
  (repeat): Identity()
  (pe): PositionalEncoding(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (projection): Linear(in_features=256, out_features=256, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=256, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace

In [48]:
n_samples = 1000
temperature = 1e-9
# create an array of random bits of shape (n_samples, bits)
torch.manual_seed(1260)
prior_samples = torch.randint(0, 2, (n_samples, prior_size)).float()
# keep unique prior samples:
prior_samples = torch.unique(prior_samples, dim=0)
# move model and samples to GPU:
start_tokens = torch.full(
    (len(prior_samples), 1),
    fill_value=selfies.start_index,
    device=DEVICE,
    dtype=torch.int,
)
device = torch.device("cuda:0")

prior_samples = prior_samples.to(DEVICE)
# generate molecules from the prior samples

if isinstance(model, torch.nn.DataParallel):
    generated = model.module.generate(
        start_tokens,
        prior_samples,
        max_new_tokens=selfies.max_length,
        temperature=temperature,
    )
else:
    generated = model.generate(
        start_tokens,
        prior_samples,
        max_new_tokens=selfies.max_length,
        temperature=temperature,
    )
test_molecules = selfies.decode(generated.cpu().numpy())
ligands = selfies.selfie_to_smiles(test_molecules)
print("bits", prior_size)
print("unique prior", len(set(prior_samples)))
print("unique molecules", len(set(ligands)))
print("total molecules", len(ligands))
print("diversity", len(set(ligands)) / len(ligands))

bits 16
unique prior 994
unique molecules 930
total molecules 994
diversity 0.9356136820925554


In [49]:
temperature = 0.01
if isinstance(model, torch.nn.DataParallel):
    generated = model.module.generate(
        start_tokens,
        prior_samples,
        max_new_tokens=selfies.max_length,
        temperature=temperature,
    )
else:
    generated = model.generate(
        start_tokens,
        prior_samples,
        max_new_tokens=selfies.max_length,
        temperature=temperature,
    )
test_molecules = selfies.decode(generated.cpu().numpy())
ligands = selfies.selfie_to_smiles(test_molecules)
print("bits", prior_size)
print("unique prior", len(set(prior_samples)))
print("unique molecules", len(set(ligands)))
print("total molecules", len(ligands))
print("diversity", len(set(ligands)) / len(ligands))

bits 16
unique prior 994
unique molecules 956
total molecules 994
diversity 0.9617706237424547


In [25]:
# Load the test molecules from the JSON file
smiles_dataset_df = pd.read_csv(dataset_name)
with open("/root/generative-models/scripts/logs/pat_2024-09-27_00-15-43.541538/test_molecules-99.json", "r") as file:
    test_molecules = json.load(file)

# Decode the molecules
smiles_dataset = smiles_dataset_df.smiles.to_list()
ligands = selfies.selfie_to_smiles(test_molecules)
ligands = [ligand for ligand in set(ligands) if ligand]
# Compute diversity
diversity = len(set(ligands)) / len(ligands)

# Compute novelty
novelty = MoleculeNovelty(smiles_dataset)
novelty_score = novelty.get_novelity_smiles(ligands,threshold=0.6)

print("Diversity:", diversity)
print("Novelty:", novelty_score)

Diversity: 1.0
Novelty: 92.05298013245033


In [26]:
# Load the test molecules from the JSON file
smiles_dataset_df = pd.read_csv(dataset_name)
with open("/root/generative-models/scripts/logs/pat_2024-09-27_00-15-43.541538/test_molecules-399.json", "r") as file:
    test_molecules = json.load(file)

# Decode the molecules
smiles_dataset = smiles_dataset_df.smiles.to_list()
ligands = selfies.selfie_to_smiles(test_molecules)
ligands = [ligand for ligand in set(ligands) if ligand]
# Compute diversity
diversity = len(set(ligands)) / len(ligands)

# Compute novelty
novelty = MoleculeNovelty(smiles_dataset)
novelty_score = novelty.get_novelity_smiles(ligands,threshold=0.6)

print("Diversity:", diversity)
print("Novelty:", novelty_score)

Diversity: 1.0
Novelty: 85.03575076608784


In [23]:
novelty_score

47.22222222222222

In [28]:
len(ligands)/2000

0.979