# 🤠 MolGPT Cowboy Chronicle 🤠

This notebook demonstrates the complete workflow for MolGPT - a GPT-based model for molecular generation. We'll cover:

1. **Setup & Environment** - Setting up dependencies and data
2. **Training** - Training the MolGPT model on molecular data
3. **Generation** - Generating new molecules with various conditioning strategies
4. **Evaluation** - Evaluating the quality of generated molecules

Let's saddle up and ride! 🐎

## 1. Setup & Environment

First, let's make sure we have all the necessary imports and dependencies.

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import re
import json
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import QED
from rdkit.Chem import Crippen
from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem.rdMolDescriptors import CalcTPSA

# Import directly from the files instead of using package imports
sys.path.insert(0, '.')
from train.model import GPT, GPTConfig
from train.trainer import Trainer, TrainerConfig
from train.dataset import SmileDataset
from train.utils import set_seed
from generate.utils import sample, check_novelty, canonic_smiles
from moses.utils import get_mol

# Set random seed for reproducibility
set_seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

### 1.1 Load Dataset

Let's load the Moses dataset and explore it.

In [None]:
# Load the Moses dataset
data_name = 'moses2'
data = pd.read_csv(f'{data_name}.csv')
data = data.dropna(axis=0).reset_index(drop=True)
data.columns = data.columns.str.lower()

# Display the first few rows
print(f"Dataset shape: {data.shape}")
data.head()

### 1.2 Tokenization

We need to tokenize the SMILES strings for the model.

In [None]:
# Define the regex pattern for tokenizing SMILES
pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
regex = re.compile(pattern)

# Define the character set
whole_string = ['#', '%10', '%11', '%12', '(', ')', '-', '1', '2', '3', '4', '5', '6', '7', '8', '9', '<', '=', 'B', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', '[B-]', '[BH-]', '[BH2-]', '[BH3-]', '[B]', '[C+]', '[C-]', '[CH+]', '[CH-]', '[CH2+]', '[CH2]', '[CH]', '[F+]', '[H]', '[I+]', '[IH2]', '[IH]', '[N+]', '[N-]', '[NH+]', '[NH-]', '[NH2+]', '[NH3+]', '[N]', '[O+]', '[O-]', '[OH+]', '[O]', '[P+]', '[PH+]', '[PH2+]', '[PH]', '[S+]', '[S-]', '[SH+]', '[SH]', '[Se+]', '[SeH+]', '[SeH]', '[Se]', '[Si-]', '[SiH-]', '[SiH2]', '[SiH]', '[Si]', '[b-]', '[bH-]', '[c+]', '[c-]', '[cH+]', '[cH-]', '[n+]', '[n-]', '[nH+]', '[nH]', '[o+]', '[s+]', '[sH+]', '[se+]', '[se]', 'b', 'c', 'n', 'o', 'p', 's']

# Create vocabulary mappings
stoi = {ch: i for i, ch in enumerate(whole_string)}
itos = {i: ch for i, ch in enumerate(whole_string)}

# Save vocabulary mappings to JSON files
with open(f'{data_name}_stoi.json', 'w') as f:
    json.dump(stoi, f)

print(f"Vocabulary size: {len(stoi)}")

## 2. Training

Now let's train the MolGPT model on the Moses dataset.

In [None]:
# Define model parameters
vocab_size = len(stoi)
block_size = 54  # Maximum SMILES length
n_layer = 8
n_head = 8
n_embd = 256
scaffold_max_len = 48  # For Moses dataset

# Define model configuration
mconf = GPTConfig(vocab_size, block_size, 
                  num_props=1,  # Using 1 property for conditioning
                  n_layer=n_layer, n_head=n_head, n_embd=n_embd, 
                  scaffold=True, scaffold_maxlen=scaffold_max_len,
                  lstm=False, lstm_layers=0)

# Create the model
model = GPT(mconf)

# Move model to device
model.to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

In [None]:
# Filter data for training
train_data = data[data['split'] == 'train']
val_data = data[data['split'] == 'test']

print(f"Training data size: {len(train_data)}")
print(f"Validation data size: {len(val_data)}")

# Extract SMILES strings
train_smiles = train_data['smiles'].values
val_smiles = val_data['smiles'].values

# Extract scaffolds
train_scaffolds = train_data['scaffold_smiles'].values
val_scaffolds = val_data['scaffold_smiles'].values

# Extract properties (QED for this example)
train_props = train_data['qed'].values
val_props = val_data['qed'].values

# Combine all SMILES for vocabulary creation
all_smiles = np.concatenate([train_smiles, val_smiles])
all_scaffolds = np.concatenate([train_scaffolds, val_scaffolds])
content = ' '.join(all_smiles.tolist() + all_scaffolds.tolist())

In [None]:
# Create a simple argparse-like object for the dataset
class Args:
    def __init__(self):
        self.debug = False

args = Args()

# Create datasets
train_dataset = SmileDataset(args, train_smiles, content, block_size, aug_prob=0.5, 
                            prop=train_props, scaffold=train_scaffolds, scaffold_maxlen=scaffold_max_len)
val_dataset = SmileDataset(args, val_smiles, content, block_size, aug_prob=0.0, 
                          prop=val_props, scaffold=val_scaffolds, scaffold_maxlen=scaffold_max_len)

In [None]:
# Configure training
train_config = TrainerConfig(
    max_epochs=2,  # For demonstration, use a small number of epochs
    batch_size=64,
    learning_rate=3e-4,
    lr_decay=True,
    warmup_tokens=512*20,
    final_tokens=2*len(train_smiles)*block_size,
    ckpt_path='trained_model.pt',
    num_workers=4,
    generate=True
)

# Create trainer
trainer = Trainer(model, train_dataset, val_dataset, train_config, stoi, itos)

In [None]:
# For actual training, uncomment the following line
# df = trainer.train(wandb=None)

print("Training skipped for demonstration purposes.")
print("In a real scenario, you would run the training process for more epochs and with larger batch sizes.")

## 3. Generation

Now let's use the pre-trained model to generate new molecules.

In [None]:
# Define model parameters
vocab_size = len(stoi)
block_size = 54  # Maximum SMILES length
n_layer = 8
n_head = 8
n_embd = 256
scaffold_max_len = 48  # For Moses dataset

# Choose the model type
model_type = "qed"  # Options: qed, sas, logp, tpsa
model_weight = f"/home/ubuntu/molgpt/datasets/weights/moses_scaf_wholeseq_{model_type}.pt"

# Define model configuration
mconf = GPTConfig(vocab_size, block_size, 
                  num_props=1,  # Using 1 property for conditioning
                  n_layer=n_layer, n_head=n_head, n_embd=n_embd, 
                  scaffold=True, scaffold_maxlen=scaffold_max_len,
                  lstm=False, lstm_layers=0)

# Create the model
model = GPT(mconf)

# Load pre-trained weights
try:
    model.load_state_dict(torch.load(model_weight, map_location=device))
    print(f"Model loaded from {model_weight}")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Using untrained model for demonstration purposes.")

model.to(device)
model.eval()

In [None]:
# Set generation parameters
context = "C"  # Starting with a carbon atom
batch_size = 5  # Generate 5 molecules at once
temperature = 1.0  # Temperature for sampling (higher = more diverse)
top_k = None  # No top-k sampling

# Tokenize the context
x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(batch_size, 1).to(device)

# Generate molecules
with torch.no_grad():
    y = sample(model, x, block_size, temperature=temperature, sample=True, top_k=top_k, prop=None, scaffold=None)

# Convert generated tokens to SMILES strings
generated_smiles = []
for gen_mol in y:
    completion = ''.join([itos[int(i)] for i in gen_mol])
    completion = completion.replace('<', '')  # Remove padding tokens
    generated_smiles.append(completion)

# Convert SMILES to molecules
molecules = []
valid_smiles = []
for smiles in generated_smiles:
    mol = get_mol(smiles)
    if mol:
        molecules.append(mol)
        valid_smiles.append(Chem.MolToSmiles(mol))

# Display results
print(f"Generated {len(generated_smiles)} molecules, {len(molecules)} are valid.")
print("\nGenerated SMILES:")
for i, smiles in enumerate(valid_smiles):
    print(f"{i+1}. {smiles}")

In [None]:
# Visualize the generated molecules
if molecules:
    img = Draw.MolsToGridImage(molecules, molsPerRow=3, subImgSize=(300, 300), legends=[f"Mol {i+1}" for i in range(len(molecules))])
    display(img)

## 4. Property-Conditioned Generation

In [None]:
# Set property conditioning values
prop_values = [0.6, 0.75, 0.9]  # QED values (drug-likeness)

# Generate molecules for each property value
all_molecules = []
all_smiles = []
all_props = []

for prop_value in prop_values:
    print(f"\nGenerating molecules with {model_type} = {prop_value}")
    
    # Tokenize the context
    x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(batch_size, 1).to(device)
    
    # Set property conditioning
    p = torch.tensor([[prop_value]]).repeat(batch_size, 1).to(device)
    
    # Generate molecules
    with torch.no_grad():
        y = sample(model, x, block_size, temperature=temperature, sample=True, top_k=top_k, prop=p, scaffold=None)
    
    # Convert generated tokens to SMILES strings
    generated_smiles = []
    for gen_mol in y:
        completion = ''.join([itos[int(i)] for i in gen_mol])
        completion = completion.replace('<', '')  # Remove padding tokens
        generated_smiles.append(completion)
    
    # Convert SMILES to molecules
    molecules = []
    valid_smiles = []
    for smiles in generated_smiles:
        mol = get_mol(smiles)
        if mol:
            molecules.append(mol)
            valid_smiles.append(Chem.MolToSmiles(mol))
    
    # Store results
    all_molecules.extend(molecules)
    all_smiles.extend(valid_smiles)
    all_props.extend([prop_value] * len(molecules))
    
    # Display results
    print(f"Generated {len(generated_smiles)} molecules, {len(molecules)} are valid.")
    for i, smiles in enumerate(valid_smiles[:3]):  # Show only first 3
        print(f"{i+1}. {smiles}")

In [None]:
# Visualize the property-conditioned molecules
if all_molecules:
    # Show only up to 9 molecules
    display_mols = all_molecules[:min(9, len(all_molecules))]
    display_props = all_props[:min(9, len(all_molecules))]
    
    img = Draw.MolsToGridImage(display_mols, molsPerRow=3, subImgSize=(300, 300), 
                              legends=[f"{model_type}={p}" for p in display_props])
    display(img)

## 5. Scaffold-Conditioned Generation

In [None]:
# Define scaffold conditions
scaffolds = [
    'c1ccccc1',  # Benzene
    'c1ccncc1',  # Pyridine
    'c1ccccc1N'  # Aniline
]

# Generate molecules for each scaffold
all_scaffold_molecules = []
all_scaffold_smiles = []
all_scaffold_conditions = []

for scaffold in scaffolds:
    print(f"\nGenerating molecules with scaffold: {scaffold}")
    
    # Tokenize the context
    x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(batch_size, 1).to(device)
    
    # Pad the scaffold
    padded_scaffold = scaffold + '<' * (scaffold_max_len - len(regex.findall(scaffold)))
    
    # Tokenize the scaffold
    sca = torch.tensor([stoi[s] for s in regex.findall(padded_scaffold)], dtype=torch.long)[None,...].repeat(batch_size, 1).to(device)
    
    # Generate molecules
    with torch.no_grad():
        y = sample(model, x, block_size, temperature=temperature, sample=True, top_k=top_k, prop=None, scaffold=sca)
    
    # Convert generated tokens to SMILES strings
    generated_smiles = []
    for gen_mol in y:
        completion = ''.join([itos[int(i)] for i in gen_mol])
        completion = completion.replace('<', '')  # Remove padding tokens
        generated_smiles.append(completion)
    
    # Convert SMILES to molecules
    molecules = []
    valid_smiles = []
    for smiles in generated_smiles:
        mol = get_mol(smiles)
        if mol:
            molecules.append(mol)
            valid_smiles.append(Chem.MolToSmiles(mol))
    
    # Store results
    all_scaffold_molecules.extend(molecules)
    all_scaffold_smiles.extend(valid_smiles)
    all_scaffold_conditions.extend([scaffold] * len(molecules))
    
    # Display results
    print(f"Generated {len(generated_smiles)} molecules, {len(molecules)} are valid.")
    for i, smiles in enumerate(valid_smiles[:3]):  # Show only first 3
        print(f"{i+1}. {smiles}")

In [None]:
# Visualize the scaffold-conditioned molecules
if all_scaffold_molecules:
    # Show only up to 9 molecules
    display_mols = all_scaffold_molecules[:min(9, len(all_scaffold_molecules))]
    display_scaffolds = all_scaffold_conditions[:min(9, len(all_scaffold_molecules))]
    
    img = Draw.MolsToGridImage(display_mols, molsPerRow=3, subImgSize=(300, 300), 
                              legends=[f"Scaffold: {s}" for s in display_scaffolds])
    display(img)

## 6. Summary

In this notebook, we've demonstrated how to use MolGPT for molecular generation with different conditioning strategies:

1. **Unconditional Generation**: Generate molecules without any constraints
2. **Property-Conditioned Generation**: Generate molecules with specific property values (QED, LogP, SAS, TPSA)
3. **Scaffold-Conditioned Generation**: Generate molecules containing specific molecular scaffolds

The MolGPT model provides a powerful and flexible approach to molecular generation, allowing for precise control over the generated structures through various conditioning mechanisms.