In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from RnnClass import RNNGenerator
from utils import return_vocabulary
import os
import numpy as np
import time
from rdkit import Chem
import json
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = True  # Automatic Mixed Precision for L4
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Configuration
PRETRAINED_MODEL = "rnn_model.pth" 
CHEMBL_DATA_PATH = './data/chembl.mini_cleaned.smi'  
OUTPUT_MODEL_PATH = "rnn_model_chembl.pth"
BATCH_SIZE = 128  
LEARNING_RATE = 0.0005
EPOCHS = 20

def validate_smiles(smiles):
    """Check if a SMILES string is valid using RDKit"""
    if not smiles or len(smiles) < 3:
        return False
        
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False
        return True
    except:
        return False

def create_new_vocabulary(smiles_list, min_freq=2):
    """Create a new vocabulary from SMILES strings"""
    print("Creating new vocabulary...")
    char_counts = Counter()
    for smiles in smiles_list:
        for char in smiles:
            char_counts[char] += 1
    
    common_chars = {char for char, count in char_counts.items() if count >= min_freq}
    all_chars = ['<PAD>'] + sorted(list(common_chars))
    
    char_to_idx = {char: idx for idx, char in enumerate(all_chars)}
    idx_to_char = {idx: char for idx, char in enumerate(all_chars)}
    
    print(f"Created vocabulary with {len(char_to_idx)} characters")
    return char_to_idx, idx_to_char

def save_vocabulary(char_to_idx, idx_to_char, filename="chembl_vocabulary.json"):
    """Save vocabulary to JSON file"""
    vocab_data = {
        "char_to_idx": char_to_idx,
        "idx_to_char": {int(k): v for k, v in idx_to_char.items()}  
    }
    
    with open(filename, 'w') as f:
        json.dump(vocab_data, f, indent=2)
    
    print(f"Vocabulary saved to {filename}")

print(f"Loading ChEMBL data from {CHEMBL_DATA_PATH}...")
with open(CHEMBL_DATA_PATH, 'r') as f:
    chembl_smiles = [line.strip() for line in f if line.strip()]
print(f"Loaded {len(chembl_smiles)} molecules from ChEMBL")

# Option to create new vocabulary or use existing one
create_new_vocab = True  # Set to True if you want to create a new vocabulary

if create_new_vocab:
    char_to_idx, idx_to_char = create_new_vocabulary(chembl_smiles)
    save_vocabulary(char_to_idx, idx_to_char, "chembl_vocabulary.json")
else:
    print("Loading vocabulary...")
    char_to_idx, idx_to_char = return_vocabulary("../cleaned_smiles.csv")

vocab_size = len(char_to_idx)
print(f"Vocabulary size: {vocab_size}")

# Calculate max length for padding
max_length = min(100, max(len(smi) for smi in chembl_smiles))
print(f"Maximum SMILES length: {max_length}")

def smiles_to_sequence(smiles, max_len):
    """Convert SMILES to padded sequence with exact length"""
    seq = [char_to_idx.get(char, 0) for char in smiles if char in char_to_idx]
    
    # Truncate if too long
    if len(seq) > max_len:
        seq = seq[:max_len]
    # Pad if too short
    elif len(seq) < max_len:
        seq = seq + [0] * (max_len - len(seq))
    
    return seq

# Convert SMILES to sequences with fixed length
print("Converting SMILES to sequences...")
sequences = []
for smi in tqdm(chembl_smiles):
    try:
        seq = smiles_to_sequence(smi, max_length)
        if len(seq) == max_length:  
            sequences.append(seq)
    except Exception as e:
        continue  

sequences = np.array(sequences, dtype=np.int64)
print(f"Processed {len(sequences)} valid sequences")
train_data = torch.tensor(sequences, dtype=torch.long)
print(f"Loading pretrained model from {PRETRAINED_MODEL}...")
model = RNNGenerator(vocab_size=vocab_size, embed_dim=128, hidden_dim=256)

if os.path.exists(PRETRAINED_MODEL):
    try:
        model.load_state_dict(torch.load(PRETRAINED_MODEL, map_location=device))
        print("Pre-trained weights loaded successfully!")
    except Exception as e:
        print(f"Error loading weights: {e}")
        print("Starting from scratch instead.")
else:
    print(f"No pretrained model found at {PRETRAINED_MODEL}")

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Starting training for {EPOCHS} epochs...")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    start_time = time.time()
    indices = torch.randperm(len(train_data))
    shuffled_data = train_data[indices]
    
    for i in tqdm(range(0, len(shuffled_data), BATCH_SIZE), desc=f"Epoch {epoch+1}/{EPOCHS}"):
        batch = shuffled_data[i : i + BATCH_SIZE].to(device)
        if len(batch) == 0:
            continue
            
        inputs, targets = batch[:, :-1], batch[:, 1:]  
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = model(inputs)
            loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item() * batch.size(0)
    
    avg_loss = total_loss / len(train_data)
    epoch_time = time.time() - start_time
    
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.6f}, Time: {epoch_time:.2f}s")
    
    if (epoch + 1) % 5 == 0:
        checkpoint_path = f"rnn_model_chembl_ep{epoch+1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

torch.save(model.state_dict(), OUTPUT_MODEL_PATH)
print(f"Final model saved to {OUTPUT_MODEL_PATH}")
model.eval()
print("\nGenerating and validating molecules:")

def generate_molecule(model, temperature=0.7, max_len=100):
    model.eval()
    with torch.no_grad():
        current_seq = torch.tensor([[char_to_idx['C']]], device=device)  # Start with carbon
        
        for _ in range(max_len):
            output = model(current_seq)
            next_logits = output[0, -1, :] / temperature
            next_probs = torch.softmax(next_logits, dim=0)
            next_char_idx = torch.multinomial(next_probs, 1)
            
            current_seq = torch.cat([current_seq, next_char_idx.unsqueeze(0)], dim=1)
            
            if next_char_idx.item() == 0:  # PAD token
                break
        
        smiles = ''.join([idx_to_char[idx.item()] for idx in current_seq[0] 
                         if idx.item() > 0 and idx.item() in idx_to_char])
        return smiles

num_to_generate = 50
valid_molecules = []
invalid_molecules = []

for i in tqdm(range(num_to_generate), desc="Generating molecules"):
    molecule = generate_molecule(model)
    is_valid = validate_smiles(molecule)
    
    if is_valid:
        valid_molecules.append(molecule)
    else:
        invalid_molecules.append(molecule)

print(f"\nGeneration Results:")
print(f"  Total generated: {num_to_generate}")
print(f"  Valid molecules: {len(valid_molecules)} ({len(valid_molecules)/num_to_generate*100:.1f}%)")
print(f"  Invalid molecules: {len(invalid_molecules)} ({len(invalid_molecules)/num_to_generate*100:.1f}%)")
print("\nValid molecule examples:")
for i, mol in enumerate(valid_molecules[:5]):
    print(f"  {i+1}. {mol}")

print("\nInvalid molecule examples:")
for i, mol in enumerate(invalid_molecules[:5]):
    print(f"  {i+1}. {mol}")

with open("generated_molecules.txt", "w") as f:
    f.write("Valid molecules:\n")
    for mol in valid_molecules:
        f.write(f"{mol}\n")
    
    f.write("\nInvalid molecules:\n")
    for mol in invalid_molecules:
        f.write(f"{mol}\n")

print(f"Generated molecules saved to generated_molecules.txt")

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


Using device: cuda
GPU: NVIDIA L4
Loading ChEMBL data from ./data/chembl.mini_cleaned.smi...
Loaded 171548 molecules from ChEMBL
Creating new vocabulary...
Created vocabulary with 33 characters
Vocabulary saved to chembl_vocabulary.json
Vocabulary size: 33
Maximum SMILES length: 100
Converting SMILES to sequences...


100%|████████████████████████████████████████████████████████████████| 171548/171548 [00:01<00:00, 132381.46it/s]


Processed 171548 valid sequences
Loading pretrained model from rnn_model.pth...
Error loading weights: Error(s) in loading state_dict for RNNGenerator:
	size mismatch for embedding.weight: copying a param with shape torch.Size([35, 128]) from checkpoint, the shape in current model is torch.Size([33, 128]).
	size mismatch for fc.weight: copying a param with shape torch.Size([35, 256]) from checkpoint, the shape in current model is torch.Size([33, 256]).
	size mismatch for fc.bias: copying a param with shape torch.Size([35]) from checkpoint, the shape in current model is torch.Size([33]).
Starting from scratch instead.
Starting training for 20 epochs...


  with torch.cuda.amp.autocast(enabled=use_amp):
Epoch 1/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 181.53it/s]


Epoch 1/20, Loss: 0.788257, Time: 7.43s


Epoch 2/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 181.94it/s]


Epoch 2/20, Loss: 0.500666, Time: 7.41s


Epoch 3/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 179.64it/s]


Epoch 3/20, Loss: 0.442593, Time: 7.51s


Epoch 4/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 180.20it/s]


Epoch 4/20, Loss: 0.410127, Time: 7.49s


Epoch 5/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 179.67it/s]


Epoch 5/20, Loss: 0.389988, Time: 7.51s
Checkpoint saved to rnn_model_chembl_ep5.pth


Epoch 6/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 178.77it/s]


Epoch 6/20, Loss: 0.376991, Time: 7.54s


Epoch 7/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 179.23it/s]


Epoch 7/20, Loss: 0.367904, Time: 7.52s


Epoch 8/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 179.05it/s]


Epoch 8/20, Loss: 0.361091, Time: 7.53s


Epoch 9/20: 100%|███████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 179.66it/s]


Epoch 9/20, Loss: 0.355735, Time: 7.50s


Epoch 10/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 177.90it/s]


Epoch 10/20, Loss: 0.351370, Time: 7.58s
Checkpoint saved to rnn_model_chembl_ep10.pth


Epoch 11/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 179.14it/s]


Epoch 11/20, Loss: 0.347810, Time: 7.52s


Epoch 12/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 180.07it/s]


Epoch 12/20, Loss: 0.344587, Time: 7.48s


Epoch 13/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 181.21it/s]


Epoch 13/20, Loss: 0.341938, Time: 7.44s


Epoch 14/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 181.06it/s]


Epoch 14/20, Loss: 0.339504, Time: 7.45s


Epoch 15/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 181.64it/s]


Epoch 15/20, Loss: 0.337378, Time: 7.43s
Checkpoint saved to rnn_model_chembl_ep15.pth


Epoch 16/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 180.62it/s]


Epoch 16/20, Loss: 0.335381, Time: 7.47s


Epoch 17/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 180.80it/s]


Epoch 17/20, Loss: 0.333587, Time: 7.46s


Epoch 18/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 179.83it/s]


Epoch 18/20, Loss: 0.331896, Time: 7.50s


Epoch 19/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 181.17it/s]


Epoch 19/20, Loss: 0.330495, Time: 7.45s


Epoch 20/20: 100%|██████████████████████████████████████████████████████████| 1341/1341 [00:07<00:00, 180.73it/s]


Epoch 20/20, Loss: 0.328932, Time: 7.47s
Checkpoint saved to rnn_model_chembl_ep20.pth
Final model saved to rnn_model_chembl.pth

Generating and validating molecules:


Generating molecules:   0%|                                                               | 0/50 [00:00<?, ?it/s][15:56:38] SMILES Parse Error: unclosed ring for input: 'Cc1cccc2cc3n(c12)OCC(C)(C)C(C)(C)C32'
[15:56:38] Explicit valence for atom # 13 C, 5, is greater than permitted
Generating molecules:  18%|█████████▉                                             | 9/50 [00:00<00:01, 26.13it/s][15:56:39] Can't kekulize mol.  Unkekulized atoms: 5 7 8 9 10 11 12
Generating molecules:  26%|██████████████                                        | 13/50 [00:00<00:01, 25.96it/s][15:56:39] SMILES Parse Error: unclosed ring for input: 'CC(C)CC(N)C(=O)NC(CC(C)C)C(O)C1CCC(C(O)C2C3CCC4CCC3(C)C2CCC2(C)C3CCC12C)C(C)C'
Generating molecules:  32%|█████████████████▎                                    | 16/50 [00:00<00:01, 25.35it/s][15:56:39] Can't kekulize mol.  Unkekulized atoms: 10 11 12 13 15 27 28
Generating molecules:  68%|████████████████████████████████████▋                 | 34/50 [00:01<00:00, 


Generation Results:
  Total generated: 50
  Valid molecules: 43 (86.0%)
  Invalid molecules: 7 (14.0%)

Valid molecule examples:
  1. CCOC(=O)c1c(O)c(Cc2ccc(C(=O)O)cc2)c(=O)c(C)c(C)c1O
  2. CC(C)(C)c1ccc(C=CC(=O)c2ccc(C(=O)O)cc2C)cc1
  3. Cc1cc(C)c(C(=O)N=c2[nH]nc(C(F)(F)F)nc2C)cc1F
  4. CC(=O)Nc1cc(S(=O)(=O)N2CCCCC2)c(F)cc1Cl
  5. COc1ccccc1C(=O)NCCc1nc(-c2ccc(C)cc2)ccc1Cl

Invalid molecule examples:
  1. Cc1cccc2cc3n(c12)OCC(C)(C)C(C)(C)C32
  2. Cc1cccc(CC2=NC(C)(C)C2=C(Cl)=C(C#N)(C2CC2)C(C)C)c1
  3. COC(O)=c1n[nH]c2ccccc12
  4. CC(C)CC(N)C(=O)NC(CC(C)C)C(O)C1CCC(C(O)C2C3CCC4CCC3(C)C2CCC2(C)C3CCC12C)C(C)C
  5. COc1ccc(-n2c(COc3ccc(C)c4[nH]c(=NCc5ccccc5)[nH]cc34)sc2)cc1
Generated molecules saved to generated_molecules.txt



