In [1]:
# !pip3 install rdkit-pypi
# !pip3 install deepchem

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score, f1_score, mean_squared_error, r2_score
from transformers import get_scheduler
import pandas as pd
import numpy as np
import os
import random
from rdkit import Chem
from tqdm import tqdm

from torch.nn.utils.rnn import pad_sequence

In [3]:
from rdkit import RDLogger  # Import RDKit Logger

# Suppress all RDKit warnings
RDLogger.DisableLog('rdApp.*')

In [4]:
# Load your pretrained model
import nbimporter
from Exp_6l_8h_192d_1024ff_100M_molbert_SMILES import SMILESMLMTransformer, dynamic_byte_patching, compute_entropy

In [5]:
class DownstreamDataset(Dataset):
    def __init__(self, smiles_list, labels, max_len=128):
        self.smiles_list = smiles_list
        self.labels = labels
        self.max_len = max_len

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        label = self.labels[idx]
        patches, entropies = dynamic_byte_patching(smiles)

        # Basic safeguards
        if not patches:
            patches = [[0]]
            entropies = [0.0]

        # # Optional truncation
        # patches = patches[:self.max_len]
        # entropies = entropies[:self.max_len]

        return patches, entropies, label

In [6]:
# def collate_finetune(batch, pad_value=0):
#     # Sort batch by patch length (descending) — optional for packed sequence use
#     batch.sort(key=lambda x: len(x[0]), reverse=True)

#     # # Debug print: shape inspection
#     # for p, _, _ in batch:
#     #     print(f"[DEBUG] Patch count: {len(p)} | Patch shapes: {[len(i) if hasattr(i, '__len__') else type(i) for i in p]}")

#     # Convert patches (List[List[int]]) into padded 3D tensor [B, P, L]
#     patch_tensors = [torch.tensor(p, dtype=torch.long) for p, _, _ in batch]
#     max_patches = max([len(p) for p in patch_tensors])
#     max_patch_len = max([len(subpatch) for p in patch_tensors for subpatch in p])
    
#     padded_patches = torch.full((len(batch), max_patches, max_patch_len), pad_value, dtype=torch.long)
#     for i, patch_seq in enumerate(patch_tensors):
#         for j, patch in enumerate(patch_seq):
#             padded_patches[i, j, :len(patch)] = patch

#     # Convert and pad entropy tensors to [B, P]
#     entropy_tensors = [torch.tensor(e, dtype=torch.float32) for _, e, _ in batch]
#     entropy_padded = pad_sequence(entropy_tensors, batch_first=True, padding_value=0.0)

#     # Stack label tensors [B, ...]
#     label_tensors = [torch.tensor(l, dtype=torch.float32) for _, _, l in batch]
#     label_tensor = torch.stack(label_tensors)

#     return padded_patches, entropy_padded, label_tensor

In [7]:
def collate_finetune(batch, pad_value=0):
    # Extract the components
    patch_lists = [item[0] for item in batch]  # List of List[List[int]]
    entropy_lists = [item[1] for item in batch]  # List of List[float]
    label_list = [item[2] for item in batch]  # List of labels (float or vector)

    # Determine max patch count and max patch length
    max_num_patches = max(len(p) for p in patch_lists)
    max_patch_len = max(len(patch) for p in patch_lists for patch in p)

    # Prepare padded patch tensor: [B, P, L]
    padded_patches = torch.full(
        (len(batch), max_num_patches, max_patch_len), pad_value, dtype=torch.long
    )
    for i, patch_seq in enumerate(patch_lists):
        for j, patch in enumerate(patch_seq):
            padded_patches[i, j, :len(patch)] = torch.tensor(patch, dtype=torch.long)

    # Pad entropy tensors to [B, P]
    entropy_tensors = [torch.tensor(e, dtype=torch.float32) for e in entropy_lists]
    entropy_padded = pad_sequence(entropy_tensors, batch_first=True, padding_value=0.0)

    # Stack labels to [B, ...]
    label_tensor = torch.stack([torch.tensor(l, dtype=torch.float32) for l in label_list])

    return padded_patches, entropy_padded, label_tensor

In [8]:
class FinetuneHead(nn.Module):
    def __init__(self, base_model, task="classification"):
        super().__init__()
        self.base_model = base_model
        self.task = task
        self.head = None  # delayed init

    def forward(self, x):
        x = self.base_model(x)  # [B, L, D]
        x = x.mean(dim=1)       # [B, D]
        if self.head is None:
            self.head = nn.Linear(x.size(-1), 1).to(x.device)  # lazy init
        return self.head(x).squeeze(-1)

In [9]:
def train_fold(model, optimizer, scheduler, criterion, dataloader, device):
    model.train()
    all_losses = []
    for xb, _, yb in dataloader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb)

        if criterion._get_name() in ["BCEWithLogitsLoss", "MSELoss"]:
            loss = criterion(preds, yb.float())
        else:
            loss = criterion(preds, yb.long())

        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()
        all_losses.append(loss.item())

    return np.mean(all_losses)

In [10]:
def evaluate_fold(model, dataloader, task, device, multi_task=False):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for xb, _, yb in dataloader:
            xb = xb.to(device)
            outputs = model(xb).detach().cpu()
            preds.append(outputs)
            labels.append(yb)

    y_true = torch.cat(labels).numpy()
    y_pred = torch.cat(preds).numpy()

    if multi_task:
        aucs, f1s = [], []
        for i in range(y_true.shape[1]):
            if np.sum(~np.isnan(y_true[:, i])) == 0:
                aucs.append(np.nan)
                f1s.append(np.nan)
                continue
            mask = ~np.isnan(y_true[:, i])
            aucs.append(roc_auc_score(y_true[mask, i], y_pred[mask, i]))
            f1s.append(f1_score(y_true[mask, i], y_pred[mask, i] > 0.5))
        return np.array(aucs), np.array(f1s)

    if task == "classification":
        if y_pred.ndim == 1 or y_pred.shape[1] == 1:
            # Binary classification
            y_prob = torch.sigmoid(torch.tensor(y_pred)).numpy()
            y_pred_cls = (y_prob >= 0.5).astype(int)
        else:
            # Multiclass classification
            y_prob = torch.softmax(torch.tensor(y_pred), dim=1).numpy()
            y_pred_cls = np.argmax(y_pred, axis=1)
            y_prob = y_prob[:, 1]  # Assuming class 1 is the positive class

        auc = roc_auc_score(y_true, y_prob)
        f1 = f1_score(y_true, y_pred_cls)
        return auc, f1

    else:
        # Regression
        rmse = mean_squared_error(y_true, y_pred, squared=False)
        r2 = r2_score(y_true, y_pred)
        return rmse, r2

In [11]:
# class SMILESMLMTransformer(nn.Module):
#     def __init__(self, vocab_size=256, embedding_dim=192, num_heads=8, num_layers=6, dropout=0.1):
#         super().__init__()

#         self.vocab_size = vocab_size + 2  # Expand vocab for special tokens
#         self.embedding_dim = embedding_dim
        
#         self.embedding = DynamicBytePatchEmbedding(self.vocab_size, embedding_dim)
#         self.encoder_layer = nn.TransformerEncoderLayer(
#             d_model=embedding_dim,
#             nhead=num_heads,
#             dropout=dropout,
#             dim_feedforward=1024,
#             batch_first=True
#         )
#         self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
#         self.lm_head = nn.Linear(embedding_dim, self.vocab_size)  # for pretraining

#     def forward(self, x, entropy=None, attention_mask=None):
#         x = self.embedding(x)

#         if attention_mask is not None:
#             x = self.encoder(x, attention_mask=attention_mask)
#         else:
#             x = self.encoder(x)

#         if entropy is not None:
#             entropy = entropy.unsqueeze(-1)
#             x = x * (1 + entropy)  # entropy-aware reweighting

#         return x  # let FinetuneHead handle classification

In [12]:
def run_finetuning(task_name="BBBP", path="best_model_100M.pth", num_folds=10, batch_size=64, epochs=10, dataset_path="path_to_your_dataset.csv"):
    # Load the dataset directly from CSV
    dataset = pd.read_csv(dataset_path, encoding='ISO-8859-1')
    dataset.head()
    
    # Extract SMILES and labels (assuming the dataset has columns 'smiles' and 'label')
    smiles = dataset['smiles'].tolist()
    labels = dataset['label'].tolist()

    # Define task type (classification or regression)
    if task_name in {"BBBP", "Tox21"}:
        task_type = "classification"
    else:
        task_type = "regression"

    # Handle multi-task classification (e.g., Tox21)
    multi_task = True if task_name == "Tox21" else False

    # Set up K-Fold cross-validation
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    all_scores = []

    for fold, (train_idx, test_idx) in enumerate(kf.split(smiles)):
        print(f"\n--- Fold {fold + 1}/{num_folds} ---")

        # Initialize the base model and load the pretrained weights
        base_model = SMILESMLMTransformer()
        state_dict = torch.load(path, weights_only=True)
        base_model.load_state_dict(state_dict)
        # base_model.load_state_dict(torch.load(path))
        model = FinetuneHead(base_model, task=task_type).to(device)

        # Get train and test labels
        train_labels = [labels[i] for i in train_idx]
        test_labels = [labels[i] for i in test_idx]

        # Prepare datasets and dataloaders
        train_dataset = DownstreamDataset([smiles[i] for i in train_idx], train_labels)
        test_dataset = DownstreamDataset([smiles[i] for i in test_idx], test_labels)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_finetune)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_finetune)

        # Optimizer, scheduler, and loss function
        optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
        scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=epochs * len(train_loader))
        criterion = nn.BCEWithLogitsLoss() if task_type == "classification" else nn.MSELoss()

        # Training and evaluation loop
        for epoch in range(epochs):
            loss = train_fold(model, optimizer, scheduler, criterion, train_loader, device)
            metrics = evaluate_fold(model, test_loader, task_type, device, multi_task=multi_task)
            print(f"Epoch {epoch+1:02d} | Loss: {loss:.4f} | Metrics: {metrics}")

        all_scores.append(metrics)

    # Final results
    print(f"\nFinal {task_name} {task_type.upper()} Results over {num_folds} folds:")
    scores_np = np.array(all_scores)
    if multi_task:
        auc_mean = np.nanmean(scores_np[:, 0])
        f1_mean = np.nanmean(scores_np[:, 1])
        print(f"Mean AUC (macro): {auc_mean:.4f}")
        print(f"Mean F1  (macro): {f1_mean:.4f}")
    elif task_type == "classification":
        print(f"AUC: {scores_np[:, 0].mean():.4f} ± {scores_np[:, 0].std():.4f}")
        print(f"F1 : {scores_np[:, 1].mean():.4f} ± {scores_np[:, 1].std():.4f}")
    else:
        print(f"RMSE: {scores_np[:, 0].mean():.4f} ± {scores_np[:, 0].std():.4f}")
        print(f"R²  : {scores_np[:, 1].mean():.4f} ± {scores_np[:, 1].std():.4f}")

In [13]:
if __name__ == "__main__":
    run_finetuning("BBBP", dataset_path="Covid_19_data.csv")  # Or "ESOL", "Tox21", "Lipophilicity"


--- Fold 1/10 ---
Epoch 01 | Loss: 1.0989 | Metrics: (0.6765723833393006, 0.608)
Epoch 02 | Loss: 0.5730 | Metrics: (0.7478816087838644, 0.5494505494505495)
Epoch 03 | Loss: 0.5220 | Metrics: (0.7963360782909655, 0.7031250000000001)
Epoch 04 | Loss: 0.4630 | Metrics: (0.8156701277002028, 0.7203065134099617)
Epoch 05 | Loss: 0.4368 | Metrics: (0.8215180809165771, 0.736)
Epoch 06 | Loss: 0.4129 | Metrics: (0.8264709392528943, 0.73109243697479)
Epoch 07 | Loss: 0.3972 | Metrics: (0.8305883757011577, 0.7398373983739839)
Epoch 08 | Loss: 0.3721 | Metrics: (0.8261129013008712, 0.7394957983193277)
Epoch 09 | Loss: 0.3413 | Metrics: (0.8243227115407566, 0.7389558232931727)
Epoch 10 | Loss: 0.3393 | Metrics: (0.8243227115407566, 0.7398373983739839)

--- Fold 2/10 ---
Epoch 01 | Loss: 2.6930 | Metrics: (0.7244378698224853, 0.6753246753246753)
Epoch 02 | Loss: 0.6175 | Metrics: (0.7923076923076923, 0.6916666666666667)
Epoch 03 | Loss: 0.5550 | Metrics: (0.8199408284023669, 0.7272727272727272)
Ep