In [1]:
#!/usr/bin/env python3
"""
Adapted classifier using custom ProteinLM + advanced MLP with augmentation, PCA, and weighted sampling.
Achieves high accuracy by incorporating data augmentation, class balancing, and optimized training.
"""

import os
import time
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader, TensorDataset, WeightedRandomSampler
from torch.optim import AdamW
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import joblib  # for saving label encoder & PCA
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------------
# CONFIGURATION (from high-accuracy setup)
# -----------------------------
DATA_PATH = "/kaggle/input/datasetmlpclass/shuffled_standardized_add.xlsx"  # Adjusted for xlsx
SEQ_COL = "SubSequence"
LABEL_COL = "Clinical_Significance"
MAX_TOKEN_LENGTH = 512
BATCH_EMBED = 8
PCA_DIM = 512
TEST_SIZE = 0.2
RANDOM_SEED = 42

USE_AUGMENT = True
AUG_SHIFTS = [-2, 0, 2]
WINDOW_LEN = 50

HIDDEN_DIM_1 = 512
HIDDEN_DIM_2 = 256
DROPOUT = 0.4
LR = 1e-4
WEIGHT_DECAY = 1e-3
BATCH_TRAIN = 32
NUM_EPOCHS = 100
PATIENCE = 25
LR_PLATEAU_FACTOR = 0.5
LR_PLATEAU_PATIENCE = 5

USE_SAMPLER = True

# Custom ProteinLM checkpoint
CKPT_PATH = "/kaggle/input/progenmodel/final_ckpt.pt"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -----------------------------
# 1) Tokenizer / Vocabulary (Custom)
# -----------------------------
AMINO_ACIDS = list("ACDEFGHIKLMNPQRSTVWY")
SPECIAL_TOKENS = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"]
VOCAB = SPECIAL_TOKENS + AMINO_ACIDS
token2idx = {tok: idx for idx, tok in enumerate(VOCAB)}
idx2token = {idx: tok for tok, idx in token2idx.items()}
PAD_ID = token2idx["<PAD>"]
SOS_ID = token2idx["<SOS>"]
EOS_ID = token2idx["<EOS>"]
UNK_ID = token2idx["<UNK>"]
VOCAB_SIZE = len(VOCAB)

def tokenize(seq: str):
    ids = [SOS_ID]
    for ch in seq.strip():
        ids.append(token2idx.get(ch, UNK_ID))
    ids.append(EOS_ID)
    return ids

# -----------------------------
# 2) Load Data
# -----------------------------
df = pd.read_excel(DATA_PATH)
df.columns = [c.strip().replace(" ", "_") for c in df.columns]
assert SEQ_COL in df.columns and LABEL_COL in df.columns, "Required columns missing"
df = df.dropna(subset=[SEQ_COL, LABEL_COL]).reset_index(drop=True)
print("Total samples:", len(df))
print(df[LABEL_COL].value_counts())

sequences_orig = df[SEQ_COL].astype(str).tolist()
labels_text_orig = df[LABEL_COL].astype(str).tolist()

# -----------------------------
# 3) Augmentation (Sliding Windows)
# -----------------------------
def sliding_windows(seq, center_len=WINDOW_LEN, shifts=AUG_SHIFTS):
    L = len(seq)
    windows = []
    for s in shifts:
        mid = L//2 + s
        start = max(0, mid - center_len//2)
        end = start + center_len
        if end > L:
            end = L
            start = max(0, end - center_len)
        w = seq[start:end]
        if len(w) == center_len:
            windows.append(w)
    return windows

if USE_AUGMENT:
    sequences, labels_text = [], []
    for seq, lab in zip(sequences_orig, labels_text_orig):
        wins = sliding_windows(seq)
        if len(wins) == 0:
            s = seq[:WINDOW_LEN].ljust(WINDOW_LEN, 'A')  # Pad with 'A' if too short
            sequences.append(s); labels_text.append(lab)
        else:
            for w in wins:
                sequences.append(w); labels_text.append(lab)
    print("After augmentation: total samples =", len(sequences))
else:
    sequences = sequences_orig
    labels_text = labels_text_orig
    print("No augmentation: total samples =", len(sequences))

# -----------------------------
# 4) Label Encoding
# -----------------------------
label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(labels_text)
print("Classes:", label_encoder.classes_)

# -----------------------------
# 5) Custom ProteinLM (from previous)
# -----------------------------
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_head, attn_dropout=0.1):
        super().__init__()
        assert d_model % n_head == 0
        self.n_head = n_head
        self.d_head = d_model // n_head
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(attn_dropout)
    def forward(self, x, mask=None):
        B, T, C = x.size()
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
        causal_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)
        att = att.masked_fill(causal_mask == 0, float("-inf"))
        if mask is not None:
            mask2 = mask.unsqueeze(1).unsqueeze(2)
            att = att.masked_fill(mask2 == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        out = torch.matmul(att, v)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(out)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_head, attn_dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, d_ff, dropout)
    def forward(self, x, mask):
        x = x + self.attn(self.ln1(x), mask)
        x = x + self.ff(self.ln2(x))
        return x

class ProteinLM(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, d_model=256, nhead=8, num_layers=6, d_ff=1024, max_len=1024, dropout=0.1):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, nhead, d_ff, dropout) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.max_len = max_len
        self.d_model = d_model

    def forward(self, input_ids, attention_mask=None):
        B, T = input_ids.size()
        positions = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
        x = self.token_emb(input_ids) + self.pos_emb(positions)
        for layer in self.layers:
            x = layer(x, attention_mask)
        x = self.ln_f(x)
        return x  # (B, T, d_model)

# Load pre-trained custom ProteinLM
base_model = ProteinLM()
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
base_model.load_state_dict(ckpt["model_state"], strict=False)
base_model.to(DEVICE)
base_model.eval()

# -----------------------------
# 6) Embedding Extraction (Adapted for Custom Model)
# -----------------------------
@torch.no_grad()
def embed_batch(seqs, batch_size=BATCH_EMBED, max_length=MAX_TOKEN_LENGTH):
    all_embs = []
    base_model.eval()
    for i in range(0, len(seqs), batch_size):
        batch = seqs[i:i+batch_size]
        # Tokenize batch
        tokenized = []
        attn_masks = []
        for seq in batch:
            ids = tokenize(seq)
            if len(ids) > max_length:
                ids = ids[:max_length-1] + [EOS_ID]
            # Pad to max_length
            padded_ids = ids + [PAD_ID] * (max_length - len(ids))
            tokenized.append(torch.tensor(padded_ids, dtype=torch.long))
            mask = [1 if tok != PAD_ID else 0 for tok in padded_ids]
            attn_masks.append(torch.tensor(mask, dtype=torch.long))
        
        input_ids = torch.stack(tokenized).to(DEVICE)
        attn_mask = torch.stack(attn_masks).to(DEVICE)
        
        # Forward through custom model
        outputs = base_model(input_ids, attention_mask=attn_mask)
        mean_emb = outputs.mean(dim=1)  # Mean pooling
        all_embs.append(mean_emb.detach().cpu().float().numpy())
    return np.vstack(all_embs)

print("Generating embeddings...")
t0 = time.time()
embeddings = embed_batch(sequences, batch_size=BATCH_EMBED)
print(f"Embeddings shape: {embeddings.shape}   (took {time.time()-t0:.1f}s)")

# -----------------------------
# 7) PCA Dimensionality Reduction
# -----------------------------
if PCA_DIM is not None:
    pca = PCA(n_components=min(PCA_DIM, embeddings.shape[1]), random_state=RANDOM_SEED)  # Avoid expansion if needed
    embeddings = pca.fit_transform(embeddings)
    print("PCA reduced to:", embeddings.shape[1])

X = torch.tensor(embeddings, dtype=torch.float32)
y = torch.tensor(labels, dtype=torch.long)

# -----------------------------
# 8) Train/Test Split
# -----------------------------
X_train, X_test, y_train, y_test = train_test_split(
    X.numpy(), y.numpy(), test_size=TEST_SIZE, random_state=RANDOM_SEED, stratify=y.numpy()
)
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

# -----------------------------
# 9) DataLoaders with Weighted Sampling
# -----------------------------
train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)

class_counts = torch.bincount(y_train)
inv_freq = 1.0 / (class_counts.float() + 1e-9)
class_weights = inv_freq / inv_freq.sum()

if USE_SAMPLER:
    sample_weights = class_weights[y_train].cpu().numpy()
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=BATCH_TRAIN, sampler=sampler, drop_last=True)
else:
    train_loader = DataLoader(train_ds, batch_size=BATCH_TRAIN, shuffle=True, drop_last=True)

test_loader = DataLoader(test_ds, batch_size=BATCH_TRAIN, shuffle=False, drop_last=False)

# -----------------------------
# 10) Advanced MLP Classifier
# -----------------------------
input_dim = X_train.shape[1]
num_classes = len(label_encoder.classes_)

class MLP(nn.Module):
    def __init__(self, input_dim, hidden1=HIDDEN_DIM_1, hidden2=HIDDEN_DIM_2, num_classes=num_classes, dropout=DROPOUT):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden1),
            nn.BatchNorm1d(hidden1),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden1, hidden2),
            nn.BatchNorm1d(hidden2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden2, num_classes)
        )
    def forward(self, x):
        return self.net(x)

model_mlp = MLP(input_dim).to(DEVICE)

# -----------------------------
# 11) Training Setup
# -----------------------------
loss_fn = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE))
optimizer = AdamW(model_mlp.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=LR_PLATEAU_FACTOR,
                                                       patience=LR_PLATEAU_PATIENCE,
                                                       verbose=True)

best_val_loss = float("inf")
patience_counter = 0

def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            out = model(xb)
            loss = loss_fn(out, yb)
            total_loss += loss.item() * xb.size(0)
            preds = torch.argmax(out, dim=1)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(yb.cpu().numpy())
    avg_loss = total_loss / len(loader.dataset)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    acc = (all_preds == all_labels).mean()
    return avg_loss, acc, all_preds, all_labels

# -----------------------------
# 12) Training Loop with Early Stopping & LR Scheduling
# -----------------------------
print("Training...")
for epoch in range(1, NUM_EPOCHS+1):
    model_mlp.train()
    epoch_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        out = model_mlp(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * xb.size(0)
    avg_train_loss = epoch_loss / len(train_loader.dataset)
    val_loss, val_acc, _, _ = evaluate(model_mlp, test_loader)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
    scheduler.step(val_loss)
    if val_loss < best_val_loss - 1e-6:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model_mlp.state_dict(), "best_mlp_custom_adv.pth")
        print("  ✅ Saved best model.")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("⏹ Early stopping.")
            break

# -----------------------------
# 13) Final Evaluation
# -----------------------------
model_mlp.load_state_dict(torch.load("best_mlp_custom_adv.pth"))
val_loss, val_acc, val_preds, val_labels = evaluate(model_mlp, test_loader)
print(f"Final Val Acc: {val_acc*100:.2f}%")
print(classification_report(val_labels, val_preds, target_names=label_encoder.classes_))

# -----------------------------
# 14) Save for Inference (e.g., Flask) - WITH EXPLICIT PATHS & VERIFICATION
# -----------------------------
OUTPUT_DIR = "/kaggle/working"
torch.save(model_mlp.state_dict(), os.path.join(OUTPUT_DIR, "best_mlp_custom_adv.pth"))
joblib.dump(label_encoder, os.path.join(OUTPUT_DIR, "label_encoder.pkl"))
if PCA_DIM is not None:
    joblib.dump(pca, os.path.join(OUTPUT_DIR, "pca_model.pkl"))

# Verify saves
print("\nSaved files:")
for f in ["best_mlp_custom_adv.pth", "label_encoder.pkl", "pca_model.pkl"]:
    path = os.path.join(OUTPUT_DIR, f)
    exists = os.path.exists(path)
    size = os.path.getsize(path) if exists else 0
    print(f" - {f}: {'✅ Exists' if exists else '❌ Missing'} ({size} bytes)")

print("\n📂 Files saved for inference:")
print(" - best_mlp_custom_adv.pth (MLP weights)")
print(" - label_encoder.pkl (class mapping)")
if PCA_DIM is not None:
    print(" - pca_model.pkl (PCA reducer)")

print("\nYou can now load these for predictions on new sequences.")

# Final verification: List all files in working dir
print("\nAll files in /kaggle/working/:")
!ls -la /kaggle/working/

Device: cuda
Total samples: 891
Clinical_Significance
Pathogenic           263
Likely Benign        217
Likely Pathogenic    217
Benign               194
Name: count, dtype: int64
After augmentation: total samples = 2671
Classes: ['Benign' 'Likely Benign' 'Likely Pathogenic' 'Pathogenic']
Generating embeddings...
Embeddings shape: (2671, 256)   (took 13.5s)
PCA reduced to: 256




Training...
Epoch 1 | Train Loss: 1.3915 | Val Loss: 1.2726 | Val Acc: 45.42%
  ✅ Saved best model.
Epoch 2 | Train Loss: 1.2517 | Val Loss: 1.1697 | Val Acc: 53.83%
  ✅ Saved best model.
Epoch 3 | Train Loss: 1.1732 | Val Loss: 1.0893 | Val Acc: 59.63%
  ✅ Saved best model.
Epoch 4 | Train Loss: 1.0897 | Val Loss: 1.0288 | Val Acc: 62.99%
  ✅ Saved best model.
Epoch 5 | Train Loss: 1.0248 | Val Loss: 0.9642 | Val Acc: 65.98%
  ✅ Saved best model.
Epoch 6 | Train Loss: 0.9699 | Val Loss: 0.9092 | Val Acc: 68.22%
  ✅ Saved best model.
Epoch 7 | Train Loss: 0.9031 | Val Loss: 0.8583 | Val Acc: 71.03%
  ✅ Saved best model.
Epoch 8 | Train Loss: 0.8529 | Val Loss: 0.8012 | Val Acc: 73.27%
  ✅ Saved best model.
Epoch 9 | Train Loss: 0.8036 | Val Loss: 0.7586 | Val Acc: 76.07%
  ✅ Saved best model.
Epoch 10 | Train Loss: 0.7817 | Val Loss: 0.7087 | Val Acc: 77.76%
  ✅ Saved best model.
Epoch 11 | Train Loss: 0.7100 | Val Loss: 0.6696 | Val Acc: 79.25%
  ✅ Saved best model.
Epoch 12 | Train L