In [1]:
!pip install transformers torchmetrics rdkit-pypi


Collecting torchmetrics
  Downloading torchmetrics-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting rdkit-pypi
  Downloading rdkit_pypi-2022.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.0-py3-none-any.whl.metadata (5.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cudnn_cu12-9.1.0

In [2]:
from google.colab import drive
drive.mount('/content/drive')

import pandas as pd
from pathlib import Path

DATA_DIR = Path('/content/drive/MyDrive/ProteinMO/Dataset/Cleaned')

train_df = pd.read_csv(DATA_DIR /'train.csv')
val_df   = pd.read_csv(DATA_DIR / 'val.csv')
test_df  = pd.read_csv(DATA_DIR / 'test.csv')

print("Train:", train_df.shape, "Val:", val_df.shape, "Test:", test_df.shape)
print("Columns:", train_df.columns)


Mounted at /content/drive
Train: (1033794, 7) Val: (121461, 7) Test: (107098, 7)
Columns: Index(['smiles', 'ic50_nM', 'protein_seq', 'protein_name', 'uniprot_id',
       'protein_desc', 'label'],
      dtype='object')


In [3]:
from transformers import AutoTokenizer, AutoModel
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1) Protein sequence encoder (ESM2)
esm_model_name = "facebook/esm2_t6_8M_UR50D"
esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name, do_lower_case=False)
esm_model = AutoModel.from_pretrained(esm_model_name).to(device)
esm_model.eval()  # freeze for now

# 2) Molecule encoder (ChemBERTa)
chem_model_name = "seyonec/ChemBERTa-zinc-base-v1"
chem_tokenizer = AutoTokenizer.from_pretrained(chem_model_name)
chem_model = AutoModel.from_pretrained(chem_model_name).to(device)
chem_model.eval()  # freeze for now

# 3) Protein description encoder (BioBERT / PubMedBERT)
desc_model_name = "dmis-lab/biobert-base-cased-v1.1"   # or "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
desc_tokenizer = AutoTokenizer.from_pretrained(desc_model_name)
desc_model = AutoModel.from_pretrained(desc_model_name).to(device)
desc_model.eval()  # freeze for now



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/166 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/501 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/179M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/179M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [4]:
from torch.utils.data import Dataset
import torch

class ProteinMolDataset(Dataset):
    def __init__(self, df, esm_tokenizer, chem_tokenizer, desc_tokenizer,
                 max_prot_len=512, max_smiles_len=128, max_desc_len=128):
        self.df = df.reset_index(drop=True)
        self.esm_tokenizer = esm_tokenizer
        self.chem_tokenizer = chem_tokenizer
        self.desc_tokenizer = desc_tokenizer
        self.max_prot_len = max_prot_len
        self.max_smiles_len = max_smiles_len
        self.max_desc_len = max_desc_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        prot_seq = row['protein_seq']
        smiles = row['smiles']
        prot_desc = row['protein_desc']
        label = torch.tensor(row['label'], dtype=torch.float)

        prot_inputs = self.esm_tokenizer(
            prot_seq, truncation=True, max_length=self.max_prot_len,
            padding='max_length', return_tensors='pt'
        )
        chem_inputs = self.chem_tokenizer(
            smiles, truncation=True, max_length=self.max_smiles_len,
            padding='max_length', return_tensors='pt'
        )
        desc_inputs = self.desc_tokenizer(
            prot_desc if isinstance(prot_desc, str) else "",
            truncation=True, max_length=self.max_desc_len,
            padding='max_length', return_tensors='pt'
        )

        return prot_inputs, chem_inputs, desc_inputs, label



In [5]:
import torch
import torch.nn as nn

class FusionModel(nn.Module):
    def __init__(self, esm_model, chem_model, desc_model, hidden_size=256, freeze_backbones=True):
        super().__init__()
        self.esm = esm_model
        self.chem = chem_model
        self.desc = desc_model

        if freeze_backbones:  # freeze by default
            for p in self.esm.parameters():  p.requires_grad = False
            for p in self.chem.parameters(): p.requires_grad = False
            for p in self.desc.parameters(): p.requires_grad = False

        esm_dim  = self.esm.config.hidden_size
        chem_dim = self.chem.config.hidden_size
        desc_dim = self.desc.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(esm_dim + chem_dim + desc_dim, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, prot_inputs, chem_inputs, desc_inputs):
        # if frozen, wrapping in no_grad saves memory; if you unfreeze, remove the context manager
        with torch.no_grad():
            prot_emb = self.esm(**prot_inputs).last_hidden_state[:, 0, :]   # [CLS]
            chem_emb = self.chem(**chem_inputs).last_hidden_state[:, 0, :]
            desc_emb = self.desc(**desc_inputs).last_hidden_state[:, 0, :]

        x = torch.cat([prot_emb, chem_emb, desc_emb], dim=1)
        logits = self.classifier(x).squeeze(1)
        return logits




In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FusionModel(esm_model, chem_model, desc_model, hidden_size=256, freeze_backbones=True).to(device)


In [7]:
from torch.utils.data import DataLoader
import torch

def collate_fn(batch):
    prot_batch = {k: torch.cat([item[0][k] for item in batch], dim=0) for k in batch[0][0]}
    chem_batch = {k: torch.cat([item[1][k] for item in batch], dim=0) for k in batch[0][1]}
    desc_batch = {k: torch.cat([item[2][k] for item in batch], dim=0) for k in batch[0][2]}
    labels = torch.stack([item[3] for item in batch])
    return prot_batch, chem_batch, desc_batch, labels

BATCH_SIZE = 64  # drop to 8/4 if you hit OOM

train_dataset = ProteinMolDataset(train_df, esm_tokenizer, chem_tokenizer, desc_tokenizer)
val_dataset   = ProteinMolDataset(val_df,   esm_tokenizer, chem_tokenizer, desc_tokenizer)
test_dataset  = ProteinMolDataset(test_df,  esm_tokenizer, chem_tokenizer, desc_tokenizer)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)



In [8]:
import torch.nn as nn
import torch

# class weights from TRAIN ONLY
pos = (train_df['label'] == 1).sum()
neg = (train_df['label'] == 0).sum()
pos_weight_value = neg / pos  # BCEWithLogitsLoss wants weight for the positive class
print("pos_weight:", pos_weight_value)

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_value], device=device))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


pos_weight: 0.1737002426206206


In [9]:
import os
import re
import torch
from tqdm import tqdm
from torchmetrics.classification import BinaryAUROC
from torch.cuda.amp import autocast, GradScaler

# ----------------- Config -----------------
EPOCHS = 3
SAVE_DIR = '/content/drive/MyDrive/ProteinMO/checkpoints'
BEST_PATH = os.path.join(SAVE_DIR, 'best_model.pt')
SAVE_EVERY_STEPS = 5000   # save mid-epoch
ACCUM_STEPS = 8           # gradient accumulation steps
os.makedirs(SAVE_DIR, exist_ok=True)

# Mixed precision scaler
scaler = GradScaler()

# ----------------- Checkpoint Utils -----------------
CKPT_REGEX = re.compile(r'checkpoint_epoch(\d+)_step(\d+)\.pt')

def find_latest_ckpt(save_dir: str):
    last_epoch, last_step = -1, -1
    last_path = None
    for fname in os.listdir(save_dir):
        m = CKPT_REGEX.match(fname)
        if not m:
            continue
        epoch_i, step_i = int(m.group(1)), int(m.group(2))
        if (epoch_i > last_epoch) or (epoch_i == last_epoch and step_i > last_step):
            last_epoch, last_step = epoch_i, step_i
            last_path = os.path.join(save_dir, fname)
    return last_path, last_epoch, last_step

def save_ckpt(path, epoch, step_in_epoch, global_step, best_val_auc, model, optimizer, scaler):
    torch.save({
        'epoch': epoch,
        'step_in_epoch': step_in_epoch,
        'global_step': global_step,
        'best_val_auc': best_val_auc,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'scaler_state': scaler.state_dict(),
    }, path)

def load_ckpt(path, model, optimizer, scaler, device):
    ckpt = torch.load(path, map_location=device)
    model.load_state_dict(ckpt['model_state'])
    optimizer.load_state_dict(ckpt['optimizer_state'])
    scaler.load_state_dict(ckpt['scaler_state'])
    return (ckpt['epoch'], ckpt['step_in_epoch'], ckpt['global_step'], ckpt.get('best_val_auc', 0.0))

# ----------------- Train / Eval -----------------
def train_one_epoch(model, loader, optimizer, criterion, device,
                    epoch, start_step_in_epoch=0, global_step=0):
    model.train()
    total_loss = 0.0
    step_in_epoch = 0

    optimizer.zero_grad()

    for step_in_epoch, (prot_inputs, chem_inputs, desc_inputs, labels) in enumerate(
        tqdm(loader, desc=f"Train E{epoch}", leave=False), start=1):

        # Skip already-completed steps when resuming
        if step_in_epoch <= start_step_in_epoch:
            continue

        prot_inputs = {k: v.to(device) for k, v in prot_inputs.items()}
        chem_inputs = {k: v.to(device) for k, v in chem_inputs.items()}
        desc_inputs = {k: v.to(device) for k, v in desc_inputs.items()}
        labels = labels.to(device)

        with autocast():
            logits = model(prot_inputs, chem_inputs, desc_inputs)
            loss = criterion(logits, labels) / ACCUM_STEPS  # scale for accumulation

        scaler.scale(loss).backward()
        total_loss += loss.item() * labels.size(0) * ACCUM_STEPS

        if (step_in_epoch % ACCUM_STEPS) == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1

            if global_step % (SAVE_EVERY_STEPS // ACCUM_STEPS) == 0:
                ckpt_path = os.path.join(SAVE_DIR, f"checkpoint_epoch{epoch}_step{step_in_epoch}.pt")
                save_ckpt(ckpt_path, epoch, step_in_epoch, global_step, best_val_auc, model, optimizer, scaler)
                print(f"[E{epoch} S{step_in_epoch}] checkpoint -> {ckpt_path}")

    avg_loss = total_loss / len(loader.dataset)
    return avg_loss, step_in_epoch, global_step

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    auroc = BinaryAUROC().to(device)

    with torch.no_grad():
        for prot_inputs, chem_inputs, desc_inputs, labels in tqdm(loader, desc="Eval", leave=False):
            prot_inputs = {k: v.to(device) for k, v in prot_inputs.items()}
            chem_inputs = {k: v.to(device) for k, v in chem_inputs.items()}
            desc_inputs = {k: v.to(device) for k, v in desc_inputs.items()}
            labels = labels.to(device)

            with autocast():
                logits = model(prot_inputs, chem_inputs, desc_inputs)
                loss = criterion(logits, labels)

            total_loss += loss.item() * labels.size(0)
            auroc.update(torch.sigmoid(logits), labels.int())

    return total_loss / len(loader.dataset), auroc.compute().item()

# ----------------- Auto Resume -----------------
start_epoch, start_step_in_epoch, global_step, best_val_auc = 1, 0, 0, 0.0
latest_ckpt, _, _ = find_latest_ckpt(SAVE_DIR)
if latest_ckpt:
    print(f"Resuming from: {latest_ckpt}")
    start_epoch, start_step_in_epoch, global_step, best_val_auc = load_ckpt(latest_ckpt, model, optimizer, scaler, device)
    print(f" -> epoch={start_epoch}, step_in_epoch={start_step_in_epoch}, global_step={global_step}, best_val_auc={best_val_auc}")
else:
    print("No checkpoint found. Starting fresh.")

# ----------------- Train Loop -----------------
for epoch in range(start_epoch, EPOCHS + 1):
    train_loss, last_step_in_epoch, global_step = train_one_epoch(
        model, train_loader, optimizer, criterion, device,
        epoch, start_step_in_epoch=start_step_in_epoch, global_step=global_step
    )
    start_step_in_epoch = 0  # reset after resuming

    val_loss, val_auc = evaluate(model, val_loader, criterion, device)
    print(f"[Epoch {epoch}/{EPOCHS}] train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_auc={val_auc:.4f}")

    end_ckpt_path = os.path.join(SAVE_DIR, f"checkpoint_epoch{epoch}_step{last_step_in_epoch}.pt")
    save_ckpt(end_ckpt_path, epoch, last_step_in_epoch, global_step, best_val_auc, model, optimizer, scaler)
    print(f"Epoch checkpoint saved -> {end_ckpt_path}")

    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), BEST_PATH)
        print(f"✓ New best AUC={best_val_auc:.4f}. Saved -> {BEST_PATH}")

# ----------------- Final Test -----------------
if os.path.exists(BEST_PATH):
    model.load_state_dict(torch.load(BEST_PATH, map_location=device))
    print("Loaded best model for test eval.")

test_loss, test_auc = evaluate(model, test_loader, criterion, device)
print(f"TEST loss={test_loss:.4f} | auc={test_auc:.4f}")


  scaler = GradScaler()


Resuming from: /content/drive/MyDrive/ProteinMO/checkpoints/checkpoint_epoch3_step7696.pt
 -> epoch=3, step_in_epoch=7696, global_step=5000, best_val_auc=0.812853991985321


  with autocast():
Train E3:  79%|███████▊  | 12696/16154 [56:18<41:25,  1.39it/s]

[E3 S12696] checkpoint -> /content/drive/MyDrive/ProteinMO/checkpoints/checkpoint_epoch3_step12696.pt


  with autocast():


[Epoch 3/3] train_loss=0.0712 | val_loss=0.1671 | val_auc=0.8128
Epoch checkpoint saved -> /content/drive/MyDrive/ProteinMO/checkpoints/checkpoint_epoch3_step16154.pt
Loaded best model for test eval.


                                                         

TEST loss=0.1551 | auc=0.8110




In [13]:
import pandas as pd
import torch
from tqdm import tqdm

# 🔹 Load your original test dataset
test_df = pd.read_csv("/content/drive/MyDrive/ProteinMO/Dataset/Cleaned/test.csv")

# 🔹 Make predictions
model.eval()
all_probs = []
all_labels = []

with torch.no_grad():
    for prot_inputs, chem_inputs, desc_inputs, labels in tqdm(test_loader, desc="Predicting"):
        prot_inputs = {k: v.to(device) for k, v in prot_inputs.items()}
        chem_inputs = {k: v.to(device) for k, v in chem_inputs.items()}
        desc_inputs = {k: v.to(device) for k, v in desc_inputs.items()}
        labels = labels.to(device)

        with autocast():
            logits = model(prot_inputs, chem_inputs, desc_inputs)
            probs = torch.sigmoid(logits)

        all_probs.extend(probs.cpu().numpy().tolist())
        all_labels.extend(labels.cpu().numpy().tolist())

# 🔹 Sanity check: lengths must match
assert len(test_df) == len(all_probs), "Mismatch between test_df and predictions!"

# 🔹 Append to original DataFrame
test_df["predicted_prob"] = all_probs
test_df["true_label"] = all_labels

# 🔹 Save
test_df.to_csv("/content/drive/MyDrive/ProteinMO/test_with_predictions.csv", index=False)
print("✅ Saved: test_with_predictions.csv")


  with autocast():
Predicting: 100%|██████████| 1674/1674 [07:45<00:00,  3.60it/s]


✅ Saved: test_with_predictions.csv


In [17]:
from sklearn.metrics import classification_report

y_prob = test_df["predicted_prob"].values
y_true = test_df["true_label"].values

for threshold in [0.5, 0.4, 0.3]:
    y_pred = (y_prob >= threshold).astype(int)
    print(f"\n🔎 Threshold = {threshold}")
    print(classification_report(y_true, y_pred, digits=4))



🔎 Threshold = 0.5
              precision    recall  f1-score   support

         0.0     0.3300    0.7388    0.4562     14848
         1.0     0.9475    0.7585    0.8425     92250

    accuracy                         0.7558    107098
   macro avg     0.6387    0.7486    0.6494    107098
weighted avg     0.8619    0.7558    0.7890    107098


🔎 Threshold = 0.4
              precision    recall  f1-score   support

         0.0     0.3881    0.6425    0.4839     14848
         1.0     0.9357    0.8370    0.8836     92250

    accuracy                         0.8100    107098
   macro avg     0.6619    0.7397    0.6837    107098
weighted avg     0.8598    0.8100    0.8282    107098


🔎 Threshold = 0.3
              precision    recall  f1-score   support

         0.0     0.4313    0.5217    0.4722     14848
         1.0     0.9203    0.8893    0.9045     92250

    accuracy                         0.8383    107098
   macro avg     0.6758    0.7055    0.6884    107098
weighted avg     

In [14]:
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve, confusion_matrix, ConfusionMatrixDisplay

# Load predictions
df = pd.read_csv("/content/drive/MyDrive/ProteinMO/test_with_predictions.csv")
y_true = df["true_label"]
y_prob = df["predicted_prob"]

# Create a directory for plots
os.makedirs("/content/drive/MyDrive/ProteinMO/plots", exist_ok=True)

# 1. Histogram of predicted probabilities
plt.figure()
plt.hist(y_prob, bins=50, alpha=0.7, color='blue')
plt.title("Histogram of Predicted Probabilities")
plt.xlabel("Predicted Probability")
plt.ylabel("Frequency")
plt.savefig("/content/drive/MyDrive/ProteinMO/plots/histogram_probs.png")
plt.close()

# 2. ROC Curve
fpr, tpr, _ = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.4f}")
plt.plot([0, 1], [0, 1], 'k--')
plt.title("ROC Curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.savefig("/content/drive/MyDrive/ProteinMO/plots/roc_curve.png")
plt.close()

# 3. Precision-Recall Curve
precision, recall, _ = precision_recall_curve(y_true, y_prob)
plt.figure()
plt.plot(recall, precision)
plt.title("Precision-Recall Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.savefig("/content/drive/MyDrive/ProteinMO/plots/pr_curve.png")
plt.close()

# 4. Confusion Matrix at threshold=0.5
y_pred = (y_prob >= 0.5).astype(int)
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.title("Confusion Matrix (Threshold = 0.5)")
plt.savefig("/content/drive/MyDrive/ProteinMO/plots/confusion_matrix.png")
plt.close()

print("✅ Saved plots to /ProteinMO/plots/")


✅ Saved plots to /ProteinMO/plots/


In [23]:
readme_path = "/content/drive/MyDrive/ProteinMO/README.md"

readme_content = """# 🧪 Protein-Molecule Binding Prediction

This project uses a multimodal deep learning model to predict protein–molecule interactions using protein sequences, molecule SMILES, and textual descriptions.

## 📁 Dataset

- Source: BindingDB + UniProt
- Size:
  - Training samples: `964,884`
  - Validation samples: `120,000`
  - Test samples: 107,098

## 🧠 Model Architecture

- Three encoders (frozen):
  - Protein: ESM2 (e.g., esm2_t6_8M_UR50D)
  - Molecule: ChemBERTa
  - Description: BioBERT
- Fusion: concatenation → Linear → ReLU → Dropout → Linear
- Loss: Binary Cross Entropy **with class weights** (because the datset is heavily labeled "1")
- Optimization: Adam with gradient accumulation
- Mixed precision training enabled

## 📊 Evaluation Metrics

| Threshold | Accuracy | F1 (Class 0) | F1 (Class 1) | Macro F1 | Weighted F1 | AUC |
|-----------|----------|--------------|--------------|----------|-------------|------|
| **0.5**   | 0.7558   | 0.4562       | 0.8425       | 0.6494   | 0.7890      | 0.8129 |
| **0.4**   | 0.8100   | 0.4839       | 0.8836       | 0.6837   | 0.8282      |        |
| **0.3**   | 0.8383   | 0.4722       | 0.9045       | 0.6884   | 0.8446      |        |

## 📈 Visualizations

All plots are saved under `/plots/`:

- [x] Histogram of predicted probabilities
- [x] ROC Curve
- [x] Precision–Recall Curve
- [x] Confusion Matrix

![ROC Curve](plots/roc_curve.png)
![PR Curve](plots/pr_curve.png)

## 🛠 Training Info

- Trained on Google Colab (T4 / A100)
- Epochs: 3
- Batch size: 32–64 with accumulation
- Checkpoints every 5000 steps with resume support

## 📎 Files

- `test_with_predictions.csv`: includes predicted probability and ground truth
- `checkpoints/`: saved models
- `README.md`: project overview

## 🔮 Future Work

- Unfreeze transformer layers for fine-tuning
- Try attention-based fusion
- Evaluate on external datasets

---

## 💡 Citation

If you use this work, please consider citing or referencing it in your own projects. This model aims to support downstream tasks in drug discovery.
"""

# Save to README.md
with open(readme_path, "w") as f:
    f.write(readme_content)

print("✅ README.md created at:", readme_path)


✅ README.md created at: /content/drive/MyDrive/ProteinMO/README.md
