In [1]:
import re, sys
import pandas as pd
import numpy as np
import json
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import rdChemReactions, AllChem, Draw, PandasTools
RDLogger.DisableLog('rdApp.*')
import warnings
warnings.filterwarnings('ignore')

In [2]:
def remove_bb_smi_label(smi):
    return re.sub(r"\[\d{2}\*\]", "[*]", smi)

#### Extract

In [3]:
with open("../data/raw/DORA_Lactam_mols_bbs.json", "r") as f:
    mols = json.load(f)
with open("../data/raw/DORA_Lactam_bbs.json", "r") as f:
    bbs = json.load(f)
    
df_mols = pd.json_normalize(mols)
df_bbs = pd.json_normalize(bbs)

#### Transform (cleaning)

In [4]:
df_bbs = df_bbs[["bb_smi", "bb_id"]]
df_mols = df_mols[["mol_smi", "mol_id", "A_id", "B_id", "C_id"]]

for col in ["A", "B", "C"]:
    df_mols = df_mols.merge(
        df_bbs,
        left_on=f"{col}_id",
        right_on="bb_id",
        how="left"
    ).rename(columns={"bb_smi": f"{col}_smi"}).drop(columns=["bb_id"])

df_mols = df_mols.dropna(subset=["A_id", "B_id", "C_id"])

In [5]:
for col in ["A_smi", "B_smi", "C_smi"]:
    df_mols[col] = df_mols[col].apply(remove_bb_smi_label)
    
df_mols['bbs_smi'] = df_mols['A_smi'] + '.' + df_mols['B_smi'] + '.' + df_mols['C_smi']

#### Save

In [57]:
df_test = df_mols.sample(n=25, random_state=42)  
df_trainval = df_mols.drop(df_test.index)

In [58]:
df_trainval.to_pickle("../data/staging/trainval_dataset.pkl")
df_test.to_pickle("../data/transformed/test_dataset.pkl")

In [10]:
df_mols[['mol_smi','bbs_smi']].to_csv("../data/staging/trainval_dataset.csv", index=False)

#### Augment

In [59]:
import pandas as pd
from rdkit import Chem
import random

def augment_smiles(smiles, num_aug=10):
    """
    Generate augmented SMILES strings for the same molecule.
    
    Args:
        smiles (str): input SMILES
        num_aug (int): number of augmented SMILES to generate
    
    Returns:
        list of str: augmented SMILES (non-canonical)
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return []
    
    aug_smiles = set()
    for _ in range(num_aug):
        # Use random atom ordering to generate non-canonical SMILES
        aug = Chem.MolToSmiles(mol, doRandom=True)
        aug_smiles.add(aug)
    
    return list(aug_smiles)
    
def augment_dataset(df, smiles_cols, n_aug=10):
    """
    Augment a dataset by randomizing SMILES in selected columns.
    
    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    smiles_cols : list of str
        Column names that contain SMILES to augment.
    n_aug : int
        Number of augmented versions per row (excluding original).
    
    Returns
    -------
    pd.DataFrame
        Augmented dataframe with (n_aug+1) rows per original row.
    """
    augmented_rows = []
    
    for _, row in df.iterrows():
        # keep the original row
        augmented_rows.append(row.to_dict())
        
        # generate augmented rows
        for i in range(n_aug):
            new_row = row.to_dict()
            for col in smiles_cols:
                smi = row[col]
                aug_list = augment_smiles(smi, num_aug=10)
                new_row[col] = aug_list[0]
            augmented_rows.append(new_row)
    
    return pd.DataFrame(augmented_rows)

In [60]:
df_aug = augment_dataset(df_trainval, smiles_cols=["mol_smi", "A_smi", "B_smi", "C_smi"], n_aug=20)

In [62]:
df_aug.to_pickle("../data/transformed/trainval_dataset_augmented.pkl")

In [85]:
df_aug = pd.read_pickle("../data/transformed/trainval_dataset_augmented.pkl")

#### Transform (encode/decode)

In [100]:
# get tokens
import pickle
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

special_tokens = ["<pad>", "<s>", "</s>"]
all_smiles = df_aug['mol_smi'].tolist() + df_aug['bbs_smi'].tolist()
chars = set("".join(all_smiles))
itos = special_tokens + sorted(list(chars))   # index-to-string
stoi = {ch: i for i, ch in enumerate(itos)}   # string-to-index
pad_id = stoi["<pad>"]

with open("../tokens/stoi.pkl", "wb") as f:
    pickle.dump(stoi, f)
with open("../tokens/itos.pkl", "wb") as f:
    pickle.dump(itos, f)

In [None]:
with open("../tokens/stoi.pkl", "rb") as f:
    stoi = pickle.load(f)

In [87]:
def encode(smiles, stoi, max_len=128):
    # returns list of token IDs with <s> at start, </s> at end, padded
    ids = [stoi["<s>"]] + [stoi[ch] for ch in smiles] + [stoi["</s>"]]
    if len(ids) < max_len:
        ids += [stoi["<pad>"]] * (max_len - len(ids))
    return ids[:max_len]

def decode(ids, itos):
    # returns string ignoring special tokens
    chars = [itos[i] for i in ids if itos[i] not in ("<pad>", "<s>", "</s>")]
    return "".join(chars)

def encode_target(smi, stoi, max_len=128):
    full_ids = encode(smi, stoi, max_len)   # <s> ... </s> + pad
    tgt_in = full_ids[:-1]                  # decoder input (drop final </s>)
    tgt_out = full_ids[1:]                  # loss target (drop initial <s>)
    return tgt_in, tgt_out

In [88]:
# encode df
df_aug['tgt_in'], df_aug['tgt_out'] = zip(*df_aug['bbs_smi'].apply(lambda x: encode_target(x, stoi)))

In [103]:
df_aug.to_parquet("../data/transformed/trainval_dataset_augmented_encoded.parquet", index=False)