In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from rdkit.Chem import Descriptors
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

from tqdm.notebook import tqdm  # Import tqdm for progress bars
from transformers import get_scheduler

from torch.nn.utils.rnn import pad_sequence
import torch.nn.utils.rnn as rnn

In [2]:
from rdkit import RDLogger  # Import RDKit Logger

# Suppress all RDKit warnings
RDLogger.DisableLog('rdApp.*')

In [3]:
print(torch.cuda.is_available())  # Should return True
print(torch.version.cuda) 

True
12.1


### DATASETS FOR TRAINING

In [4]:
# SMILES datasets from PubChem database (ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction.)

# path = '/home/medard/ChemBERTa Data Trial/pubchem_10m_ant_lab.txt'
chemberta_10m = pd.read_csv('/home/medard/ChemBERTa Data Trial/pubchem_10m_ant_lab.txt', sep=" ", header=None, names=["SMILES"])
chemberta_10m

Unnamed: 0,SMILES
0,CN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1
1,CC[NH+](CC)C1CCC([NH2+]C2CC2)(C(=O)[O-])C1
2,COCC(CNC(=O)c1ccc2c(c1)NC(=O)C2)OC
3,OCCn1cc(CNc2cccc3c2CCCC3)nn1
4,CCCCCCc1ccc(C#Cc2ccc(C#CC3=CC=C(CCC)CC3)c(C3CC...
...,...
9999995,CC(=O)C(C)Cc1ccc2c(c1)NC(=O)CO2
9999996,O=C(Cn1cc(C=C2NC(=O)N(Cc3ccccc3F)C2=O)c2ccccc2...
9999997,COc1cc(C(F)(F)F)cc(n2c(C)nc(C#Cc3ccnc(Cl)c3)c2...
9999998,O=C(NCc1ccccc1)N1CCC2(CC1)OCCc1c2[nH]c2ccccc12


In [5]:
chemberta_10m.describe()

Unnamed: 0,SMILES
count,10000000
unique,10000000
top,CN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1
freq,1


In [6]:
# SMILES datasets from MolBert and MolRoPE-Bert

molbert_smiles = pd.read_csv('/home/medard/Bert pretraining for SMILES/molrope_smiles_all.smiles', header=None, names=["SMILES"])
molbert_smiles

Unnamed: 0,SMILES
0,CCCC(=O)NNC(=O)Nc1ccccc1
1,CC(=O)NC1CCC2(C)C(CCC3(C)C2C(=O)C=C2C4C(C)C(C)...
2,CC(=O)NC(C)Cc1ccc(C#Cc2ccnc(N3CCCC(F)C3)n2)cc1
3,Cc1cccc(CCNC(=O)C2CCC(=O)N(Cc3ccc(Cl)cc3)C2)n1
4,CC1C=CN(N(C)C)C2=C1C(=O)c1cnccc1C2=O
...,...
1591373,O=C(O)CCc1ccc(-c2c(C=Cc3ccc4ccccc4n3)nc3c(N4CC...
1591374,CCOCCc1ccc(OCCNC(=O)c2cc(C(C)(C)C)nn2C)c(C)c1
1591375,NCCCC(=O)Nc1ccc(C(=O)Nc2nccs2)cc1
1591376,COc1ccc(C2C(C#N)=C(N=CN3CCN(C)CC3)OC3=C2C(=O)C...


In [7]:
molbert_smiles.describe()

Unnamed: 0,SMILES
count,1591378
unique,1591378
top,CCCC(=O)NNC(=O)Nc1ccccc1
freq,1


In [8]:
# SMILES datasets from PubChem database (By Medard)

base_path = "/home/medard/Bert pretraining for SMILES/"
# base_path = "/home/user/10TB/Data/SMILES/"

pubchem_list = []

for i in range(6):
    pubchem = pd.read_csv(base_path + f"PubChem_compound_list_{i+1}.csv",
                     usecols = ["isosmiles"],
                     )
    pubchem_list.append(pubchem)
    
all_pubchem = pd.concat([pubchem_list[0], pubchem_list[1],
                        pubchem_list[2], pubchem_list[3],
                        pubchem_list[4], pubchem_list[5]],
                       ignore_index = True,
                        axis=0,
                       )
all_pubchem

Unnamed: 0,isosmiles
0,C(C(=O)COP(=O)(O)O)N
1,C(=O)C(C(=O)O)N
2,C1C(C(C(OC1O)CO)O)O
3,C1=NC(=C2C(=N1)N(C=N2)C3C(C(C(O3)COP(=O)(O)O)O...
4,C(C(C(=O)O)N)Cl
...,...
6770340,C(C(=O)[O-])C(CC(=O)[O-])(C(=O)[O-])O.C(C(=O)[...
6770341,[C@@H]([C@H](C(=O)[O-])O)(C(=O)[O-])O.[PbH2+2]
6770342,CC[C@@H](C)[C@H]1C(=O)NCC(=O)N[C@H]2C[S@@](=O)...
6770343,C[C@]1(CCC23COC4(C2C1)CCC5[C@]6(CCC(C(C6CC[C@]...


In [9]:
# drop_duplicates after concatenation
all_pubchem.drop_duplicates(keep='first',
                           inplace=True,
                           ignore_index = True)
all_pubchem

Unnamed: 0,isosmiles
0,C(C(=O)COP(=O)(O)O)N
1,C(=O)C(C(=O)O)N
2,C1C(C(C(OC1O)CO)O)O
3,C1=NC(=C2C(=N1)N(C=N2)C3C(C(C(O3)COP(=O)(O)O)O...
4,C(C(C(=O)O)N)Cl
...,...
6362615,CC(C)(C)C1(CCN2CC3CC4=CC=CC=C4SC5=CC=CC(=C35)C...
6362616,CC[C@H](C)[C@@H]1[C@H](C=C[C@@]2(O1)C[C@@H]3C[...
6362617,[Be+2].[O-][Si]([O-])([O-])[O-].[Mn+2].[Zn+2]
6362618,C(N(CP(=O)(O)[O-])CP(=O)(O)[O-])P(=O)(O)[O-].O...


In [10]:
all_pubchem.describe()

Unnamed: 0,isosmiles
count,6362620
unique,6362620
top,C(C(=O)COP(=O)(O)O)N
freq,1


#### Concatenate all the dataset for Training

In [11]:
all_smiles_for_training = pd.concat([chemberta_10m["SMILES"], all_pubchem["isosmiles"], molbert_smiles["SMILES"],], axis=0)
all_smiles_for_training

0                     CN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1
1                 CC[NH+](CC)C1CCC([NH2+]C2CC2)(C(=O)[O-])C1
2                         COCC(CNC(=O)c1ccc2c(c1)NC(=O)C2)OC
3                               OCCn1cc(CNc2cccc3c2CCCC3)nn1
4          CCCCCCc1ccc(C#Cc2ccc(C#CC3=CC=C(CCC)CC3)c(C3CC...
                                 ...                        
1591373    O=C(O)CCc1ccc(-c2c(C=Cc3ccc4ccccc4n3)nc3c(N4CC...
1591374        CCOCCc1ccc(OCCNC(=O)c2cc(C(C)(C)C)nn2C)c(C)c1
1591375                    NCCCC(=O)Nc1ccc(C(=O)Nc2nccs2)cc1
1591376    COc1ccc(C2C(C#N)=C(N=CN3CCN(C)CC3)OC3=C2C(=O)C...
1591377    O=C(CC(NC(=O)c1ccccc1)c1ccccc1)Oc1ccc([N+](=O)...
Length: 17953998, dtype: object

In [12]:
all_smiles_for_training.describe()

count                         17953998
unique                        17857808
top       CCCCCCCCCC(=O)CC(=O)NC1CCCC1
freq                                 3
dtype: object

In [13]:
all_smiles_for_training = all_smiles_for_training.unique()
all_smiles_for_training

array(['CN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1',
       'CC[NH+](CC)C1CCC([NH2+]C2CC2)(C(=O)[O-])C1',
       'COCC(CNC(=O)c1ccc2c(c1)NC(=O)C2)OC', ...,
       'NCCCC(=O)Nc1ccc(C(=O)Nc2nccs2)cc1',
       'COc1ccc(C2C(C#N)=C(N=CN3CCN(C)CC3)OC3=C2C(=O)CC(C)(C)C3)cc1',
       'O=C(CC(NC(=O)c1ccccc1)c1ccccc1)Oc1ccc([N+](=O)[O-])cc1'],
      dtype=object)

In [14]:
type(all_smiles_for_training)

numpy.ndarray

In [15]:
len(all_smiles_for_training)

17857808

In [16]:
# Convert numpy array data to pandas dataFrame

training_smiles_pd = pd.DataFrame(data=all_smiles_for_training, columns=["unique_smiles"])
training_smiles_pd

Unnamed: 0,unique_smiles
0,CN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1
1,CC[NH+](CC)C1CCC([NH2+]C2CC2)(C(=O)[O-])C1
2,COCC(CNC(=O)c1ccc2c(c1)NC(=O)C2)OC
3,OCCn1cc(CNc2cccc3c2CCCC3)nn1
4,CCCCCCc1ccc(C#Cc2ccc(C#CC3=CC=C(CCC)CC3)c(C3CC...
...,...
17857803,O=C(O)CCc1ccc(-c2c(C=Cc3ccc4ccccc4n3)nc3c(N4CC...
17857804,CCOCCc1ccc(OCCNC(=O)c2cc(C(C)(C)C)nn2C)c(C)c1
17857805,NCCCC(=O)Nc1ccc(C(=O)Nc2nccs2)cc1
17857806,COc1ccc(C2C(C#N)=C(N=CN3CCN(C)CC3)OC3=C2C(=O)C...


In [17]:
training_smiles_pd.describe()

Unnamed: 0,unique_smiles
count,17857808
unique,17857808
top,CN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1
freq,1


In [18]:
smiles_list = training_smiles_pd['unique_smiles'].to_list()
len(smiles_list)

17857808

In [19]:
type(smiles_list)

list

In [20]:
len(smiles_list)

17857808

In [21]:
# Ensure there are at least 500,000 samples
num_samples = min(10000000, len(smiles_list))
smiles_sample = random.sample(smiles_list, num_samples)  # Randomly select 500,000 without replacement

In [22]:
len(smiles_sample)

10000000

In [23]:
type(smiles_sample)

list

### ENTROPY, PATCHING, MASKING, TRAINING, and VISUALIZATION

In [24]:
# 1. Compute rolling entropy for SMILES byte sequences
def compute_entropy(byte_seq, window_size=5): #, stride=1
    entropy_values = []

    # If sequence is shorter than window size, assign uniform entropy
    if len(byte_seq) < window_size:
        return [0] * len(byte_seq)  # Assign zero entropy if sequence is too short

    # Compute entropy over the sliding window
    for i in range(len(byte_seq) - window_size + 1): #, stride
        window = byte_seq[i : i + window_size]
        freq = Counter(window)
        probs = np.array(list(freq.values())) / sum(freq.values())
        entropy = -np.sum(probs * np.log2(probs))
        entropy_values.append(entropy)

    # **Ensure output length matches byte_seq length**
    # Ensure entropy_values length matches byte_seq length
    if len(entropy_values) < len(byte_seq):
        entropy_values += [entropy_values[-1]] * (len(byte_seq) - len(entropy_values))
        
    return entropy_values

In [25]:
smiles_test = "C1=CC=CC=C1"  # Example short SMILES
print(compute_entropy(list(smiles_test.encode("utf-8")), window_size=5))

[1.3709505944546687, 1.5219280948873621, 0.9709505944546686, 0.7219280948873623, 0.9709505944546686, 0.9709505944546686, 1.3709505944546687, 1.3709505944546687, 1.3709505944546687, 1.3709505944546687, 1.3709505944546687]


In [26]:
# 2. Identify chemical motifs: Updated
def identify_motifs(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return {}

    metal_atoms = {3, 4, 11, 12, 19, 20, 30, 31, 38, 39, 48, 49, 56, 57, 72, 73, 80, 81, 88, 89}  # Common metal atomic numbers

    motifs = {
        "aromatic": False, "aromatic_positions": [],
        "charged": False, "charged_positions": [],
        "large_ring": False,
        "functional_groups": False,
        "branches": False,
        "double_bonds": False,
        "triple_bonds": False,
        "metal": False, "metal_positions": [],
        "non_metal": False
    }

    # Identify atom-level motifs
    for atom in mol.GetAtoms():
        idx = atom.GetIdx()  # Get the atom index

        # Aromatic atoms
        if atom.GetIsAromatic():
            motifs["aromatic"] = True
            motifs["aromatic_positions"].append(idx)

        # Charged atoms
        if atom.GetFormalCharge() != 0:
            motifs["charged"] = True
            motifs["charged_positions"].append(idx)

        # Metal atoms
        if atom.GetAtomicNum() in metal_atoms:
            motifs["metal"] = True
            motifs["metal_positions"].append(idx)
        else:
            motifs["non_metal"] = True  # Anything not in metal_atoms and not H

        # Functional groups (atoms with more than 2 neighbors)
        if len(atom.GetNeighbors()) > 2:
            motifs["functional_groups"] = True

        # Branches (atoms with more than 3 neighbors)
        if len(atom.GetNeighbors()) > 3:
            motifs["branches"] = True

    # Identify bond-level motifs
    for bond in mol.GetBonds():
        if bond.GetBondTypeAsDouble() == 2.0:
            motifs["double_bonds"] = True
        if bond.GetBondTypeAsDouble() == 3.0:
            motifs["triple_bonds"] = True

    # Identify large rings
    rings = Chem.GetSymmSSSR(mol)
    if any(len(ring) > 6 for ring in rings):
        motifs["large_ring"] = True

    return motifs


In [27]:
# 3. Dynamic Byte Patching with Entropy Thresholding
def dynamic_byte_patching(smiles, base_entropy_threshold=1.5, min_patch_size=2, max_patch_size=6, return_motif_sizes=False):    
    byte_seq = list(smiles.encode("utf-8"))

    # Define special tokens
    sos_token = 256
    eos_token = 257

    # If sequence is too short for entropy calculation, return single patch
    if len(byte_seq) < min_patch_size:
        return [[sos_token] + byte_seq + [eos_token]], [0], [], [], []  # Entire sequence as one patch with entropy 0 /Return empty lists for patch logs
        
    entropy_values = compute_entropy(byte_seq, window_size=5)

    # Ensure entropy_values length matches byte_seq length
    entropy_values = entropy_values + [entropy_values[-1]] * (len(byte_seq) - len(entropy_values))
    
    motifs = identify_motifs(smiles)
    # print("Motifs Detected:", motifs)

    # Dynamically adjust entropy threshold based on chemical properties
    entropy_threshold = base_entropy_threshold
    # if motifs["aromatic"]:
    if motifs.get("aromatic", False):  # Use .get() to avoid KeyError
        entropy_threshold *= 1.2  # Increase threshold for aromatic rings
    # if motifs["charged"]:
    if motifs.get("charged", False):  # Use .get() to avoid KeyError
        entropy_threshold *= 0.8  # Reduce threshold for charged groups
    # if motifs["functional_groups"]:
    if motifs.get("functional_groups", False):  # Safe access with get()
        entropy_threshold *= 1.1  # Slightly increase threshold for functional groups
    # if motifs["branches"]:
    if motifs.get("branches", False):
        entropy_threshold *= 0.9  # Reduce threshold for branched structures
    # if motifs["double_bonds"] or motifs["triple_bonds"]:
    if motifs.get("double_bonds", False) or motifs.get("triple_bonds", False):
        entropy_threshold *= 1.15  # Increase threshold for conjugated bonds
    # if motifs["metal"]:
    if motifs.get("metal", False):
        entropy_threshold *= 0.75  # Lower threshold for metals to preserve interactions
    # if motifs["non_metal"]:
    if motifs.get("non_metal", False):
        entropy_threshold *= 1.05  # Slightly increase for non-metal structures

    patches, entropy_log = [], []
    aromatic_patch_sizes, metal_patch_sizes, charged_patch_sizes = [], [], []
    
    i = 0
    while i < len(byte_seq):
        local_entropy = entropy_values[min(i, len(entropy_values)-1)]

        # Adjust patch size based on entropy
        if local_entropy < entropy_threshold * 0.8:
            patch_size = max_patch_size  # Low entropy → Larger patches
        elif local_entropy > entropy_threshold * 1.2:
            patch_size = min_patch_size  # High entropy → Smaller patches
        else:
            patch_size = (min_patch_size + max_patch_size) // 2  # Medium entropy → Medium patches
        
        # Ensure patch size does not exceed sequence length
        patch_size = min(patch_size, len(byte_seq) - i)

        # Ensure patch does NOT include `[SOS]` (0) or `[EOS]` (1)
        while i < len(byte_seq) and (byte_seq[i] == sos_token or byte_seq[i] == eos_token):
            i += 1  # Skip `[SOS]` and `[EOS]`

        # If patch starts with `[SOS]` or `[EOS]`, shift it
        if i + patch_size < len(byte_seq) and (byte_seq[i + patch_size - 1] == sos_token or byte_seq[i + patch_size - 1] == eos_token):
            patch_size -= 1

        # Check if patch contains any key motifs
        contains_aromatic = any(byte in motifs.get("aromatic_positions", []) for byte in range(i, i + patch_size))
        contains_metal = any(byte in motifs.get("metal_positions", []) for byte in range(i, i + patch_size))
        contains_charged = any(byte in motifs.get("charged_positions", []) for byte in range(i, i + patch_size))

        # print(f"Patch at {i}-{i+patch_size}: {byte_seq[i:i+patch_size]}")
        # print(f"Contains: Aromatic={contains_aromatic}, Metal={contains_metal}, Charged={contains_charged}")

        if contains_aromatic:
            aromatic_patch_sizes.append(patch_size)
        if contains_metal:
            metal_patch_sizes.append(patch_size)
        if contains_charged:
            charged_patch_sizes.append(patch_size)

        patches.append(byte_seq[i : i + patch_size])
        entropy_log.append(np.mean(entropy_values[i : i + patch_size]))
        i += patch_size  # Move to next patch

    # Add [SOS] to the first patch and [EOS] to the last patch
    if len(patches) > 0:
        patches[0] = [sos_token] + patches[0]  # Add [SOS] to first patch
        patches[-1] = patches[-1] + [eos_token]  # Add [EOS] to last patch

    # return patches, entropy_log, aromatic_patch_sizes, metal_patch_sizes, charged_patch_sizes
    if return_motif_sizes:
        return patches, entropy_log, aromatic_patch_sizes, metal_patch_sizes, charged_patch_sizes
    else:
        return patches, entropy_log  # Default case for dataset use: class SMILESDataset

In [28]:
# 4. Masking Function for MLM (BERT-style)
def mask_byte_patches(byte_patches, mask_ratio=0.15, vocab_size=256, sos_token=256, eos_token=257, mask_token=255):
    masked_patches, labels = [], []
    
    for patch in byte_patches:
        masked_patch, label = [], []
        
        for token in patch:
            if token == sos_token or token == eos_token:
                # Do NOT mask SOS or EOS tokens; keep them unchanged
                masked_patch.append(token)
                label.append(-100)  # Ignore in loss
            elif random.random() < mask_ratio:
                rand = random.random()
                if rand < 0.8:
                    masked_patch.append(mask_token)  # 80% replaced with MASK token (ID: 255)
                elif rand < 0.9:
                    masked_patch.append(random.randint(2, vocab_size - 1))  # 10% replaced randomly (excluding SOS/EOS)
                else:
                    masked_patch.append(token)  # 10% unchanged
                label.append(token)  # Store original token for loss
            else:
                masked_patch.append(token)
                label.append(-100)  # Ignore in loss
            
        masked_patches.append(masked_patch)
        labels.append(label)

    return masked_patches, labels

In [29]:
def apply_rope(embeddings):
    """
    Apply Rotary Position Embedding (RoPE) to the given embeddings.
    Args:
        embeddings: Tensor of shape (batch_size, seq_len, d_model)
    Returns:
        Tensor with RoPE applied
    """
    batch_size, seq_len, d_model = embeddings.shape
    assert d_model % 2 == 0, "Embedding dimension must be even for RoPE"

    # Create position indices
    positions = torch.arange(seq_len, dtype=torch.float32, device=embeddings.device).unsqueeze(1)

    # Compute theta
    theta = 10000 ** (-2 * (torch.arange(d_model // 2, dtype=torch.float32, device=embeddings.device) / d_model))
    theta = positions * theta  # Shape: (seq_len, d_model // 2)

    # Compute sin and cos values
    sin_vals = torch.sin(theta)
    cos_vals = torch.cos(theta)

    # Apply rotation
    embeddings_2d = embeddings.view(batch_size, seq_len, d_model // 2, 2)  # Reshape into 2D components
    x1, x2 = embeddings_2d[..., 0], embeddings_2d[..., 1]  # Split into two halves

    x1_new = x1 * cos_vals - x2 * sin_vals
    x2_new = x1 * sin_vals + x2 * cos_vals

    return torch.cat([x1_new, x2_new], dim=-1)  # Merge back into original shape

In [30]:
# 5. Byte Patch Embedding Layer
class DynamicBytePatchEmbedding(nn.Module):
    def __init__(self, vocab_size=256, embedding_dim=128):
        super().__init__()
        self.byte_embedding = nn.Embedding(vocab_size + 2, embedding_dim, padding_idx=0)

    def forward(self, byte_patches):
        """
        byte_patches: Tensor of shape (batch, seq_len, patch_size)
        """
        patch_embeds = self.byte_embedding(byte_patches) # (batch, seq_len, patch_size, embedding_dim)
        # print(f"Patch Embeddings Shape before Aggregation: {patch_embeds.shape}")
        patch_embeds = patch_embeds.sum(dim=-2) / torch.sqrt(torch.tensor(patch_embeds.shape[-2], dtype=torch.float))

        # print(f"Patch Embeddings Shape Aggregation: {patch_embeds.shape}")  # (batch, seq_len, embedding_dim)

        return patch_embeds # Ensure correct shape alignment

In [31]:
# 6. Transformer Encoder Model
class SMILESMLMTransformer(nn.Module):
    def __init__(self, vocab_size=256, embedding_dim=128, num_heads=8, num_layers=6, dropout=0.1):
        super().__init__()

        self.vocab_size = vocab_size + 2  # Expand vocab for special tokens
        self.embedding_dim = embedding_dim
        
        self.embedding = DynamicBytePatchEmbedding(self.vocab_size, embedding_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, dropout=dropout, dim_feedforward=512)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.lm_head = nn.Linear(embedding_dim, self.vocab_size)

    def forward(self, byte_patches):
        batch_size, seq_length, patch_size = byte_patches.shape # Ensure correct 3D shape
        # # byte_patches = byte_patches.view(batch_size, seq_length, patch_size)  # Keep structure 3D
        # print(f"Input Byte Patches Shape (MLM): {byte_patches.shape}")

        # Get embeddings
        x = self.embedding(byte_patches)  # Shape: (batch_size, seq_length, embedding_dim)
        # print(f"Embedding Output Shape (MLM): {x.shape}")

        # Apply RoPE before feeding into Transformer Encoder
        x = apply_rope(x)
        # print(f"Rope Embedding Output Shape (MLM): {x.shape}")

        # Pass through transformer
        x = self.encoder(x)  # Transformer processes sequence
        # print(f"Transformer Output Shape (MLM): {x.shape}")

        # Pass through language modeling head
        x = self.lm_head(x)
        # print(f"Logits Output Shape (MLM): {x.shape}")
    
        return x  # Output logits / Pass through linear layer (batch, seq_len, vocab_size)

In [32]:
# 7. Custom Dataset
class SMILESDataset(Dataset):
    def __init__(self, smiles_list, max_patch_len=128):
        self.data, self.labels, self.entropy_log = [], [], []

        # Define special token indices
        sos_token = 256  # Assign a unique integer for [SOS]
        eos_token = 257  # Assign a unique integer for [EOS]
        
        for smiles in tqdm(smiles_list, desc="Processing SMILES", leave=True):
            # Convert SMILES to byte patches
            # byte_patches, entropy_vals = dynamic_byte_patching(smiles, return_motif_sizes=False)
            byte_patches, entropy_vals, *motif_sizes = dynamic_byte_patching(smiles, return_motif_sizes=False)
            
            # ✅ Pass byte_patches as-is (no extra list wrapping)
            masked_patches, labels = mask_byte_patches(byte_patches)
            
            # Convert to tensor and pad sequences
            masked_patches = [torch.tensor(p, dtype=torch.long) for p in masked_patches]
            labels = [torch.tensor(l, dtype=torch.long) for l in labels]

            padded_patches = rnn.pad_sequence(masked_patches, batch_first=True, padding_value=0)
            padded_labels = rnn.pad_sequence(labels, batch_first=True, padding_value=-100)  # -100 for ignore_index in loss
            
            self.data.append(padded_patches) 
            self.labels.append(padded_labels) 
            self.entropy_log.append(torch.tensor(entropy_vals, dtype=torch.float)) # Convert to tensor / Store entropy per patch

        # ✅ Print debugging info AFTER tqdm completes  
        print("=====================###############==========================")
        print("Sample byte patches:", self.data[0] if self.data else "No data")
        print("Sample byte patches shape:", self.data[0].shape)
        print("=====================###############==========================")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx], self.entropy_log[idx]

In [33]:
def collate_fn(batch):
    data, labels, entropy_log = zip(*batch)  # Unzip batch data
    
    # Ensure all are padded to the longest sequence in batch
    max_seq_len = max([d.shape[0] for d in data])  # Get longest sequence
    max_patch_len = max([d.shape[1] for d in data])  # Get longest patch length
    
    data_padded = torch.zeros((len(batch), max_seq_len, max_patch_len), dtype=torch.long)
    labels_padded = torch.full((len(batch), max_seq_len, max_patch_len), -100, dtype=torch.long)
    entropy_padded = torch.zeros((len(batch), max_seq_len), dtype=torch.float)

    for i, (d, l, e) in enumerate(zip(data, labels, entropy_log)):
        seq_len, patch_len = d.shape
        data_padded[i, :seq_len, :patch_len] = d
        labels_padded[i, :seq_len, :patch_len] = l
        entropy_padded[i, :seq_len] = e

    return data_padded, labels_padded, entropy_padded

In [35]:
# 8. Training Loop with Logging
import sys

def train_with_entropy_tuning(model, dataloader, optimizer, criterion, scheduler, epochs=5, save_best_model=True):
    # Redirect print output to a log file
    log_file = open("training_log.txt", "w")
    sys.stdout = log_file

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Enable Multi-GPU with DataParallel
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        model = torch.nn.DataParallel(model)  # Wrap model with DataParallel
    
    model.to(device)  # Move model to GPU if available
    
    model.train()
    best_loss = float("inf")
    best_epoch = 0
    entropy_threshold = 1.5  # Initial base threshold
    log_losses, log_accuracies, log_entropy_thresholds = [], [], []

    try:
        for epoch in range(epochs):
            total_loss, correct, total = 0, 0, 0
            entropy_values = []

            # Use tqdm to track progress of batches within an epoch
            epoch_progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)

            for byte_patches, labels, entropy_log in epoch_progress:
                # Move the data to the appropriate device (GPU if available)
                byte_patches = byte_patches.to(device)
                labels = labels.to(device)
                entropy_log = [log.to(device) if isinstance(log, torch.Tensor) else log for log in entropy_log]

                optimizer.zero_grad()

                output = model(byte_patches) # Shape: (batch, seq_len, patch_size, vocab_size) / Forward pass
                output = output.reshape(-1, output.shape[-1])  # More robust
                
                labels = labels.view(output.shape[0], -1).squeeze(-1) # Ensure correct batch size
                labels = labels.argmax(dim=-1)  

                loss = criterion(output, labels)
                loss.backward()
                optimizer.step()
                scheduler.step()  # Update LR
                total_loss += loss.item()

                # Compute accuracy
                preds = torch.argmax(output, dim=-1)
                mask = labels != -100
                correct += (preds[mask] == labels[mask]).sum().item()
                total += mask.sum().item()

                # Track entropy values
                entropy_values.extend([np.mean(sublist.cpu().numpy()) for sublist in entropy_log if isinstance(sublist, torch.Tensor) and sublist.numel() > 0])

                # Update progress bar description with loss and accuracy
                epoch_progress.set_postfix(loss=total_loss / (epoch_progress.n + 1), accuracy=correct / total, entropy_threshold=entropy_threshold)

            # Adjust entropy threshold dynamically
            avg_entropy = np.mean(entropy_values) if len(entropy_values) > 0 else entropy_threshold
            entropy_threshold = 0.9 * entropy_threshold + 0.1 * avg_entropy  # Smooth update

            epoch_loss = total_loss / len(dataloader)
            epoch_accuracy = correct / total

            log_losses.append(epoch_loss)
            log_accuracies.append(epoch_accuracy)
            log_entropy_thresholds.append(entropy_threshold)

            print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, New Entropy Threshold: {entropy_threshold:.3f}")

            # Save best model
            if save_best_model and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_epoch = epoch + 1
                # torch.save(model.state_dict(), "best_model.pth")
                torch.save(model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict(), "best_model.pth")
                print(f"Best model saved at epoch {best_epoch} with loss {best_loss:.4f}")

        print(f"Training complete. Best model was at epoch {best_epoch} with loss {best_loss:.4f}")

        # Plot entropy threshold evolution
        plot_entropy_threshold_evolution(log_entropy_thresholds, save_path="Entropy Threshold Evolution Across Training Epochs.png")

        # Plot Patch Size Distribution
        plot_patch_size_distribution(entropy_values, save_path="Distribution of Patch Entropies Across Molecules.png")

    except Exception as e:
        print(f"Training stopped due to error: {str(e)}", file=sys.stderr)  # Print error to stderr

    finally:
        # Reset stdout back to normal
        sys.stdout = sys.__stdout__
        log_file.close()
        print("Training log saved to training_log.txt")

In [36]:
# 9. Plot entropy threshold evolution
def plot_entropy_threshold_evolution(log_entropy_thresholds, save_path="Entropy Threshold Evolution Across Training Epochs.png"):
    """Visualize how entropy threshold evolves across epochs."""
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, len(log_entropy_thresholds) + 1), log_entropy_thresholds, marker='o', linestyle='-')
    plt.xlabel("Epoch")
    plt.ylabel("Entropy Threshold")
    plt.title("Entropy Threshold Evolution Across Training Epochs")
    plt.grid()
    plt.savefig(save_path, bbox_inches="tight", dpi=300) # Save the figure with high resolution
    print(f"Results on Entropy Threshold Evolution Across Training Epochs saved to {save_path}")
    plt.show()

In [37]:
# 10. Patch Size Distribution Visualization
def plot_patch_size_distribution(entropy_values, save_path="Distribution of Patch Entropies Across Molecules.png"):
    plt.figure(figsize=(8, 5))
    plt.hist(entropy_values, bins=20, color="green", alpha=0.7, edgecolor="black")
    plt.xlabel("Patch Entropy")
    plt.ylabel("Frequency")
    plt.title("Distribution of Patch Entropies Across Molecules")
    plt.savefig(save_path, bbox_inches="tight", dpi=300) # Save the figure with high resolution
    print(f"Results on Distribution of Patch Entropies Across Molecules saved to {save_path}")
    plt.show()

In [38]:
def plot_patch_size_comparison(aromatic_sizes, metal_sizes, charged_sizes, save_path="patch_size_comparison.png"):
    plt.figure(figsize=(8, 5))

    plt.hist(aromatic_sizes, bins=10, color="blue", alpha=0.6, label="Aromatic Patches", edgecolor="black")
    plt.hist(metal_sizes, bins=10, color="red", alpha=0.6, label="Metal Patches", edgecolor="black")
    plt.hist(charged_sizes, bins=10, color="green", alpha=0.6, label="Charged Patches", edgecolor="black")

    plt.xlabel("Patch Size")
    plt.ylabel("Frequency")
    plt.title("Comparison of Patch Sizes for Aromatic, Metal, and Charged Groups")
    plt.legend()
    
    plt.savefig(save_path, bbox_inches="tight", dpi=300)  # Save with high resolution
    print(f"Patch size comparison plot saved to {save_path}")
    print(f"Patch size comparison saved to {save_path}")
    
    plt.show()

In [39]:
def plot_patch_size_comparison_2(aromatic_sizes, metal_sizes, charged_sizes):
    """Plots a comparison of patch sizes for aromatic, metal, and charged groups."""
    
    # data = [
    #     ("Aromatic", size) for size in aromatic_sizes
    # ] + [
    #     ("Metal", size) for size in metal_sizes
    # ] + [
    #     ("Charged", size) for size in charged_sizes
    # ]
    data = []
    if aromatic_sizes:
        data.extend([("Aromatic", size) for size in aromatic_sizes])
    if metal_sizes:
        data.extend([("Metal", size) for size in metal_sizes])
    if charged_sizes:
        data.extend([("Charged", size) for size in charged_sizes])
    
    if not data:
        print("No patch sizes found for aromatic, metal, or charged groups.")
        return
    
    categories, patch_sizes = zip(*data)

    plt.figure(figsize=(8, 5))
    sns.boxplot(x=categories, y=patch_sizes, palette={"Aromatic": "green", "Metal": "red", "Charged": "blue"})
    
    plt.xlabel("Patch Type")
    plt.ylabel("Patch Size")
    plt.title("Comparison of Patch Sizes for Aromatic, Metal, and Charged Groups")
    
    plt.savefig("patch_size_comparison_2.png", dpi=300, bbox_inches="tight")
    plt.show()

In [40]:
# 11. Run Pretraining with Logging

# Ensure CUDA is used / Define device globally
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
    smiles_list = smiles_sample
    
    dataset = SMILESDataset(smiles_list)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

    model = SMILESMLMTransformer().to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-2)

    # Compute training steps for scheduler
    num_training_steps = len(dataloader) * 5  # 5 epochs
    num_warmup_steps = int(0.1 * num_training_steps)  # 10% warmup steps
    
    # Learning rate scheduler with warmup
    scheduler = get_scheduler(
        "cosine",  # Can also use "linear"
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    train_with_entropy_tuning(model, dataloader, optimizer, criterion, scheduler, epochs=5, save_best_model=True)

    # Lists to store patch sizes for each motif
    all_aromatic_sizes = []
    all_metal_sizes = []
    all_charged_sizes = []

    for smiles in smiles_list:
        patches, entropy_vals, aromatic_sizes, metal_sizes, charged_sizes = dynamic_byte_patching(smiles, return_motif_sizes=True)
        
        all_aromatic_sizes.extend(aromatic_sizes)
        all_metal_sizes.extend(metal_sizes)
        all_charged_sizes.extend(charged_sizes)
        # print(f"Aromatic sizes: {all_aromatic_sizes}")
        # print(f"Metal sizes: {all_metal_sizes}")
        # print(f"Charged sizes: {all_charged_sizes}")

    # Plot the patch size comparison
    plot_patch_size_comparison(all_aromatic_sizes, all_metal_sizes, all_charged_sizes)

    plot_patch_size_comparison_2(all_aromatic_sizes, all_metal_sizes, all_charged_sizes)

Processing SMILES:   0%|          | 0/10000000 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Sample byte patches: tensor([[256, 255,  61,  67,  40,   0],
        [ 67,  40,  79,  41,   0,   0],
        [ 83,  40,  61, 255,   0,   0],
        [ 41,  40, 255,  79,   0,   0],
        [ 41,  67,  67,  79,  41,  78],
        [ 49,  67,  67,  61,  67,  40],
        [ 99,  50, 255,  40,   0,   0],
        [ 70,  41, 255,  99,   0,   0],
        [ 40,  78,  51,  67,   0,   0],
        [ 67,  40,  67,  79,   0,   0],
        [ 99,  52,  99, 255, 111, 195],
        [ 52,  41,  79,  67,   0,   0],
        [255, 255,  79,  41,   0,   0],
        [ 99, 255,  50,  70,   0,   0],
        [ 41,  67,  67,  49, 257,   0]])
Sample byte patches shape: torch.Size([15, 6])




Epoch 1/5:   0%|          | 0/156250 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter serve

DONE!!!!!!


In [41]:
print("DONE!!!!!!")

In [42]:
smiles_list

['O=C(C(O)S(=O)(=O)CCO)N1CC=C(c2c(F)cc(N3CC(COc4ccon4)OC3=O)cc2F)CC1',
 'C=CCOC(=O)NC(CC(=O)O)C(=O)CNCCCc1ccccc1',
 'CCOC(=O)C1CCCN(C1)C(=O)C2=CC=C(O2)CS(=O)C3=CC=CC=C3',
 'CC1=CC2=C(C=C1)N=C(N2)C3CN(C3)S(=O)(=O)C4=CC=C(C=C4)NC(=O)C',
 'CN1CCN(C(C1)C2=CC=CC=C2)C(=O)C3=C(C=CC(=C3)F)Br',
 'CNC(=O)C1=CC(=CN(C1=O)CC2=CC(=CC=C2)C#N)C(=O)NC3CC3',
 'CC(=O)Oc1ccc2c(c1)CC(CNS(=O)(=O)c1ccc(C)cc1)C2',
 'COc1ccc(C)c(-n2c(N)c(C(N)=O)c3nc4ccccc4nc32)c1',
 'COC1=C(C=C(C=C1)Cl)NC2=NC(=CS2)C3=CC4=C(C=CC(=C4)Br)OC3=O',
 'C1=CC=C2C(=C1)C(C3=CC=CC=C3O2)C(=O)NC4=CC(=C(C=C4)Cl)[N+](=O)[O-]',
 'Cc1nc(C2CCCCCN2S(=O)(=O)c2ccc3c(c2)CC(C)O3)no1',
 'CCC(=O)OC1=C(C=C(C=C1)C2C3=C(CC(CC3=O)(C)C)OC4=C2C(=O)CC(C4)(C)C)OC',
 'CC[NH+](CCc1ccncc1)C1=C(Sc2nc(C)cc(C)n2)C(=O)N(c2cccc(OC(F)(F)F)c2)C1=O',
 'CC(C)(CNC(=O)c1ccccc1OCC1CCCO1)[NH+]1CCOCC1',
 'CC(C)OCC(CN1CCN(CC1)C2=CC=CC=C2F)O',
 'CCC(C)C(=O)Nc1ccc(Cl)c(NC(=O)C(NC(=O)C2CC[NH+](CC(=O)Nc3cc(C(=O)N4CCOCC4)ccc3Cl)CC2)C(C)C)c1',
 'Cc1cc(C(=O)Nc2c(Br)cc(Br)cc2C(=O)[O-])

In [43]:
len(smiles_list)

10000000

In [44]:
smiles_sample

['O=C(C(O)S(=O)(=O)CCO)N1CC=C(c2c(F)cc(N3CC(COc4ccon4)OC3=O)cc2F)CC1',
 'C=CCOC(=O)NC(CC(=O)O)C(=O)CNCCCc1ccccc1',
 'CCOC(=O)C1CCCN(C1)C(=O)C2=CC=C(O2)CS(=O)C3=CC=CC=C3',
 'CC1=CC2=C(C=C1)N=C(N2)C3CN(C3)S(=O)(=O)C4=CC=C(C=C4)NC(=O)C',
 'CN1CCN(C(C1)C2=CC=CC=C2)C(=O)C3=C(C=CC(=C3)F)Br',
 'CNC(=O)C1=CC(=CN(C1=O)CC2=CC(=CC=C2)C#N)C(=O)NC3CC3',
 'CC(=O)Oc1ccc2c(c1)CC(CNS(=O)(=O)c1ccc(C)cc1)C2',
 'COc1ccc(C)c(-n2c(N)c(C(N)=O)c3nc4ccccc4nc32)c1',
 'COC1=C(C=C(C=C1)Cl)NC2=NC(=CS2)C3=CC4=C(C=CC(=C4)Br)OC3=O',
 'C1=CC=C2C(=C1)C(C3=CC=CC=C3O2)C(=O)NC4=CC(=C(C=C4)Cl)[N+](=O)[O-]',
 'Cc1nc(C2CCCCCN2S(=O)(=O)c2ccc3c(c2)CC(C)O3)no1',
 'CCC(=O)OC1=C(C=C(C=C1)C2C3=C(CC(CC3=O)(C)C)OC4=C2C(=O)CC(C4)(C)C)OC',
 'CC[NH+](CCc1ccncc1)C1=C(Sc2nc(C)cc(C)n2)C(=O)N(c2cccc(OC(F)(F)F)c2)C1=O',
 'CC(C)(CNC(=O)c1ccccc1OCC1CCCO1)[NH+]1CCOCC1',
 'CC(C)OCC(CN1CCN(CC1)C2=CC=CC=C2F)O',
 'CCC(C)C(=O)Nc1ccc(Cl)c(NC(=O)C(NC(=O)C2CC[NH+](CC(=O)Nc3cc(C(=O)N4CCOCC4)ccc3Cl)CC2)C(C)C)c1',
 'Cc1cc(C(=O)Nc2c(Br)cc(Br)cc2C(=O)[O-])

In [45]:
len(smiles_sample)

10000000