# End-to-End Conditional SMILES Generation Using a GPT Model

**Author:** Mahmoud Ebrahimkhani 

**Date:** 2025-04-11

**Course:** Applied Deep Learning and Generative AI in Healthcare

**Reference:** DOI: [10.1021/acs.jcim.1c00600](https://doi.org/10.1021/acs.jcim.1c00600)

---

## Introduction

In this notebook, you will train transformer-based language model to design novel drug-like small molecules. We'll walk through an end-to-end pipeline where a GPT-style model (decoder-only transformer) is trained to generate valid SMILES (Simplified Molecular Input Line Entry System) strings, *conditioned on desired molecular properties*. These properties—such as scaffold, LogP, QED, and TPSA are relevant in early drug discovery for assessing bioavailability, drug-likeness, and molecular complexity.

---

## What You'll Do

1. **Load and explore** the MOSES dataset ([github.com/molecularsets/moses](https://github.com/molecularsets/moses)), which contains pre-filtered drug-like molecules.
2. **Compute molecular descriptors**, including:
    - Scaffold (core structure of a molecule): The rigid, central framework that defines the molecule's basic shape and connectivity.
    - LogP (lipophilicity): A measure of how well a molecule dissolves in fats versus water, important for drug absorption and distribution.
    - QED (quantitative estimate of drug-likeness): A score between 0 and 1 that indicates how similar a molecule's properties are to known drugs.
    - TPSA (topological polar surface area): The total surface area of all polar atoms (mainly oxygen and nitrogen) in the molecule, which helps predict drug absorption.
3. **Format the data** to enable *conditional* generation—so the model learns to generate SMILES strings based on specified property values.
4. **Train a GPT-like transformer model** to generate molecules.
5. **Evaluate the quality of the generated molecules** using:
    - Validity of SMILES strings: Using RDKit to parse each generated SMILES string and verify it represents a valid chemical structure
    - Uniqueness of generated molecules: Computing the ratio of unique SMILES strings to total generated molecules
    - Tanimoto similarity to training molecules: Calculating molecular fingerprint similarity between generated and training set molecules to assess novelty
    - Alignment with the conditional property distributions: Comparing statistical distributions of properties (LogP, QED, TPSA) between generated and training molecules

## Environment Setup

You may need to install a few dependencies

In [None]:
%pip install pandas==1.2.5
%pip install rdkit==2022.3.5
%pip install scikit-learn matplotlib tqdm
%pip install torch --index-url https://download.pytorch.org/whl/cu118
%pip install transformers
%pip install git+https://github.com/molecularsets/moses.git

## 1. Import packages

In [1]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from rdkit import Chem
from rdkit.Chem import AllChem, Draw, Descriptors, QED
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol
from rdkit.Chem import rdMolDescriptors, rdDistGeom
from rdkit.Chem.MolStandardize import rdMolStandardize

import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    GPT2Config,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

from moses.dataset import get_dataset

  from .autonotebook import tqdm as notebook_tqdm


## 2. Data Loading and Filtering
 
 We'll use the MOSES dataset, which is a curated set of drug-like molecules specifically
 designed for machine learning applications. It's much smaller than ChEMBL 
 (https://www.ebi.ac.uk/chembl/, database: https://chembl.gitbook.io/chembl-interface-documentation/) 
 but still contains high-quality, drug-like compounds.

In [None]:
import os
import shutil
from pathlib import Path

def setup_moses_data():
    current_dir = Path.cwd() 
    moses_data_dir = current_dir / 'moses_data'
    external_moses_dir = current_dir.parent / 'external' / 'moses' / 'data'
    
    moses_data_dir.mkdir(exist_ok=True)
    
    files_to_copy = [
        'dataset_v1.csv',
        'test.csv',
        'test_scaffolds.csv'
    ]
    
    for file_name in files_to_copy:
        source_path = external_moses_dir / file_name
        dest_path = moses_data_dir / file_name
        
        if source_path.exists():
            print(f"Copying {file_name}...")
            shutil.copy2(source_path, dest_path)
            print(f"Successfully copied {file_name}")
        else:
            print(f"Warning: {file_name} not found in {external_moses_dir}")
    
    return moses_data_dir

def load_moses_data():
    try:
        data_dir = setup_moses_data()
        
        train_path = data_dir / 'dataset_v1.csv'
        print(f"Reading dataset from {train_path}")
        
        df = pd.read_csv(train_path)
        print("\nDataFrame info:")
        print(df.info())
        print("\nFirst few rows:")
        print(df.head())
        
        if 'smiles' in df.columns:
            df = df.rename(columns={'smiles': 'SMILES'})
        elif 'SMILES' not in df.columns:
            print("Available columns:", df.columns.tolist())
            raise ValueError("SMILES column not found in dataset")
        
        smiles_list = df['SMILES'].values
        print(f"\nFound {len(smiles_list)} SMILES strings")
        
        valid_mols = []
        for smi in tqdm(smiles_list, desc="Validating and filtering SMILES"):
            mol = Chem.MolFromSmiles(smi)
            if mol is not None:
                # Filter 1: Lipinski's criteria for oral bioavailability
                # Oral bioavailability is the ability of a drug to be absorbed into the bloodstream when taken orally
                mw = Descriptors.ExactMolWt(mol)
                logp = Descriptors.MolLogP(mol)
                hbd = rdMolDescriptors.CalcNumHBD(mol)
                hba = rdMolDescriptors.CalcNumHBA(mol)
                
                if mw <= 500 and logp <= 5 and hbd <= 5 and hba <= 10:
                    # Filter 2: Remove molecules with problematic functional groups
                    has_bad_groups = False
                    patt_list = [
                        '[N+]([O-])=O',  # Nitro groups: Highly reactive, can cause DNA damage and carcinogenicity
                        '[S](=[O])(=[O])',  # Sulfonyl groups: Can be chemically reactive and cause skin/eye irritation
                        '[P](=[O])',  # Phosphoryl groups: Potential toxicity and instability in biological systems
                        '[As]'  # Arsenic: Highly toxic heavy metal with severe health risks and carcinogenic properties
                    ]
                    for patt in patt_list:
                        if mol.HasSubstructMatch(Chem.MolFromSmarts(patt)):
                            has_bad_groups = True
                            break
                    
                    if not has_bad_groups:
                        valid_mols.append(smi)
        
        df = pd.DataFrame(valid_mols, columns=['SMILES'])
        print(f"Loaded {len(df)} valid drug-like molecules after filtering")
        return df
    
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        raise

filtered_df = load_moses_data()
filtered_df.head()

Copying dataset_v1.csv...
Successfully copied dataset_v1.csv
Copying test.csv...
Successfully copied test.csv
Copying test_scaffolds.csv...
Successfully copied test_scaffolds.csv
Reading dataset from /Users/mebrahimkhani/Documents/northeastern_teaching_appointment/session-13/moses_data/dataset_v1.csv

DataFrame info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1936962 entries, 0 to 1936961
Data columns (total 2 columns):
 #   Column  Dtype 
---  ------  ----- 
 0   SMILES  object
 1   SPLIT   object
dtypes: object(2)
memory usage: 29.6+ MB
None

First few rows:
                                   SMILES  SPLIT
0  CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1  train
1    CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1  train
2  CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1   test
3     Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO  train
4        Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C  train

Found 1936962 SMILES strings


Validating and filtering SMILES: 100%|██████████| 1936962/1936962 [10:30<00:00, 3070.81it/s] 


Loaded 1735494 valid drug-like molecules after filtering


Unnamed: 0,SMILES
0,CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1
1,CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1
2,CC1C2CCC(C2)C1CN(CCO)C(=O)c1ccc(Cl)cc1
3,Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO
4,Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C


## 3. Descriptor Calculation (Scaffolds, logP, QED, TPSA)
 We compute additional descriptors needed:
 - Murcko Scaffolds: Core molecular framework obtained by removing side chains and keeping only ring systems and linkers between rings
 - QED
 - TPSA
 - LogP

We'll store these in the DataFrame alongside the SMILES.

In [None]:
def calculate_descriptors(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return None, None, None, None
    
    try:
        scaffold = GetScaffoldForMol(mol)
        scaffold_smiles = Chem.MolToSmiles(scaffold)
        
        qed_val = QED.qed(mol)
        
        tpsa_val = rdMolDescriptors.CalcTPSA(mol)
        
        logp_val = Descriptors.MolLogP(mol)
        
        return scaffold_smiles, logp_val, qed_val, tpsa_val
    except:
        return None, None, None, None

print("Calculating molecular descriptors...")
scaffolds = []
logps = []
qeds = []
tpsas = []

for smi in tqdm(filtered_df['SMILES'], desc="Calculating descriptors"):
    scaffold, lp, qd, tp = calculate_descriptors(smi)
    scaffolds.append(scaffold)
    logps.append(lp)
    qeds.append(qd)
    tpsas.append(tp)

filtered_df['Scaffold'] = scaffolds
filtered_df['LogP'] = logps
filtered_df['QED'] = qeds
filtered_df['TPSA'] = tpsas

filtered_df = filtered_df.dropna(subset=['Scaffold', 'LogP', 'QED', 'TPSA'])
print(f"Final dataset size after descriptor calculation: {len(filtered_df)}")

print("\nDescriptor Statistics:")
print(f"LogP range: {filtered_df['LogP'].min():.2f} to {filtered_df['LogP'].max():.2f}")
print(f"QED range: {filtered_df['QED'].min():.2f} to {filtered_df['QED'].max():.2f}")
print(f"TPSA range: {filtered_df['TPSA'].min():.2f} to {filtered_df['TPSA'].max():.2f}")

filtered_df.head()