training VAE model

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
import random
from tqdm.auto import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

class SimpleCharVAE(nn.Module):
    def __init__(self, vocab_size, hidden_size=256, latent_size=64):
        super(SimpleCharVAE, self).__init__()

        # Encoder
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.encoder_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.mu = nn.Linear(hidden_size*2, latent_size)  # *2 for bidirectional
        self.logvar = nn.Linear(hidden_size*2, latent_size)

        # Decoder
        self.latent_to_hidden = nn.Linear(latent_size, hidden_size)
        self.decoder_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.output = nn.Linear(hidden_size, vocab_size)

    def encode(self, x):
        x = self.embedding(x)
        _, hidden = self.encoder_rnn(x)
        # Combine forward and backward hidden states
        hidden = torch.cat([hidden[0], hidden[1]], dim=1)

        mu = self.mu(hidden)
        logvar = self.logvar(hidden)
        return mu, logvar

    def decode(self, z, target=None, max_length=100, teacher_forcing=0.5):
        batch_size = z.size(0)
        device = z.device

        # Initialize hidden state from latent vector
        hidden = self.latent_to_hidden(z).unsqueeze(0)

        # Start with zeros (will be converted to start token)
        decoder_input = torch.zeros(batch_size, 1, device=device, dtype=torch.long)

        outputs = []

        # Generate sequence
        for t in range(max_length):
            # Embed input token
            emb = self.embedding(decoder_input)

            # Run RNN for one step
            output, hidden = self.decoder_rnn(emb, hidden)

            # Get output probabilities
            output = self.output(output)
            outputs.append(output)

            # Teacher forcing or use own prediction
            if target is not None and t < target.size(1) and random.random() < teacher_forcing:
                decoder_input = target[:, t:t+1]
            else:
                decoder_input = output.argmax(2)

        return torch.cat(outputs, dim=1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, target=None):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        output = self.decode(z, target)
        return output, mu, logvar

    def sample(self, n_samples=1, device='cuda'):
        with torch.no_grad():
            # Sample from latent space
            z = torch.randn(n_samples, self.mu.out_features, device=device)
            # Decode
            output = self.decode(z, teacher_forcing=0)
            # Get most likely tokens
            return output.argmax(dim=2)

def create_char_dict(smiles_list):
    """Create character to index mapping from SMILES strings"""
    # Add start/end tokens
    chars = set()
    for smiles in smiles_list:
        chars.update(smiles)

    # Special tokens
    all_chars = ['<pad>', '<start>', '<end>'] + sorted(list(chars))
    char_to_idx = {c: i for i, c in enumerate(all_chars)}
    idx_to_char = {i: c for c, i in char_to_idx.items()}

    return char_to_idx, idx_to_char, all_chars

def tokenize_smiles(smiles, char_to_idx, max_length=100):
    """Convert SMILES to token indices"""
    tokens = [char_to_idx['<start>']]

    # Add character tokens
    for c in smiles[:max_length-2]:  # -2 for start/end tokens
        tokens.append(char_to_idx.get(c, char_to_idx['<pad>']))

    # Add end token
    tokens.append(char_to_idx['<end>'])

    # Pad to fixed length
    while len(tokens) < max_length:
        tokens.append(char_to_idx['<pad>'])

    return tokens[:max_length]  # Ensure max length

def smiles_to_tensor(smiles_list, char_to_idx, max_length=100):
    """Convert list of SMILES to tensor of token indices"""
    tensors = []
    for smiles in smiles_list:
        tensors.append(tokenize_smiles(smiles, char_to_idx, max_length))
    return torch.tensor(tensors, dtype=torch.long)

def tensor_to_smiles(tensor, idx_to_char):
    """Convert tensor of token indices back to SMILES"""
    smiles = []
    for t in tensor:
        chars = []
        for i in t:
            c = idx_to_char[i.item()]
            if c == '<end>':
                break
            if c not in ['<pad>', '<start>']:
                chars.append(c)
        smiles.append(''.join(chars))
    return smiles

def vae_loss(recon_x, x, mu, logvar, kl_weight=0.1):
    """Combined loss function for VAE with KL annealing"""
    # Reconstruction loss (cross entropy)
    recon_x_flat = recon_x.reshape(-1, recon_x.size(2))
    x_flat = x.reshape(-1)

    # Mask padding in loss calculation
    mask = (x_flat != 0)  # Assuming 0 is pad token
    recon_loss = F.cross_entropy(recon_x_flat[mask], x_flat[mask], reduction='sum')

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Total loss
    return recon_loss + kl_weight * kl_loss

def train_model(model, train_data, optimizer, device, epochs=10):
    """Train the VAE model"""
    model.train()
    train_losses = []

    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0

        # Process in batches
        batch_size = 64
        num_batches = len(train_data) // batch_size

        for i in tqdm(range(num_batches), desc=f"Epoch {epoch+1}/{epochs}"):
            # Get batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch = train_data[start_idx:end_idx].to(device)

            # Forward pass
            recon_batch, mu, logvar = model(batch, batch)

            # Calculate loss
            loss = vae_loss(recon_batch, batch, mu, logvar)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / (num_batches * batch_size)
        train_losses.append(avg_loss)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}")

    return train_losses

def generate_smiles(model, idx_to_char, n_samples=100, device='cuda'):
    """Generate SMILES strings from the model"""
    model.eval()
    generated = []

    # Generate in batches for efficiency
    batch_size = 50
    n_batches = (n_samples + batch_size - 1) // batch_size

    for _ in range(n_batches):
        with torch.no_grad():
            # Sample from model
            samples = model.sample(min(batch_size, n_samples - len(generated)), device)

            # Convert to SMILES
            batch_smiles = tensor_to_smiles(samples, idx_to_char)

            # Save all generated strings even if they're not valid
            for smiles in batch_smiles:
                if smiles not in generated and len(smiles) <= 100:
                    generated.append(smiles)
                    if len(generated) >= n_samples:
                        break

    return generated

# Main execution
if __name__ == "__main__":
    # Parameters
    max_length = 100
    hidden_size = 256
    latent_size = 64
    epochs = 100

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    data_file = "/content/dataset_with_descriptors.csv"
    df = pd.read_csv(data_file)

    # Filter for active compounds
    active_df = df[df['Activity'] == 'Active']
    print(f"Found {len(active_df)} active compounds")

    # Get SMILES
    active_smiles = active_df['SMILES'].dropna().tolist()

    # Create character mappings
    char_to_idx, idx_to_char, all_chars = create_char_dict(active_smiles)
    print(f"Vocabulary size: {len(all_chars)}")

    # Convert to tensors
    all_data = smiles_to_tensor(active_smiles, char_to_idx, max_length)
    print(f"Training on {len(all_data)} examples")

    # Create model
    model = SimpleCharVAE(len(all_chars), hidden_size, latent_size).to(device)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train model - use all data, no validation
    losses = train_model(model, all_data, optimizer, device, epochs=epochs)

    # Plot losses
    plt.figure(figsize=(10, 6))
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.savefig('vae_training_loss.png')
    plt.close()

    # Save model
    torch.save(model.state_dict(), 'simple_smiles_vae.pt')

    # Generate new SMILES
    print("Generating new molecules...")
    generated = generate_smiles(model, idx_to_char, n_samples=100, device=device)

    # Save all generated SMILES
    with open('generated_molecules.txt', 'w') as f:
        for smiles in generated:
            f.write(f"{smiles}\n")

    # Also save molecules validated by RDKit separately
    valid_mols = []
    for smiles in generated:
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                valid_mols.append(Chem.MolToSmiles(mol))
        except:
            pass

    print(f"Generated {len(generated)} SMILES strings")
    print(f"Of which {len(valid_mols)} are chemically valid")

    # Save valid molecules separately
    with open('valid_molecules.txt', 'w') as f:
        for smiles in valid_mols:
            f.write(f"{smiles}\n")

loading model to generate new smiles and loading dataset for identifying the vocabulary

In [None]:
import torch
import numpy as np
import random
from rdkit import Chem
import pandas as pd

# Define the model class (same as in your training code)
class SimpleCharVAE(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size=256, latent_size=64):
        super(SimpleCharVAE, self).__init__()

        # Encoder
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.encoder_rnn = torch.nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.mu = torch.nn.Linear(hidden_size*2, latent_size)
        self.logvar = torch.nn.Linear(hidden_size*2, latent_size)

        # Decoder
        self.latent_to_hidden = torch.nn.Linear(latent_size, hidden_size)
        self.decoder_rnn = torch.nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.output = torch.nn.Linear(hidden_size, vocab_size)

    def decode(self, z, target=None, max_length=100, teacher_forcing=0.5):
        batch_size = z.size(0)
        device = z.device

        hidden = self.latent_to_hidden(z).unsqueeze(0)
        decoder_input = torch.zeros(batch_size, 1, device=device, dtype=torch.long)

        outputs = []

        for t in range(max_length):
            emb = self.embedding(decoder_input)
            output, hidden = self.decoder_rnn(emb, hidden)
            output = self.output(output)
            outputs.append(output)

            if target is not None and t < target.size(1) and random.random() < teacher_forcing:
                decoder_input = target[:, t:t+1]
            else:
                decoder_input = output.argmax(2)

        return torch.cat(outputs, dim=1)

    def sample(self, n_samples=1, device='cuda'):
        with torch.no_grad():
            z = torch.randn(n_samples, self.mu.out_features, device=device)
            output = self.decode(z, teacher_forcing=0)
            return output.argmax(dim=2)

# Function to convert tensor to SMILES
def tensor_to_smiles(tensor, idx_to_char):
    smiles = []
    for t in tensor:
        chars = []
        for i in t:
            c = idx_to_char[i.item()]
            if c == '<end>':
                break
            if c not in ['<pad>', '<start>']:
                chars.append(c)
        smiles.append(''.join(chars))
    return smiles

# Create character mapping from the original data
def create_char_dict(smiles_list):
    # Add start/end tokens
    chars = set()
    for smiles in smiles_list:
        chars.update(smiles)

    # Special tokens
    all_chars = ['<pad>', '<start>', '<end>'] + sorted(list(chars))
    char_to_idx = {c: i for i, c in enumerate(all_chars)}
    idx_to_char = {i: c for c, i in char_to_idx.items()}

    return char_to_idx, idx_to_char, all_chars

# Main execution
if __name__ == "__main__":
    # Set paths
    model_path = '/content/simple_smiles_vae.pt'  # Path to your saved model
    data_file = "/content/dataset_with_descriptors.csv"  # Path to original dataset

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load original data to recreate character mapping
    print("Loading original dataset to recreate character mapping...")
    df = pd.read_csv(data_file)

    # Filter for active compounds
    active_df = df[df['Activity'] == 'Active']
    print(f"Found {len(active_df)} active compounds")

    # Get SMILES
    active_smiles = active_df['SMILES'].dropna().tolist()

    # Create character mappings
    char_to_idx, idx_to_char, all_chars = create_char_dict(active_smiles)
    print(f"Vocabulary size: {len(all_chars)}")

    # Create model with correct vocabulary size
    model = SimpleCharVAE(len(all_chars), hidden_size=256, latent_size=64).to(device)

    # Load trained model
    print(f"Loading model from {model_path}...")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Generate SMILES
    print("Generating molecules...")
    n_samples = 100

    with torch.no_grad():
        samples = model.sample(n_samples, device)
        generated = tensor_to_smiles(samples, idx_to_char)

    # Filter valid molecules
    valid_mols = []
    for smiles in generated:
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                valid_mols.append(Chem.MolToSmiles(mol))
        except:
            pass

    # Save valid molecules
    with open('valid_molecules.txt', 'w') as f:
        for smiles in valid_mols:
            f.write(f"{smiles}\n")

    print(f"Generated {len(generated)} SMILES strings")
    print(f"Of which {len(valid_mols)} are chemically valid")

    # Save all generated molecules
    with open('all_generated_molecules.txt', 'w') as f:
        for smiles in generated:
            f.write(f"{smiles}\n")