In [None]:
import json

import torch
from torch import nn
from torch.utils.data import DataLoader
import tqdm

from qumedl.mol.encoding.selfies_ import Selfies
from qumedl.models.transformer.pat import CausalMolPAT
from qumedl.models.transformer.loss_functions import causal_transformer_compute_losses
from qumedl.training.collator import CollatorForCausalModeling
from qumedl.training.tensor_batch import TensorBatch
from qumedl.models.activations import NewGELU
from qumedl.models.priors import GaussianPrior

: 

In [2]:
random_seed = 42
DEVICE =  "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32
prior_dim = 16

model_dim = embedding_dim = 128
n_attn_heads = 4
n_encoder_layers = 4
dropout = 0.2

n_epochs = 2
learning_rate = 0.001
gradient_accumulation_steps = 4

n_test_samples = 10

In [3]:
selfies = Selfies.from_smiles_csv(
    "/root/data/drug-discovery/1Kstoned_vsc_initial_dataset_insilico_chemistry42_filtered.csv"
)

100%|██████████| 1000/1000 [00:00<00:00, 2230.08it/s]


In [4]:
selfies_dataset = selfies.as_dataset()

dl_shuffler = torch.Generator()
dl_shuffler.manual_seed(random_seed)

selfies_dl = DataLoader(
    selfies_dataset,
    batch_size=batch_size,
    shuffle=True,
    generator=dl_shuffler,
    collate_fn=CollatorForCausalModeling(),
)

prior = GaussianPrior(dim=prior_dim)

model = CausalMolPAT(
    vocab_size=selfies.n_tokens,
    embedding_dim=embedding_dim,
    prior_dim=prior.dim,
    model_dim=model_dim,
    n_attn_heads=n_attn_heads,
    n_encoder_layers=n_encoder_layers,
    hidden_act=NewGELU(),
    dropout=dropout,
    padding_token_idx=selfies.pad_index,
)

model.to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# training loop
for epoch in range(n_epochs):
    with tqdm.tqdm(total=len(selfies_dl), desc="Training Model") as prog_bar:
        tensor_batch: TensorBatch
        for step, tensor_batch in enumerate(selfies_dl):
            tensor_batch.to(DEVICE)
            prior_samples = prior.generate(tensor_batch.batch_size).to(DEVICE)
            total_loss = causal_transformer_compute_losses(
                model, tensor_batch, prior_samples=prior_samples
            )

            total_loss.backward()

            if step % gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            step_losses = {"total_loss": total_loss.item()}

            prog_bar.set_postfix(step_losses)
            prog_bar.update()

            tensor_batch.to("cpu")
            prior_samples.to("cpu")

    prog_bar.set_description("Generating test molecules")

    # generate a few samples and save them as JSON locally and to WandB
    test_prior_samples = prior.generate(n_test_samples).to(DEVICE)
    start_tokens = torch.full(
        (n_test_samples, 1),
        fill_value=selfies.start_index,
        device=DEVICE,
        dtype=torch.int,
    )

    generated = model.generate(
        start_tokens, test_prior_samples, max_new_tokens=selfies.max_length
    )
    test_molecules = selfies.decode(generated.cpu().numpy())
    
    # UNCOMMENT to save samples as JSON
    # with open(f"test_molecules-{epoch}.json", "w") as f:
    #     json.dump(test_molecules, f)

Training Model: 100%|██████████| 32/32 [00:03<00:00,  8.81it/s, total_loss=1.09] 
Training Model: 100%|██████████| 32/32 [00:02<00:00, 13.54it/s, total_loss=0.833]


In [14]:
start_tokens = torch.full(
    (5, 1),
    fill_value=selfies.start_index,
    device=DEVICE,
    dtype=torch.int,
)

prior_samples = prior.generate(start_tokens.shape[0]).to(DEVICE)

generated_samples = model.generate(start_tokens, prior_samples, max_new_tokens=10)

In [15]:
generated_samples, selfies.pad_index

(tensor([[25,  9, 11,  9,  9, 11,  9,  9,  9,  9,  9],
         [25,  9,  9,  9,  9, 11,  9, 11,  9,  9,  9],
         [25,  9,  9,  9, 24, 11,  9,  9,  9,  9,  9],
         [25, 11,  9,  9, 11,  9,  9,  9,  9,  9, 11],
         [25,  9,  9, 11,  9,  9,  9,  9,  9, 11,  9]], device='cuda:0'),
 9)

In [16]:
selfies.decode(generated_samples)

['', '', '', '[C]', '']