In [None]:
# !pip install --force-reinstall transformers datasets evaluate scikit-learn accelerate --no-build-isolation

In [None]:
import os
notebook_path = os.path.abspath("trainer.ipynb")
train_csv = os.path.join(os.path.dirname(notebook_path), "datasets/clinvar_sequence_disease_clean.csv")

# from google.colab import files
# uploaded = files.upload()

In [None]:
# !pip uninstall -y torch torchvision torchaudio
# !pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
# !pip install causal-conv1d==1.4.0 && pip install mamba-ssm==2.2.2

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import evaluate
import numpy as np

In [None]:
df = pd.read_csv("clinvar_sequence_disease_clean.csv")

df = df.dropna(subset=["sequence", "disease_labels"])
df["disease_labels"] = df["disease_labels"].astype(str)

df["disease_labels"] = df["disease_labels"].apply(
    lambda x: [d.strip() for d in x.split(",") if d.strip()]
)

df = df[df["disease_labels"].map(len) > 0]

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

mlb = MultiLabelBinarizer()
train_labels = mlb.fit_transform(train_df["disease_labels"])
test_labels = mlb.transform(test_df["disease_labels"])

In [None]:
class SequenceDiseaseDataset(Dataset):
    def __init__(self, sequences, labels, tokenizer, max_length=512):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = str(self.sequences[idx])
        label = self.labels[idx]
        encoding = self.tokenizer(
          seq,
          truncation=True,
          padding="max_length",
          max_length=self.max_length,
          return_tensors="pt",
          return_attention_mask=True
        )
        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "labels": torch.tensor(label, dtype=torch.float)
        }

In [None]:
model_name = "kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_dataset = SequenceDiseaseDataset(train_df["sequence"].tolist(), train_labels, tokenizer)
test_dataset = SequenceDiseaseDataset(test_df["sequence"].tolist(), test_labels, tokenizer)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(mlb.classes_),
    problem_type="multi_label_classification"
)

In [None]:
metric_f1 = evaluate.load("f1")
metric_accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    labels = labels.astype(int)
    f1 = metric_f1.compute(predictions=preds, references=labels, average="micro")["f1"]
    acc = metric_accuracy.compute(predictions=preds, references=labels)["accuracy"]
    return {"f1": f1, "accuracy": acc}

In [None]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    save_total_limit=1,
    report_to="none",
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

In [None]:
model.save_pretrained("caduceus_clinvar_multilabel")
tokenizer.save_pretrained("caduceus_clinvar_multilabel")
pd.Series(mlb.classes_).to_csv("disease_labels_mapping.csv", index=False)