Local setup: install deps with `pip install -r requirements.txt`, download Kaggle data to `../data/`, then run cells below.

In [None]:
from rdkit import Chem
import torch, pandas as pd
print("RDKit version:", Chem.rdBase.rdkitVersion)
print("PyTorch version:", torch.__version__)

In [None]:
from pathlib import Path
import pandas as pd

DATA_DIR = Path('../data')
train_df = pd.read_csv(DATA_DIR / 'rxn_train.csv')
print(train_df.columns)
print(train_df.head(3))


In [None]:
example_rxn = train_df.iloc[0, 2] # first row's reaction SMILES
print("Example reaction SMILES:", example_rxn)

In [None]:
sample_df = train_df.sample(100, random_state=42)
sample_smiles = sample_df.iloc[:, 2] # the reaction SMILES column
print(sample_smiles.head(5))

In [None]:
rxn_str = sample_smiles.iloc[0]
reactants_str, product_str = rxn_str.split(">>")
print("Reactants:", reactants_str)
print("Product:", product_str)

In [None]:
from rdkit import Chem
# function to remove mapping via RDKit canonicalization
def canonicalize_smiles(smi):
  mol = Chem.MolFromSmiles(smi)
  if mol:
    return Chem.MolToSmiles(mol, canonical=True)
  return None
print("Before:", reactants_str)
print("After:", canonicalize_smiles(reactants_str))

In [None]:
assert sample_smiles.str.contains(">>").all()

In [None]:
test_smi = "CCO" # ethanol, a simple SMILES
mol = Chem.MolFromSmiles(test_smi)
print("RDKit Mol object:", mol)

In [None]:
canonical = Chem.MolToSmiles(mol)
print("Canonical SMILES from Mol:", canonical)

In [None]:
react_mol = Chem.MolFromSmiles(reactants_str)
prod_mol = Chem.MolFromSmiles(product_str)
print("Reactant Mol:", react_mol, "Product Mol:", prod_mol)

In [None]:
from rdkit.Chem import Descriptors
mw = Descriptors.MolWt(prod_mol) # molecular weight of product
print("Product molecular weight:", mw)

In [None]:
import pandas as pd
import numpy as np
from rdkit import Chem # RDKit for chemical informatics
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
# Use GPU if available for faster training
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
if device.type == "mps":
    torch.mps.empty_cache()
print("Using device:", device)

In [None]:
from pathlib import Path

DATA_DIR = Path('../data')
train_file = DATA_DIR / 'rxn_train.csv'
val_file = DATA_DIR / 'rxn_val.csv'
test_file = DATA_DIR / 'rxn_test.csv'

train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)
# The dataset might have columns like ['ID', 'Class', 'reactants>>product'].
# We'll assume the reaction SMILES is the last column (index 2 if zero-indexed).
train_rxns = train_df.iloc[:, -1].values
val_rxns = val_df.iloc[:, -1].values
test_rxns = test_df.iloc[:, -1].values
print('Number of training reactions:', len(train_rxns))
print('Example reaction (raw):', train_rxns[0])


In [None]:
def canonicalize_smiles(smiles):
  try:
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
      return None
    return Chem.MolToSmiles(mol, canonical=True)
  except Exception as e:
    return None

train_reactants, train_products = [], []
for rxn in train_rxns:
  react_str, prod_str = rxn.split(">>")
  # Canonicalize each molecule on both sides
  # Reactants might have multiple molecules separated by '.', handle each
  react_parts = [canonicalize_smiles(r) for r in react_str.split('.')]
  prod_parts = [canonicalize_smiles(p) for p in prod_str.split('.')]
  # Skip this reaction if any part failed to canonicalize
  if None in react_parts or None in prod_parts:
    continue
  # Rejoin multi-molecule SMILES with '.' (order is preserved)
  train_reactants.append('.'.join(react_parts))
  train_products.append('.'.join(prod_parts))

val_reactants, val_products = [], []
for rxn in val_rxns:
  react_str, prod_str = rxn.split(">>")
  react_parts = [canonicalize_smiles(r) for r in react_str.split('.')]
  prod_parts = [canonicalize_smiles(p) for p in prod_str.split('.')]
  if None in react_parts or None in prod_parts:
    continue
  val_reactants.append('.'.join(react_parts))
  val_products.append('.'.join(prod_parts))
test_reactants, test_products = [], []
for rxn in test_rxns:
  react_str, prod_str = rxn.split(">>")
  react_parts = [canonicalize_smiles(r) for r in react_str.split('.')]
  prod_parts = [canonicalize_smiles(p) for p in prod_str.split('.')]
  if None in react_parts or None in prod_parts:
    continue
  test_reactants.append('.'.join(react_parts))
  test_products.append('.'.join(prod_parts))


print(f"After canonicalization: {len(train_reactants)} training reactions, {len(val_reactants)} validation, {len(test_reactants)} test.")
# Print a sample to see the effect of canonicalization
print("Sample reactants (canonical):", train_reactants[0])
print("Sample product (canonical):", train_products[0])

In [None]:
# Build character vocabulary from the training set
chars = set()
for smi in train_reactants + train_products:
  chars.update(list(smi))
# Also ensure all characters in val and test are covered (to avoid unknown tokens)
for smi in val_reactants + val_products + test_reactants + test_products:
  chars.update(list(smi))
# Define special tokens
PAD_TOKEN = "<pad>"
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
special_tokens = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]
# Create vocab list: special tokens + sorted characters (for consistency)
char_list = sorted(list(chars))
vocab = special_tokens + char_list
# Create mappings from token to index and index to token
token_to_idx = {tok: i for i, tok in enumerate(vocab)}
idx_to_token = {i: tok for i, tok in enumerate(vocab)}
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)
print("Tokens:", vocab)

In [None]:
# Utility function to encode a SMILES string to a list of token indices
def encode_smiles(smi, token_to_idx):
  return [token_to_idx[ch] for ch in smi]

# Sample a smaller subset of the dataframes for debugging
sample_size = 200  # reduced for stability on Mac
sample_size = min(sample_size, len(train_df))
val_size = max(1, int(sample_size * 0.2))
train_df_sampled = train_df.sample(n=sample_size, random_state=42)
val_df_sampled = val_df.sample(n=min(val_size, len(val_df)), random_state=42)  # smaller validation set
test_df_sampled = test_df.sample(n=min(val_size, len(test_df)), random_state=42)  # smaller test set

train_rxns = train_df_sampled.iloc[:, -1].values  # numpy array of reaction strings
val_rxns = val_df_sampled.iloc[:, -1].values
test_rxns = test_df_sampled.iloc[:, -1].values

print("Number of training reactions (sampled):", len(train_rxns))
print("Example reaction (raw):", train_rxns[0])


train_reactants, train_products = [], []
for rxn in train_rxns:
  react_str, prod_str = rxn.split(">>")
  # Canonicalize each molecule on both sides
  # Reactants might have multiple molecules separated by '.', handle each
  react_parts = [canonicalize_smiles(r) for r in react_str.split('.')]
  prod_parts = [canonicalize_smiles(p) for p in prod_str.split('.')]
  # Skip this reaction if any part failed to canonicalize
  if None in react_parts or None in prod_parts:
    continue
  # Rejoin multi-molecule SMILES with '.' (order is preserved)
  train_reactants.append('.'.join(react_parts))
  train_products.append('.'.join(prod_parts))

val_reactants, val_products = [], []
for rxn in val_rxns:
  react_str, prod_str = rxn.split(">>")
  react_parts = [canonicalize_smiles(r) for r in react_str.split('.')]
  prod_parts = [canonicalize_smiles(p) for p in prod_str.split('.')]
  if None in react_parts or None in prod_parts:
    continue
  val_reactants.append('.'.join(react_parts))
  val_products.append('.'.join(prod_parts))
test_reactants, test_products = [], []
for rxn in test_rxns:
  react_str, prod_str = rxn.split(">>")
  react_parts = [canonicalize_smiles(r) for r in react_str.split('.')]
  prod_parts = [canonicalize_smiles(p) for p in prod_str.split('.')]
  if None in react_parts or None in prod_parts:
    continue
  test_reactants.append('.'.join(react_parts))
  test_products.append('.'.join(prod_parts))


print(f"After canonicalization: {len(train_reactants)} training reactions, {len(val_reactants)} validation, {len(test_reactants)} test.")
# Print a sample to see the effect of canonicalization
print("Sample reactants (canonical):", train_reactants[0])
print("Sample product (canonical):", train_products[0])


# Build character vocabulary from the training set
chars = set()
for smi in train_reactants + train_products:
  chars.update(list(smi))
# Also ensure all characters in val and test are covered (to avoid unknown tokens)
for smi in val_reactants + val_products + test_reactants + test_products:
  chars.update(list(smi))
# Define special tokens
PAD_TOKEN = "<pad>"
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
special_tokens = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]
# Create vocab list: special tokens + sorted characters (for consistency)
# Sorting the characters is not strictly necessary, but it ensures a stable order
char_list = sorted(list(chars))
vocab = special_tokens + char_list
# Create mappings from token to index and index to token
token_to_idx = {tok: i for i, tok in enumerate(vocab)}
idx_to_token = {i: tok for i, tok in enumerate(vocab)}
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)
# print("Tokens:", vocab) # Optional: print full vocab

# Encode all reactants and products
train_enc_reactants = [encode_smiles(s, token_to_idx) for s in train_reactants]
train_enc_products = [encode_smiles(s, token_to_idx) +
[token_to_idx[EOS_TOKEN]] for s in train_products]
val_enc_reactants = [encode_smiles(s, token_to_idx) for s in val_reactants]
val_enc_products = [encode_smiles(s, token_to_idx) + [token_to_idx[EOS_TOKEN]]
for s in val_products]
test_enc_reactants = [encode_smiles(s, token_to_idx) for s in test_reactants]
test_enc_products = [encode_smiles(s, token_to_idx) + [token_to_idx[EOS_TOKEN]]
for s in test_products]

# Determine max lengths for padding
max_len_react = max(len(seq) for seq in train_enc_reactants + val_enc_reactants + test_enc_reactants)
max_len_prod = max(len(seq) for seq in train_enc_products + val_enc_products + test_enc_products)
# Cap sequence lengths to avoid overly long sequences that can destabilize training on CPU/MPS
MAX_LEN_CAP = 256
max_len_react = min(max_len_react, MAX_LEN_CAP)
max_len_prod = min(max_len_prod, MAX_LEN_CAP)
print("Max reactant length (capped):", max_len_react)
print("Max product length (including <eos>) (capped):", max_len_prod)

# Pad sequences to the max length
def pad_sequence(seq, max_len, pad_idx):
  seq = seq[:max_len]
  return seq + [pad_idx] * (max_len - len(seq))

train_enc_reactants = [pad_sequence(seq, max_len_react, token_to_idx[PAD_TOKEN]) for seq in train_enc_reactants]
train_enc_products = [pad_sequence(seq, max_len_prod, token_to_idx[PAD_TOKEN]) for seq in train_enc_products]
val_enc_reactants = [pad_sequence(seq, max_len_react, token_to_idx[PAD_TOKEN]) for seq in val_enc_reactants]
val_enc_products = [pad_sequence(seq, max_len_prod, token_to_idx[PAD_TOKEN]) for seq in val_enc_products]
test_enc_reactants = [pad_sequence(seq, max_len_react, token_to_idx[PAD_TOKEN]) for seq in test_enc_reactants]
test_enc_products = [pad_sequence(seq, max_len_prod, token_to_idx[PAD_TOKEN]) for seq in test_enc_products]
# Convert lists to numpy arrays, then to torch tensors
train_enc_reactants = torch.tensor(train_enc_reactants, dtype=torch.long)
train_enc_products = torch.tensor(train_enc_products, dtype=torch.long)
val_enc_reactants = torch.tensor(val_enc_reactants, dtype=torch.long)
val_enc_products = torch.tensor(val_enc_products, dtype=torch.long)
test_enc_reactants = torch.tensor(test_enc_reactants, dtype=torch.long)
test_enc_products = torch.tensor(test_enc_products, dtype=torch.long)

# Create DataLoader for batching
batch_size = 64 # you can adjust this
train_dataset = TensorDataset(train_enc_reactants, train_enc_products)
val_dataset = TensorDataset(val_enc_reactants, val_enc_products)
test_dataset = TensorDataset(test_enc_reactants, test_enc_products)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
#batch_size=1 for easier evaluation
print("Data preparation done.")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# DataLoaders
train_dataset = TensorDataset(train_enc_reactants, train_enc_products)
val_dataset = TensorDataset(val_enc_reactants, val_enc_products)
test_dataset = TensorDataset(test_enc_reactants, test_enc_products)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Hyperparameters
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
NUM_LAYERS = 1

class Encoder(nn.Module):
    def __init__(self, input_size, embed_size, hidden_size, num_layers=1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_size, embed_size, padding_idx=token_to_idx[PAD_TOKEN])
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, batch_first=True)

    def forward(self, input_seq):
        embedded = self.embedding(input_seq)
        outputs, hidden = self.lstm(embedded)
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, output_size, embed_size, hidden_size, num_layers=1):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_size, embed_size, padding_idx=token_to_idx[PAD_TOKEN])
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, output_size)

    def forward(self, input_token, hidden):
        embedded = self.embedding(input_token)
        output, hidden = self.lstm(embedded, hidden)
        output_logits = self.fc_out(output)
        return output_logits, hidden

encoder = Encoder(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS).to(device)
decoder = Decoder(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS).to(device)

xent = nn.CrossEntropyLoss(ignore_index=token_to_idx[PAD_TOKEN], reduction='none')
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=5e-4)

num_epochs = 5
for epoch in range(1, num_epochs + 1):
    encoder.train(); decoder.train()
    total_loss = 0.0

    for batch_idx, (reactant_batch, product_batch) in enumerate(train_loader):
        reactant_batch = reactant_batch.to(device)
        product_batch = product_batch.to(device)

        encoder_outputs, encoder_hidden = encoder(reactant_batch)
        decoder_hidden = encoder_hidden

        batch_size_curr = reactant_batch.size(0)
        decoder_input = torch.full((batch_size_curr, 1), token_to_idx[SOS_TOKEN], dtype=torch.long, device=device)

        optimizer.zero_grad()
        loss_accum = 0.0
        token_count = 0

        for t in range(max_len_prod):
            output_logits, decoder_hidden = decoder(decoder_input, decoder_hidden)
            output_logits = output_logits.squeeze(1)
            if not torch.isfinite(output_logits).all():
                print(f"Skipping batch {batch_idx} due to non-finite logits at step {t}")
                token_count = 0
                break
            target_token = product_batch[:, t]
            mask = target_token != token_to_idx[PAD_TOKEN]
            if mask.any():
                per_token = xent(output_logits, target_token)
                loss_accum += (per_token * mask).sum()
                token_count += mask.sum().item()
            decoder_input = target_token.unsqueeze(1)

        if token_count == 0:
            print(f"Skipping batch {batch_idx} (no valid targets)")
            continue

        loss = loss_accum / token_count
        if not torch.isfinite(loss):
            print(f"Skipping batch {batch_idx} due to invalid loss value")
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()

    if len(train_loader) == 0:
        raise RuntimeError("Training loader is empty after preprocessing")
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}, Training Loss: {avg_loss:.4f}")

    # Validation
    encoder.eval(); decoder.eval()
    val_loss = 0.0
    val_tokens = 0

    with torch.no_grad():
        for reactant_batch, product_batch in val_loader:
            reactant_batch = reactant_batch.to(device)
            product_batch = product_batch.to(device)

            encoder_outputs, encoder_hidden = encoder(reactant_batch)
            decoder_hidden = encoder_hidden
            decoder_input = torch.full((reactant_batch.size(0), 1), token_to_idx[SOS_TOKEN], dtype=torch.long, device=device)

            for t in range(max_len_prod):
                output_logits, decoder_hidden = decoder(decoder_input, decoder_hidden)
                output_logits = output_logits.squeeze(1)
                if not torch.isfinite(output_logits).all():
                    continue
                target_token = product_batch[:, t]
                mask = target_token != token_to_idx[PAD_TOKEN]
                if mask.any():
                    per_token = xent(output_logits, target_token)
                    val_loss += (per_token * mask).sum().item()
                    val_tokens += mask.sum().item()
                decoder_input = target_token.unsqueeze(1)

    if len(val_loader) == 0 or val_tokens == 0:
        raise RuntimeError("Validation loader produced no valid targets")
    val_loss = val_loss / val_tokens
    print(f"Validation Loss: {val_loss:.4f}")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from rdkit import Chem

# -------- helper: quick token utilities ----------
SPECIAL_IDS = set([
    # fill at runtime (do this after token_to_idx exists):
    # token_to_idx[PAD_TOKEN], token_to_idx[SOS_TOKEN], token_to_idx[EOS_TOKEN]
])

def ids_to_smiles(id_list, idx_to_token):
    return "".join(
        idx_to_token[i] for i in id_list
        if i not in SPECIAL_IDS
    )

def is_balanced_brackets(s):
    # quick sanity check for () and [] balance while decoding
    stack = []
    pairs = {')': '(', ']': '['}
    for ch in s:
        if ch in '([':
            stack.append(ch)
        elif ch in ')]':
            if not stack or stack[-1] != pairs[ch]:
                return False
            stack.pop()
    # allow being "open" during partial decoding; just disallow over-closing
    return True

def partial_valid(s):
    # very lightweight gate: allow partial strings that still *could* form SMILES.
    # We don’t require RDKit to parse partials (it’s too strict mid-string).
    # 1) Don’t allow closing brackets/parentheses past balance
    # 2) Don’t allow two dots in a row, etc. (minimal rules)
    if not is_balanced_brackets(s):
        return False
    if '..' in s:
        return False
    return True

class AdditiveAttention(nn.Module):
    """
    Bahdanau-style attention: score(h_t, h_s) = v^T tanh(W1*h_s + W2*h_t)
    Shapes:
      encoder_outputs: (B, S, H_e)
      decoder_hidden_t: (B, H_d)
      returns: (B, S) attention weights
    """
    def __init__(self, enc_hidden, dec_hidden):
        super().__init__()
        self.W1 = nn.Linear(enc_hidden, dec_hidden, bias=False)
        self.W2 = nn.Linear(dec_hidden, dec_hidden, bias=False)
        self.v  = nn.Linear(dec_hidden, 1, bias=False)

    def forward(self, encoder_outputs, dec_hidden_t):
        # encoder_proj: (B, S, H_d)
        encoder_proj = self.W1(encoder_outputs)
        # dec_proj: (B, 1, H_d)
        dec_proj = self.W2(dec_hidden_t).unsqueeze(1)
        # energies: (B, S, 1) -> (B, S)
        energies = self.v(torch.tanh(encoder_proj + dec_proj)).squeeze(-1)
        attn_weights = F.softmax(energies, dim=1)
        return attn_weights  # (B, S)

class AttnDecoder(nn.Module):
    """
    Decoder with attention. Projects encoder_outputs to decoder hidden size if needed.
    """
    def __init__(self, vocab_size, embed_size, dec_hidden, num_layers=1, pad_idx=0, enc_hidden=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_size, dec_hidden, num_layers=num_layers, batch_first=True)
        self.fc_out = nn.Linear(dec_hidden + (enc_hidden if enc_hidden else dec_hidden), vocab_size)

        self.enc_to_dec = None
        self.attn = None
        if enc_hidden is None or enc_hidden == dec_hidden:
            # encoder and decoder hidden sizes match
            self.attn = AdditiveAttention(dec_hidden, dec_hidden)
            self.enc_to_dec = nn.Identity()
            self.enc_hidden_dim = dec_hidden
        else:
            # project encoder outputs to decoder hidden size
            self.enc_to_dec = nn.Linear(enc_hidden, dec_hidden, bias=False)
            self.attn = AdditiveAttention(dec_hidden, dec_hidden)
            self.enc_hidden_dim = dec_hidden

    def forward(self, input_token, hidden, encoder_outputs):
        """
        input_token: (B, 1)
        hidden: (h, c) with shapes (num_layers, B, dec_hidden)
        encoder_outputs: (B, S, enc_hidden)
        """
        B, S, _ = encoder_outputs.size()
        # project encoder outputs if needed
        enc_proj = self.enc_to_dec(encoder_outputs)  # (B, S, dec_hidden)

        emb = self.embedding(input_token)            # (B, 1, E)
        out, hidden = self.lstm(emb, hidden)         # out: (B, 1, dec_hidden)
        dec_h_t = hidden[0][-1]                      # last layer (B, dec_hidden)

        # attention over encoder states (using projected enc states)
        attn_w = self.attn(enc_proj, dec_h_t)        # (B, S)
        context = torch.bmm(attn_w.unsqueeze(1), enc_proj)  # (B, 1, dec_hidden)

        # concat context with decoder output (repeat out to shape (B, 1, dec_hidden))
        combined = torch.cat([out, context], dim=-1)  # (B, 1, dec_hidden + dec_hidden)
        logits = self.fc_out(combined)               # (B, 1, vocab)
        return logits, hidden, attn_w

In [None]:
SPECIAL_IDS = set([token_to_idx[PAD_TOKEN], token_to_idx[SOS_TOKEN], token_to_idx[EOS_TOKEN]])


In [None]:
# Function to perform greedy decoding for one reactant sequence
def predict_greedy(reactant_seq):
    encoder_outputs, encoder_hidden = encoder(reactant_seq)
    decoder_hidden = encoder_hidden
    decoder_input = torch.full((1, 1), token_to_idx[SOS_TOKEN], dtype=torch.long, device=device)  # start with <sos>

    predicted_tokens = []

    with torch.no_grad():
        for t in range(max_len_prod):
            output_logits, decoder_hidden = decoder(decoder_input, decoder_hidden)
            output_logits = output_logits.squeeze(1)  # shape (1, vocab_size)

            # Pick the token with highest probability
            top_token = output_logits.argmax(dim=1)  # tensor of shape (1,)
            top_idx = top_token.item()

            if top_idx == token_to_idx[EOS_TOKEN]:
                break

            predicted_tokens.append(top_idx)
            decoder_input = top_token.unsqueeze(1)  # input next token

    return predicted_tokens

# Evaluation loop
encoder.eval()
decoder.eval()
top1_correct = 0
total = 0
valid_predictions = 0

for reactant_seq, true_product_seq in test_loader:
    # Ensure reactant_seq is 2D (batch_size, seq_len) for predict_greedy
    reactant_seq = reactant_seq.to(device)
    true_product_seq = true_product_seq.to(device)

    # Since predict_greedy takes one sequence at a time, iterate through the batch
    for i in range(reactant_seq.size(0)):
        single_reactant_seq = reactant_seq[i].unsqueeze(0) # Shape (1, seq_len) for single prediction

        # Remove PAD and EOS from target sequence
        true_tokens = true_product_seq[i].squeeze().tolist()
        if token_to_idx[PAD_TOKEN] in true_tokens:
            true_tokens = true_tokens[:true_tokens.index(token_to_idx[PAD_TOKEN])]
        if true_tokens and true_tokens[-1] == token_to_idx[EOS_TOKEN]:
            true_tokens = true_tokens[:-1]

        # Correctly map token indices to characters
        true_smiles = "".join([idx_to_token[t] for t in true_tokens])
        true_canon = canonicalize_smiles(true_smiles)

        # Predict with greedy decoding
        predicted_token_seq = predict_greedy(single_reactant_seq)

        # Convert token sequence to SMILES string
        raw_pred_tokens = [idx_to_token[t] for t in predicted_token_seq]
        pred_smiles = "".join(raw_pred_tokens).replace(" ", "").replace("][", "].[")

        # Sanitize brackets: ensure balanced []
        if pred_smiles.count('[') != pred_smiles.count(']'):
            pred_smiles = ""

        # Canonicalize predicted SMILES
        pred_canon = None
        if pred_smiles != "":
            try:
                mol = Chem.MolFromSmiles(pred_smiles)
                if mol:
                    pred_canon = Chem.MolToSmiles(mol, canonical=True)
                    valid_predictions += 1
            except:
                pred_canon = None

        # Check match with true product
        if pred_canon is not None and true_canon is not None and pred_canon == true_canon:
            top1_correct += 1

        total += 1

# Final stats
top1_acc = top1_correct / total * 100
valid_percent = valid_predictions / total * 100

print(f"Baseline Seq2Seq Greedy Top-1 Accuracy: {top1_acc:.2f}%")
print(f"Valid product predictions: {valid_percent:.2f}%")

In [None]:
@torch.no_grad()
def greedy_decode_attention_with_constraints(
    encoder_attn, decoder_attn, input_seq,
    sos_idx, eos_idx, max_len,
    idx_to_token,
    top_n=10
):
    encoder_attn.eval(); decoder_attn.eval()
    enc_outputs, enc_hidden = encoder_attn(input_seq)
    dec_hidden = enc_hidden
    dec_input  = torch.tensor([[sos_idx]], device=device)

    pred_ids = []
    partial = ""

    for _ in range(max_len):
        logits, dec_hidden, _ = decoder_attn(dec_input, dec_hidden, enc_outputs)
        logits = logits.squeeze(1)                 # (1, V)
        probs  = F.log_softmax(logits, dim=-1)[0]  # (V,)

        next_idx = int(torch.argmax(probs).item())
        top_idx  = torch.topk(probs, k=min(top_n, probs.numel())).indices.tolist()

        chosen = None
        for cand in top_idx:
            if cand == eos_idx:
                chosen = cand
                break
            trial = partial + (idx_to_token[cand] if cand not in SPECIAL_IDS else "")
            if partial_valid(trial):
                chosen = cand
                break

        if chosen is None:
            chosen = next_idx

        if chosen == eos_idx:
            break

        pred_ids.append(chosen)
        if chosen not in SPECIAL_IDS:
            partial += idx_to_token[chosen]

        dec_input = torch.tensor([[chosen]], device=device)

    return pred_ids


In [None]:
# Print a few example predictions vs true
for i in range(3):
    reactant = test_enc_reactants[i].unsqueeze(0).to(device)  # take i-th test example
    true_prod = test_products[i]

    pred_tokens = predict_greedy(reactant)
    pred_smiles = "".join([idx_to_token[t] for t in pred_tokens])

    print(f"Reactants: {test_reactants[i]}")
    print(f"True Product: {true_prod}")
    print(f"Predicted Product: {pred_smiles}")

    mol = Chem.MolFromSmiles(pred_smiles)
    print(f"Valid SMILES: {mol is not None}")
    print("-" * 60)


# Decoder with Attention
class DecoderWithAttention(nn.Module):
    def __init__(self, output_size, embed_size, hidden_size, num_layers=1):
        super(DecoderWithAttention, self).__init__()
        self.embedding = nn.Embedding(
            output_size, embed_size, padding_idx=token_to_idx[PAD_TOKEN]
        )
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, batch_first=True)
        # Simple dot-product attention: no learnable parameter for scores
        self.attn_combine = nn.Linear(hidden_size * 2, hidden_size)
        self.fc_out = nn.Linear(hidden_size, output_size)

    def forward(self, input_token, hidden, encoder_outputs):
        # input_token: (batch, 1)
        # hidden: tuple (h, c)
        # encoder_outputs: (batch, seq_len, hidden_size)

        embedded = self.embedding(input_token)  # (batch, 1, embed_size)
        lstm_out, hidden = self.lstm(embedded, hidden)  # (batch, 1, hidden_size)

        # Attention scores: dot product between encoder_outputs and lstm_out
        attn_scores = torch.bmm(encoder_outputs, lstm_out.transpose(1, 2))
        # (batch, seq_len, 1)

        attn_weights = torch.softmax(attn_scores, dim=1)
        # (batch, seq_len, 1), normalized weights

        # Context vector: weighted sum of encoder_outputs
        context = torch.bmm(attn_weights.transpose(1, 2), encoder_outputs)
        # (batch, 1, hidden_size)

        # Concatenate LSTM output and context
        combined = torch.cat([lstm_out, context], dim=2)
        # (batch, 1, hidden_size*2)

        # Mix them with a linear layer + tanh
        combined = torch.tanh(self.attn_combine(combined))  # (batch, 1, hidden_size)

        # Final output logits
        output_logits = self.fc_out(combined)  # (batch, 1, output_size)

        return output_logits, hidden, attn_weights


In [None]:
# Initialize new model with attention
encoder_attn = Encoder(
    input_size=vocab_size,
    embed_size=EMBEDDING_DIM,
    hidden_size=HIDDEN_DIM,
    num_layers=NUM_LAYERS
).to(device)

decoder_attn = DecoderWithAttention(
    output_size=vocab_size,
    embed_size=EMBEDDING_DIM,
    hidden_size=HIDDEN_DIM,
    num_layers=NUM_LAYERS
).to(device)

criterion_attn = nn.CrossEntropyLoss(ignore_index=token_to_idx[PAD_TOKEN], reduction='none')
optimizer_attn = optim.Adam(
    list(encoder_attn.parameters()) + list(decoder_attn.parameters()),
    lr=5e-4
)

num_epochs_attn = 5
for epoch in range(1, num_epochs_attn + 1):
    encoder_attn.train(); decoder_attn.train()
    total_loss = 0.0

    for batch_idx, (reactant_batch, product_batch) in enumerate(train_loader):
        reactant_batch = reactant_batch.to(device)
        product_batch = product_batch.to(device)

        encoder_outputs, encoder_hidden = encoder_attn(reactant_batch)
        decoder_hidden = encoder_hidden
        decoder_input = torch.full(
            (reactant_batch.size(0), 1),
            token_to_idx[SOS_TOKEN],
            dtype=torch.long,
            device=device
        )

        optimizer_attn.zero_grad()
        loss_accum = 0.0
        token_count = 0

        for t in range(max_len_prod):
            output_logits, decoder_hidden, attn_weights = decoder_attn(
                decoder_input, decoder_hidden, encoder_outputs
            )
            output_logits = output_logits.squeeze(1)  # (batch, vocab)
            if not torch.isfinite(output_logits).all():
                print(f"[Attn] Skipping batch {batch_idx} due to non-finite logits at step {t}")
                token_count = 0
                break
            target_token = product_batch[:, t]
            mask = target_token != token_to_idx[PAD_TOKEN]
            if mask.any():
                per_token = criterion_attn(output_logits, target_token)
                loss_accum += (per_token * mask).sum()
                token_count += mask.sum().item()

            # Teacher forcing
            decoder_input = target_token.unsqueeze(1)

        if token_count == 0:
            print(f"[Attn] Skipping batch {batch_idx} (no valid targets)")
            continue
        loss = loss_accum / token_count
        if not torch.isfinite(loss):
            print(f"[Attn] Skipping batch {batch_idx} due to invalid loss value")
            continue
        loss.backward()

        nn.utils.clip_grad_norm_(
            list(encoder_attn.parameters()) + list(decoder_attn.parameters()),
            max_norm=1.0
        )

        optimizer_attn.step()
        total_loss += loss.item()

    if len(train_loader) == 0:
        raise RuntimeError("Training loader is empty after preprocessing")
    avg_loss = total_loss / len(train_loader)
    print(f"[Attn] Epoch {epoch}/{num_epochs_attn}, Training loss: {avg_loss:.4f}")

    # Validation loop
    encoder_attn.eval(); decoder_attn.eval()
    val_loss = 0.0
    val_tokens = 0

    with torch.no_grad():
        for reactant_batch, product_batch in val_loader:
            reactant_batch = reactant_batch.to(device)
            product_batch = product_batch.to(device)

            encoder_outputs, encoder_hidden = encoder_attn(reactant_batch)
            decoder_hidden = encoder_hidden
            decoder_input = torch.full(
                (reactant_batch.size(0), 1),
                token_to_idx[SOS_TOKEN],
                dtype=torch.long,
                device=device
            )

            for t in range(max_len_prod):
                output_logits, decoder_hidden, attn_weights = decoder_attn(
                    decoder_input, decoder_hidden, encoder_outputs
                )
                output_logits = output_logits.squeeze(1)
                if not torch.isfinite(output_logits).all():
                    continue
                target_token = product_batch[:, t]
                mask = target_token != token_to_idx[PAD_TOKEN]
                if mask.any():
                    per_token = criterion_attn(output_logits, target_token)
                    val_loss += (per_token * mask).sum().item()
                    val_tokens += mask.sum().item()

                # Teacher forcing during validation
                decoder_input = target_token.unsqueeze(1)

    if len(val_loader) == 0 or val_tokens == 0:
        raise RuntimeError("Validation loader produced no valid targets")
    val_loss = val_loss / val_tokens
    print(f"[Attn] Validation loss: {val_loss:.4f}")


In [None]:
def predict_beam_search(reactant_seq, beam_width=5, max_len=100):
    # Encode the input
    encoder_outputs, encoder_hidden = encoder_attn(reactant_seq)

    # Initialize the beam with the start token
    # Each beam is (sequence_so_far, hidden_state, cumulative_log_prob)
    beams = [([token_to_idx[SOS_TOKEN]], encoder_hidden, 0.0)]
    completed_sequences = []

    with torch.no_grad():
        # Expand beams up to max_len
        for _ in range(max_len):
            new_beams = []

            for seq, hidden, log_prob in beams:
                # If this sequence already ended with EOS, carry it over to completed
                if seq[-1] == token_to_idx[EOS_TOKEN]:
                    completed_sequences.append((seq, log_prob))
                    continue

                # Otherwise, get the next token probabilities
                last_token = torch.tensor([[seq[-1]]], device=device)
                output_logits, new_hidden, attn_weights = decoder_attn(
                    last_token, hidden, encoder_outputs
                )
                output_logits = output_logits.squeeze(1)  # (1, vocab_size)

                # Convert to log probabilities
                probs = torch.log_softmax(output_logits, dim=1)  # (1, vocab_size)
                probs = probs.cpu().numpy().flatten()

                # Get top beam_width next tokens
                top_indices = np.argsort(probs)[-beam_width:][::-1]

                for idx in top_indices:
                    new_seq = seq + [int(idx)]
                    new_log_prob = log_prob + float(probs[idx])

                    # Pass along detached hidden state
                    new_hidden_detached = (
                        new_hidden[0].detach().clone(),
                        new_hidden[1].detach().clone()
                    )

                    new_beams.append((new_seq, new_hidden_detached, new_log_prob))

            # If we didn’t expand any (all beams ended), break
            if len(new_beams) == 0:
                break

            # Keep only the best beams (highest log_prob first)
            new_beams.sort(key=lambda x: x[2], reverse=True)
            beams = new_beams[:beam_width]

            # Stop if all beams have produced EOS
            all_ended = all(seq[-1] == token_to_idx[EOS_TOKEN] for seq, _, _ in beams)
            if all_ended:
                completed_sequences.extend([(seq, log_prob) for seq, _, log_prob in beams])
                break
        else:
            # If loop finishes without break (max_len reached), add remaining beams
            completed_sequences.extend([(seq, log_prob) for seq, _, log_prob in beams])

    # Sort completed sequences by log probability
    completed_sequences.sort(key=lambda x: x[1], reverse=True)

    # Return the token sequences (excluding SOS and trimming after EOS)
    top_sequences = []
    for seq, log_prob in completed_sequences[:beam_width]:
        # Remove initial SOS
        if seq and seq[0] == token_to_idx[SOS_TOKEN]:
            seq = seq[1:]
        # Trim after EOS if present
        if token_to_idx[EOS_TOKEN] in seq:
            seq = seq[:seq.index(token_to_idx[EOS_TOKEN])]
        top_sequences.append(seq)

    return top_sequences


In [None]:
encoder_attn.eval()
decoder_attn.eval()

top1_correct = 0
top3_correct = 0
top5_correct = 0
valid_top1 = 0
total = 0

for reactant_seq, true_product_seq in test_loader:
    reactant_seq = reactant_seq.to(device)
    true_product_seq = true_product_seq.to(device)

    # Ensure single sample for beam search
    reactant_seq = reactant_seq.unsqueeze(0) if reactant_seq.dim() == 1 else reactant_seq
    if reactant_seq.size(0) != 1:
        reactant_seq = reactant_seq[0].unsqueeze(0)

    # Extract true product SMILES (remove PAD and EOS)
    true_tokens = torch.flatten(true_product_seq).tolist()

    if token_to_idx[PAD_TOKEN] in true_tokens:
        true_tokens = true_tokens[:true_tokens.index(token_to_idx[PAD_TOKEN])]
    if true_tokens and true_tokens[-1] == token_to_idx[EOS_TOKEN]:
        true_tokens = true_tokens[:-1]

    true_smiles = "".join([idx_to_token[t] for t in true_tokens])
    true_canon = canonicalize_smiles(true_smiles)

    # Beam search predictions
    top_sequences = predict_beam_search(
        reactant_seq, beam_width=5, max_len=max_len_prod
    )

    # Convert token sequences into canonical SMILES
    pred_smiles_list = []
    for seq in top_sequences:
        smiles = "".join([idx_to_token[t] for t in seq])
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            can = Chem.MolToSmiles(mol, canonical=True)
        else:
            can = None  # invalid SMILES
        pred_smiles_list.append(can)

    # Count valid top-1 predictions
    if pred_smiles_list:
        if pred_smiles_list[0] is not None:
            valid_top1 += 1

    # Check for correct matches in top-k
    if true_canon is not None:
        for rank, pred in enumerate(pred_smiles_list):
            if pred is not None and pred == true_canon:
                if rank == 0:
                    top1_correct += 1
                if rank < 3:
                    top3_correct += 1
                if rank < 5:
                    top5_correct += 1
                break  # stop after first correct match

    total += 1

# Final metrics
top1_acc = top1_correct / total * 100
top3_acc = top3_correct / total * 100
top5_acc = top5_correct / total * 100
valid_percent = valid_top1 / total * 100

print(f"Attention Seq2Seq Top-1 Accuracy: {top1_acc:.2f}%")
print(f"Attention Seq2Seq Top-3 Accuracy: {top3_acc:.2f}%")
print(f"Attention Seq2Seq Top-5 Accuracy: {top5_acc:.2f}%")
print(f"Valid Top-1 Predictions: {valid_percent:.2f}%")


In [None]:
# Compare an example with and without attention
for i in range(3):
    reactant = test_enc_reactants[i].unsqueeze(0).to(device)
    true_prod = test_products[i]

    # Baseline greedy
    base_tokens = predict_greedy(reactant)
    base_smiles = "".join([idx_to_token[t] for t in base_tokens])

    # Attention greedy (fair comparison)
    encoder_attn.eval()
    decoder_attn.eval()
    attn_tokens = []
    with torch.no_grad():
        enc_out, enc_h = encoder_attn(reactant)
        dec_h = enc_h
        dec_input = torch.full(
            (1, 1),
            token_to_idx[SOS_TOKEN],
            dtype=torch.long,
            device=device
        )

        for t in range(max_len_prod):
            out_logits, dec_h, attn_w = decoder_attn(dec_input, dec_h, enc_out)
            top_token = out_logits.squeeze(1).argmax(dim=1)
            if top_token.item() == token_to_idx[EOS_TOKEN]:
                break
            attn_tokens.append(top_token.item())
            dec_input = top_token.unsqueeze(1)

    attn_smiles = "".join([idx_to_token[t] for t in attn_tokens])

    # Print comparison
    print(f"Reactants: {test_reactants[i]}")
    print(f"True Product: {true_prod}")
    print(f"Baseline Pred (greedy): {base_smiles}")
    print(f"Attention Pred (greedy): {attn_smiles}")
    print("-" * 70)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Example reactant and product SMILES (with simple tokenization by character for demo)
reactant_smiles = "CCBr.O"  # e.g., bromoethane and water (tokens: C, C, Br, ., O)
product_smiles = "CCO"      # e.g., ethanol (tokens: C, C, O)

# Suppose the model's attention weights matrix has shape [len(product_tokens) x len(reactant_tokens)]
# We create a dummy attention matrix for illustration:
input_tokens = list(reactant_smiles)   # ['C','C','B','r','.','O']
output_tokens = list(product_smiles)   # ['C','C','O']
attention_weights = np.array([
    [0.8, 0.2, 0.0, 0.0, 0.0, 0.0],  # attention for output token 1 ('C')
    [0.1, 0.7, 0.2, 0.0, 0.0, 0.0],  # attention for output token 2 ('C')
    [0.0, 0.0, 0.1, 0.0, 0.0, 0.9]   # attention for output token 3 ('O')
])

# Plot heatmap
plt.figure(figsize=(6,4))
sns.heatmap(attention_weights, annot=True, cmap="YlGnBu",
            xticklabels=input_tokens, yticklabels=output_tokens)
plt.xlabel("Reactant SMILES Tokens")
plt.ylabel("Product SMILES Tokens")
plt.title("Attention Weight Heatmap")
plt.tight_layout()
plt.show()


In [None]:
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors

# True vs predicted product for a sample reaction (hypothetical failure case)
true_smiles = "CCO"      # e.g., true product is ethanol
pred_smiles = "CCCl"     # e.g., model predicted chloroethane (wrong product)

# Parse SMILES to molecules
true_mol = Chem.MolFromSmiles(true_smiles)
pred_mol = Chem.MolFromSmiles(pred_smiles)

# 1. Check if predicted SMILES was chemically parsable
if pred_mol is None:
    print("Predicted SMILES is invalid and could not be parsed by RDKit.")
else:
    print("Predicted SMILES parsed successfully (chemically valid molecule).")

# 2. Compare molecular formulas for atom conservation
true_formula = rdMolDescriptors.CalcMolFormula(true_mol)
pred_formula = rdMolDescriptors.CalcMolFormula(pred_mol)
print(f"True product formula: {true_formula}")
print(f"Predicted product formula: {pred_formula}")
if true_formula != pred_formula:
    print("Atom count mismatch – the model's prediction has a different formula (atoms lost or gained).")
else:
    print("Formulas match – predicted product has the same atoms as the true product (but possibly arranged differently).")

# 3. (If needed) check for specific atom mismatches:
true_atoms = sorted([atom.GetSymbol() for atom in true_mol.GetAtoms()])
pred_atoms = sorted([atom.GetSymbol() for atom in pred_mol.GetAtoms()])
print(f"True atoms: {true_atoms}")
print(f"Predicted atoms: {pred_atoms}")


In [None]:
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors, AllChem, DataStructs

# Example SMILES lists
true_products = ["CCCO", "CCO", "invalid_smiles"]
predicted_products = ["CCCO", "CC", "C(C)(C)O"]

sims = []

for true_smiles, pred_smiles in zip(true_products, predicted_products):
    true_mol = Chem.MolFromSmiles(true_smiles)
    pred_mol = Chem.MolFromSmiles(pred_smiles)

    if true_mol is None or pred_mol is None:
        sims.append(0.0)
        continue

    fp_true = rdMolDescriptors.GetMorganFingerprintAsBitVect(true_mol, radius=2, nBits=1024)
    fp_pred = rdMolDescriptors.GetMorganFingerprintAsBitVect(pred_mol, radius=2, nBits=1024)

    sim = DataStructs.TanimotoSimilarity(fp_true, fp_pred)
    sims.append(sim)

# Summary stats
average_sim = sum(sims) / len(sims)
print(f"Average Tanimoto similarity on test set = {average_sim:.3f}")
print("Sample similarities:", sims[:10])



In [None]:
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors, DataStructs

# Example SMILES lists with some invalid entries
true_products = ["CCCO", "CCO", "..."]
predicted_products = ["CCCO", "CC", "..."]

sims = []

for true_smiles, pred_smiles in zip(true_products, predicted_products):
    true_mol = Chem.MolFromSmiles(true_smiles)
    pred_mol = Chem.MolFromSmiles(pred_smiles)

    # Skip or assign similarity 0.0 if parsing fails
    if true_mol is None or pred_mol is None:
        sims.append(0.0)
        continue

    fp_true = rdMolDescriptors.GetMorganFingerprintAsBitVect(true_mol, radius=2, nBits=1024)
    fp_pred = rdMolDescriptors.GetMorganFingerprintAsBitVect(pred_mol, radius=2, nBits=1024)

    sim = DataStructs.TanimotoSimilarity(fp_true, fp_pred)
    sims.append(sim)

# Summary statistics
average_sim = sum(sims) / len(sims)
print(f"Average Tanimoto similarity on test set = {average_sim:.3f}")
print("Sample similarities:", sims[:10])



In [None]:
from rdkit.Chem import Draw

# Define true and predicted product SMILES for two example reactions
true_smiles1 = "CCCO"    # 1-propanol (true product in example 1)
pred_smiles1 = "CCCO"    # model prediction (same, correct)
true_smiles2 = "CCO"     # ethanol (true product in example 2)
pred_smiles2 = "CC"      # ethane (model's incorrect prediction for example 2)

# Convert to RDKit Mol objects
mol_true1 = Chem.MolFromSmiles(true_smiles1)
mol_pred1 = Chem.MolFromSmiles(pred_smiles1)
mol_true2 = Chem.MolFromSmiles(true_smiles2)
mol_pred2 = Chem.MolFromSmiles(pred_smiles2)

# Draw molecules in a 2x2 grid: columns = (True, Predicted), rows = (Example1, Example2)
mols = [mol_true1, mol_pred1, mol_true2, mol_pred2]
legends = [
    "True Product (Example 1)", "Predicted Product (Example 1)",
    "True Product (Example 2)", "Predicted Product (Example 2)"
]
img = Draw.MolsToGridImage(mols, molsPerRow=2, subImgSize=(300,200), legends=legends)
img.show()


In [None]:
import numpy as np

train_classes = []  # fill with class labels if available
if train_classes:
    classes, counts = np.unique(train_classes, return_counts=True)
    total = len(train_classes)
    class_weight = {cls: total/count for cls, count in zip(classes, counts)}
    print("Class weights:", class_weight)
else:
    print("No class labels provided; skipping class-weight demo.")


In [None]:
import pandas as pd
from rdkit.Chem import AllChem
from rdkit import DataStructs, Chem

true_smiles_list = []
pred_smiles_list = []
class_list = []

if not true_smiles_list:
    print("Add true/pred SMILES and class labels to compute per-class metrics.")
else:
    df = pd.DataFrame({
        "reaction_class": class_list,
        "true_smiles": true_smiles_list,
        "pred_smiles": pred_smiles_list
    })
    df["correct"] = df["true_smiles"] == df["pred_smiles"]
    acc_per_class = df.groupby("reaction_class")["correct"].mean()
    print("Top-1 Accuracy by class:")
    print(acc_per_class)

    def tanimoto_smiles(smi1, smi2):
        mol1 = Chem.MolFromSmiles(smi1); mol2 = Chem.MolFromSmiles(smi2)
        if mol1 is None or mol2 is None:
            return 0.0
        fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2, nBits=1024)
        fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2, nBits=1024)
        return DataStructs.TanimotoSimilarity(fp1, fp2)

    df["tanimoto_sim"] = df.apply(lambda row: tanimoto_smiles(row["true_smiles"], row["pred_smiles"]), axis=1)
    sim_per_class = df.groupby("reaction_class")["tanimoto_sim"].mean()
    print("\nAverage Tanimoto similarity by class:")
    print(sim_per_class)
