In [None]:
!pip install --force-reinstall transformers datasets evaluate scikit-learn accelerate --no-build-isolation
!pip uninstall -y torch torchvision torchaudio
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip install causal-conv1d==1.4.0 && pip install mamba-ssm==2.2.2

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from collections import Counter
import numpy as np
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
df = pd.read_csv("clinvar_sequence_disease_clean.csv")
print("Dataset shape:", df.shape)
print("Columns:", df.columns.tolist())
print("Sample rows:")

print("Test dataset shape:", df.shape)

print(df.head())

In [None]:
from collections import Counter

all_labels = [d for x in df['disease_labels'] for d in x.split(",")]
label_counts = Counter(all_labels)

top_labels = [d for d, _ in label_counts.most_common(100)]
label2id = {d: i for i, d in enumerate(top_labels)}
id2label = {i: d for d, i in label2id.items()}
num_labels = len(top_labels)

print("Number of diseases kept:", num_labels)
print("Most common ones:", list(label2id.keys())[:10])

In [None]:
def encode_labels(label_str):
    y = [0] * num_labels
    for d in label_str.split(","):
        if d in label2id:
            y[label2id[d]] = 1
    return y

df["label_vec"] = df["disease_labels"].apply(encode_labels)

In [None]:
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42, stratify=None)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=None)

print(f"Train size: {len(train_df)}")
print(f"Val size: {len(val_df)}")
print(f"Test size: {len(test_df)}")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model_name = "kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name)
backbone = AutoModel.from_pretrained(model_name).to(device)

print(f"Backbone hidden size: {backbone.config.d_model}")
print(f"Backbone vocab size: {backbone.config.vocab_size}")

In [None]:
class DNADataset(Dataset):
    def __init__(self, sequences, labels, tokenizer, max_len=512):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        label = self.labels[idx]

        if len(seq) > self.max_len:
            seq = seq[:self.max_len]

        try:
            encoding = self.tokenizer(
                seq,
                truncation=True,
                padding='max_length',
                max_length=self.max_len,
                return_tensors='pt'
            )

            input_ids = encoding['input_ids'].squeeze()

            if 'attention_mask' in encoding:
                attention_mask = encoding['attention_mask'].squeeze()
            else:
                pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
                attention_mask = (input_ids != pad_token_id).long()

        except Exception as e:
            DNA_VOCAB = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}

            token_ids = [DNA_VOCAB.get(char.upper(), 4) for char in seq]

            if len(token_ids) < self.max_len:
                original_len = len(token_ids)
                token_ids += [4] * (self.max_len - len(token_ids))
                attention_mask = [1] * original_len + [0] * (self.max_len - original_len)
            else:
                token_ids = token_ids[:self.max_len]
                attention_mask = [1] * self.max_len

            input_ids = torch.tensor(token_ids, dtype=torch.long)
            attention_mask = torch.tensor(attention_mask, dtype=torch.long)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(label, dtype=torch.float)
        }

In [None]:
MAX_LEN = 512
BATCH_SIZE = 8

train_ds = DNADataset(train_df["sequence"].tolist(), train_df["label_vec"].tolist(), tokenizer, MAX_LEN)
val_ds = DNADataset(val_df["sequence"].tolist(), val_df["label_vec"].tolist(), tokenizer, MAX_LEN)
test_ds = DNADataset(test_df["sequence"].tolist(), test_df["label_vec"].tolist(), tokenizer, MAX_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [None]:
print("\nTesting backbone model compatibility...")
sample_batch = next(iter(train_loader))
ids = sample_batch["input_ids"][:2].to(device)
mask = sample_batch["attention_mask"][:2].to(device)

backbone_call_method = None
try:
    print("Testing backbone(input_ids=ids, attention_mask=mask)...")
    outputs = backbone(input_ids=ids, attention_mask=mask)
    backbone_call_method = "keyword_args"
    print("Keyword arguments work!")
except Exception as e:
    print(f"Keyword args failed: {e}")

    try:
        print("Testing backbone(ids)...")
        outputs = backbone(ids)
        backbone_call_method = "input_ids_only"
        print("Input IDs only works!")
    except Exception as e:
        print(f"Input IDs only failed: {e}")

        try:
            print("Testing backbone(ids, mask)...")
            outputs = backbone(ids, mask)
            backbone_call_method = "positional_args"
            print("Positional arguments work!")
        except Exception as e:
            print(f"All methods failed: {e}")
            raise

print("\nDetecting actual backbone output dimensions...")
with torch.no_grad():
    if backbone_call_method == "keyword_args":
        outputs = backbone(input_ids=ids, attention_mask=mask)
    elif backbone_call_method == "input_ids_only":
        outputs = backbone(ids)
    else:
        outputs = backbone(ids, mask)

    if hasattr(outputs, 'last_hidden_state'):
        hidden = outputs.last_hidden_state
    elif isinstance(outputs, tuple):
        hidden = outputs[0]
    else:
        hidden = outputs

    # Pool to get final dimension
    pooled = hidden.mean(dim=1)  # Average pooling
    actual_hidden_size = pooled.shape[-1]

print(f"Config says d_model: {backbone.config.d_model}")
print(f"Actual output size: {actual_hidden_size}")
print(f"Hidden state shape: {hidden.shape}")
print(f"Pooled shape: {pooled.shape}")

In [None]:
class DiseaseClassifier(nn.Module):
    def __init__(self, backbone, num_labels, actual_hidden_size, backbone_call_method, dropout_rate=0.3):
        super().__init__()
        self.backbone = backbone
        self.backbone_call_method = backbone_call_method
        self.dropout = nn.Dropout(dropout_rate)

        # Use the ACTUAL hidden size, not config
        hidden_size = actual_hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.LayerNorm(hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 4, num_labels)
        )

        self._init_weights()

    def _init_weights(self):
        for module in self.classifier:
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_normal_(module.weight)
                torch.nn.init.zeros_(module.bias)

    def forward(self, input_ids, attention_mask=None):
        if self.backbone_call_method == "keyword_args":
            outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        elif self.backbone_call_method == "input_ids_only":
            outputs = self.backbone(input_ids)
        else:
            outputs = self.backbone(input_ids, attention_mask)

        if hasattr(outputs, 'last_hidden_state'):
            hidden = outputs.last_hidden_state
        elif isinstance(outputs, tuple):
            hidden = outputs[0]
        else:
            hidden = outputs

        if attention_mask is not None and self.backbone_call_method != "input_ids_only":
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden.size()).float()
            sum_hidden = torch.sum(hidden * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
            pooled = sum_hidden / sum_mask
        else:
            pooled = hidden.mean(dim=1)

        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        return logits

In [None]:
model = DiseaseClassifier(backbone, num_labels, actual_hidden_size, backbone_call_method).to(device)
print(f"\nFixed model created with actual hidden size: {actual_hidden_size}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

try:
    with torch.no_grad():
        test_logits = model(ids, mask)
        print(f"SUCCESS! Output shape: {test_logits.shape}")
        print(f"Expected shape: ({ids.shape[0]}, {num_labels})")
except Exception as e:
    print(f"Model forward pass failed: {e}")
    raise

In [None]:
def evaluate(model, loader, device):
    model.eval()
    all_probs, all_labels = [], []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            logits = model(ids, mask)
            probs = torch.sigmoid(logits).cpu().numpy()

            all_probs.append(probs)
            all_labels.append(labels.cpu().numpy())

    all_probs = np.vstack(all_probs)
    all_labels = np.vstack(all_labels)

    best_f1 = 0
    best_threshold = 0.5

    for thresh in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:
        bin_preds = (all_probs >= thresh).astype(int)
        f1 = f1_score(all_labels, bin_preds, average="micro", zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = thresh

    bin_preds = (all_probs >= best_threshold).astype(int)

    return {
        "threshold": best_threshold,
        "f1_micro": f1_score(all_labels, bin_preds, average="micro", zero_division=0),
        "f1_macro": f1_score(all_labels, bin_preds, average="macro", zero_division=0),
        "precision": precision_score(all_labels, bin_preds, average="macro", zero_division=0),
        "recall": recall_score(all_labels, bin_preds, average="macro", zero_division=0),
        "accuracy": (bin_preds == all_labels).mean()
    }

In [None]:
all_labels = []
for batch in train_loader:
    all_labels.append(batch["labels"])
all_labels = torch.cat(all_labels, dim=0).float()

pos_counts = all_labels.sum(dim=0)
neg_counts = all_labels.size(0) - pos_counts
pos_weight = torch.clamp(neg_counts / (pos_counts + 1e-5), min=0.1, max=10.0).to(device)

print(f"Positive weights range: {pos_weight.min():.2f} - {pos_weight.max():.2f}")

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)

num_epochs = 10
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

print(f"\nTraining setup:")
print(f"  Epochs: {num_epochs}")
print(f"  Steps per epoch: {len(train_loader)}")
print(f"  Total steps: {total_steps}")
print(f"  Backbone call method: {backbone_call_method}")
print(f"  Actual hidden size: {actual_hidden_size}")

best_val_f1 = 0

print("\n" + "="*60)
print("STARTING TRAINING - ALL ISSUES FIXED!")
print("="*60)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    num_batches = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in progress_bar:
        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device).float()

        optimizer.zero_grad()

        logits = model(ids, mask)
        loss = criterion(logits, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        num_batches += 1

        progress_bar.set_postfix({
            'loss': f'{total_loss/num_batches:.4f}',
            'lr': f'{scheduler.get_last_lr()[0]:.2e}'
        })

    print(f"\nEpoch {epoch+1} - Average Loss: {total_loss/num_batches:.4f}")

    val_metrics = evaluate(model, val_loader, device)
    print(f"Validation: {val_metrics}")

    if val_metrics["f1_micro"] > best_val_f1:
        best_val_f1 = val_metrics["f1_micro"]
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"New best F1: {best_val_f1:.4f}")

print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)

test_metrics = evaluate(model, test_loader, device)
print(f"Test metrics: {test_metrics}")
print(f"F1:     {test_metrics['f1_micro']:.1%}")

print("="*60)