In [None]:
!pip install --force-reinstall transformers datasets evaluate scikit-learn accelerate --no-build-isolation

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
!pip uninstall -y torch torchvision torchaudio
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip install causal-conv1d==1.4.0 && pip install mamba-ssm==2.2.2

In [None]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from transformers import AutoModel, AutoTokenizer

In [None]:
df = pd.read_csv("clinvar_sequence_disease_clean.csv")

print("Sample rows:")
print(df.head())

In [None]:
from collections import Counter

all_labels = [d for x in df['disease_labels'] for d in x.split(",")]
label_counts = Counter(all_labels)

top_labels = [d for d, _ in label_counts.most_common(500)]
label2id = {d:i for i,d in enumerate(top_labels)}
id2label = {i:d for d,i in label2id.items()}
num_labels = len(top_labels)

print("Number of diseases kept:", num_labels)
print("Most common ones:", list(label2id.keys())[:10])

def encode_labels(label_str):
    y = [0]*num_labels
    for d in label_str.split(","):
        if d in label2id:  
            y[label2id[d]] = 1
    return y

df["label_vec"] = df["disease_labels"].apply(encode_labels)


In [None]:
from sklearn.model_selection import train_test_split

df_small = df.sample(n=5000, random_state=42)

train_df, temp_df = train_test_split(df_small, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print("Train size:", len(train_df))
print("Validation size:", len(val_df))
print("Test size:", len(test_df))
print(train_df.head())


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

DNA_VOCAB = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}

def tokenize_dna(sequence, max_len):
    ids = [DNA_VOCAB.get(base, 4) for base in sequence]
    
    if len(ids) < max_len:
        ids += [4] * (max_len - len(ids))
    else:
        ids = ids[:max_len]
    return torch.tensor(ids, dtype=torch.long)

def create_attention_mask(input_ids):
    return (input_ids != 4).long()

class DNADataset(Dataset):
    def __init__(self, sequences, labels, max_len=512):
        self.sequences = sequences
        self.labels = labels
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        label = self.labels[idx]

        input_ids = tokenize_dna(seq, self.max_len)
        attention_mask = create_attention_mask(input_ids)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(label, dtype=torch.float)
        }


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name)
backbone = AutoModel.from_pretrained(model_name).to(device)


In [None]:
MAX_LEN = 512

train_ds = DNADataset(train_df["sequence"].tolist(),
                      train_df["label_vec"].tolist(),
                      max_len=MAX_LEN)

val_ds = DNADataset(val_df["sequence"].tolist(),
                    val_df["label_vec"].tolist(),
                    max_len=MAX_LEN)

test_ds = DNADataset(test_df["sequence"].tolist(),
                     test_df["label_vec"].tolist(),
                     max_len=MAX_LEN)

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=2)
test_loader = DataLoader(test_ds, batch_size=2)

sample = train_ds[0]
print("Keys in dataset sample:", sample.keys())
print("Input IDs:", sample["input_ids"])
print("Attention mask:", sample["attention_mask"])
print("Labels:", sample["labels"])

In [None]:
class DiseaseClassifier(nn.Module):
    def __init__(self, backbone, hidden_size, num_labels):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_labels)
        )

    def forward(self, input_ids, attention_mask=None):
        outputs = self.backbone(input_ids)

        hidden = (
            outputs.last_hidden_state
            if hasattr(outputs, "last_hidden_state")
            else outputs
        )

        pooled = hidden.mean(dim=1)
        return self.classifier(pooled)

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy as np

def evaluate(loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            logits = model(ids, mask)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_preds.append(probs)
            all_labels.append(labels.cpu().numpy())

    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)

    bin_preds = (all_preds >= 0.5).astype(int)

    metrics = {}
    metrics["accuracy"]  = (bin_preds == all_labels).mean()
    metrics["f1_macro"]  = f1_score(all_labels, bin_preds, average="macro", zero_division=0)
    metrics["f1_micro"]  = f1_score(all_labels, bin_preds, average="micro", zero_division=0)
    metrics["precision"] = precision_score(all_labels, bin_preds, average="macro", zero_division=0)
    metrics["recall"]    = recall_score(all_labels, bin_preds, average="macro", zero_division=0)

    return metrics


In [None]:
import torch
from torch.optim import AdamW
import torch.nn as nn

hidden_size = backbone.config.d_model
model = DiseaseClassifier(backbone, hidden_size, num_labels).to(device)

all_labels = []

for batch in train_loader:
    all_labels.append(batch["labels"])
all_labels = torch.cat(all_labels, dim=0).float()

pos_counts = all_labels.sum(dim=0)
neg_counts = all_labels.size(0) - pos_counts

pos_weight = neg_counts / (pos_counts + 1e-5)
pos_weight = pos_weight.to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = AdamW(model.parameters(), lr=1e-4)

for epoch in range(3):
    model.train()
    for batch in train_loader:
        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device).float()

        optimizer.zero_grad()
        logits = model(input_ids=ids, attention_mask=mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

    val_metrics = evaluate(val_loader)
    print(f"Epoch {epoch+1}: {val_metrics}")

    batch = next(iter(val_loader))
    ids, mask, labels = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["labels"].to(device)

    with torch.no_grad():
        logits = model(input_ids=ids, attention_mask=mask)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).int()

    print("Predictions:", preds[:5].tolist())
    print("Labels:", labels[:5].tolist())


In [None]:
test_auc = evaluate(test_loader)
print("Final Test AUROC:", test_auc)