In [None]:
# =============================================================
# Tutorial (Colab) — Fine-tuning step by step (English, no custom wrappers)
# - No pair-task support (only single sequence classification: text,label)
# - Three modes: full fine-tuning, head-only, or LoRA
# - Expects train.csv, dev.csv and test.csv with format: text,label
# =============================================================

# If running on Google Colab, uncomment:
# !pip install -q transformers peft datasets accelerate scikit-learn

# 1) Imports and configuration

In [None]:
from google.colab import drive
!git clone https://github.com/fabianagoes/bc2_tutorial8.git
%cd bc2_tutorial8
drive.mount('/content/drive')

import os, csv, json, logging, random, warnings
from typing import List
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef, precision_score, recall_score
from datasets import Dataset
from peft import LoraConfig, get_peft_model
import pandas as pd
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    set_seed,
)

warnings.filterwarnings("ignore", message=".*torch.cuda.amp.autocast.*")
logging.basicConfig(level=logging.INFO)


# Hyperparameters

In [None]:
# zhihan1996/DNABERT-2-117M
# InstaDeepAI/nucleotide-transformer-500m-human-ref
model_name_or_path = "InstaDeepAI/nucleotide-transformer-500m-human-ref"  # example: "bert-base-uncased", "InstaDeepAI/nucleotide-transformer-v2-50m-1000g"

# Training strategy (choose ONE)
use_lora = False          # True for LoRA
train_head_only = False   # True to train only the classification head
# If both are False => full fine-tuning

# LoRA parameters (if use_lora=True)
lora_r = 8
lora_alpha = 32
lora_dropout = 0.05
lora_target_modules = "query,value"  # Adjust for your model, e.g. "q_proj,v_proj"

# Data
data_path = "/ceph/groups/aibds/fabiana/workspace1/bc2_tutorial/datasets/GUE/prom/prom_300_notata"  # must contain train.csv, dev.csv, test.csv

# Training args
run_name = "run"
output_dir = "/ceph/groups/aibds/fabiana/workspace1/bc2_tutorial/output"
model_max_length = 512
per_device_train_batch_size = 8
per_device_eval_batch_size = 16
num_train_epochs = 1
fp16 = False
save_strategy = "epoch"
evaluation_strategy = "epoch"
eval_steps = 100
warmup_steps = 50
weight_decay = 0.1
learning_rate = 1e-4
save_total_limit = None
load_best_model_at_end = True
seed = 42

# Save model and results
save_model = False
save_results = True

# 3) Seed and device

In [None]:
set_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# 4) Tokenizer

In [None]:
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=model_max_length,
    padding_side="right",
    use_fast=True,
    trust_remote_code=True,
)

if "InstaDeepAI" in model_name_or_path and tokenizer.eos_token is None and tokenizer.pad_token is not None:
    tokenizer.eos_token = tokenizer.pad_token

# 5) Read CSVs (single sequence classification)

In [None]:
print("\nReading data...")
print("Task: single sequence classification (text,label)")

def stratified_sample(df, frac, label_col="label", random_state=42):
    return (
        df.groupby(label_col, group_keys=False)
          .apply(lambda x: x.sample(frac=frac, random_state=random_state))
          .reset_index(drop=True)
    )


train_df = pd.read_csv(os.path.join(data_path, "train.csv"))
dev_df   = pd.read_csv(os.path.join(data_path, "dev.csv"))
test_df  = pd.read_csv(os.path.join(data_path, "test.csv"))

train_df = stratified_sample(train_df, 0.01)
dev_df   = stratified_sample(dev_df, 0.08)
test_df  = stratified_sample(test_df, 0.08)

print("\nAfter stratified sampling:")
print("Train size:", len(train_df))
print("Dev size:", len(dev_df))
print("Test size:", len(test_df))
print("\nClass distribution in train:")
print(train_df["label"].value_counts())

train_texts, train_labels = train_df["sequence"].tolist(), train_df["label"].astype(int).tolist()
dev_texts, dev_labels     = dev_df["sequence"].tolist(), dev_df["label"].astype(int).tolist()
test_texts, test_labels   = test_df["sequence"].tolist(), test_df["label"].astype(int).tolist()


# 6) Tokenization

In [None]:
print("\nTokenizing train/dev/test...")
enc_train = tokenizer(train_texts, truncation=True, padding=False, max_length=model_max_length)
enc_dev = tokenizer(dev_texts, truncation=True, padding=False, max_length=model_max_length)
enc_test = tokenizer(test_texts, truncation=True, padding=False, max_length=model_max_length)

# 7) Datasets

In [None]:
train_ds = Dataset.from_dict({
    "input_ids": enc_train["input_ids"],
    "attention_mask": enc_train["attention_mask"],
    "labels": train_labels,
})

dev_ds = Dataset.from_dict({
    "input_ids": enc_dev["input_ids"],
    "attention_mask": enc_dev["attention_mask"],
    "labels": dev_labels,
})

test_ds = Dataset.from_dict({
    "input_ids": enc_test["input_ids"],
    "attention_mask": enc_test["attention_mask"],
    "labels": test_labels,
})

num_labels = len(set(train_labels + dev_labels + test_labels))
print("num_labels:", num_labels)

# 8) Load base model

In [None]:
print("\nLoading model...")
model = AutoModelForSequenceClassification.from_pretrained(
    model_name_or_path,
    num_labels=num_labels,
    trust_remote_code=True,
)
model.to(device)

print("Num epochs:", num_train_epochs)
print("Save strategy:", save_strategy)
print("Load best model at end:", load_best_model_at_end)
print("Train head only:", train_head_only)
print("Use LoRA:", use_lora)

# 9) Training strategy

In [None]:
if train_head_only and use_lora:
    raise ValueError("Choose only ONE strategy: head-only or LoRA (or both False for full FT)")

if train_head_only:
    print("\n[Head-only] Freezing backbone and leaving only classifier trainable...")
    for name, param in model.named_parameters():
        param.requires_grad = False
    # Reactivate only classifier params
    head_substrings = ["classifier", "score"]
    unlocked = []
    for name, param in model.named_parameters():
        if any(s in name for s in head_substrings):
            param.requires_grad = True
            unlocked.append(name)
    if not unlocked:
        print("[WARNING] No classifier layer found (adjust head_substrings)")
    else:
        print("Reactivated classifier parameters:")
        for n in unlocked:
            print("  -", n)
elif use_lora:
    print("\n[LoRA] Applying efficient adaptation...")
    lconf = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=[s.strip() for s in lora_target_modules.split(",") if s.strip()],
        lora_dropout=lora_dropout,
        bias="none",
        task_type="SEQ_CLS",
        inference_mode=False,
    )
    model = get_peft_model(model, lconf)
    model.print_trainable_parameters()
else:
    print("\n[Full fine-tuning] All parameters will be updated.")

# 10) Data collator and TrainingArguments

In [None]:
collator = DataCollatorWithPadding(tokenizer=tokenizer)

args = TrainingArguments(
    output_dir=output_dir,
    run_name=run_name,
    optim="adamw_torch",
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    num_train_epochs=num_train_epochs,
    fp16=fp16,
    save_strategy=save_strategy,
    evaluation_strategy=evaluation_strategy,
    eval_steps=eval_steps if evaluation_strategy == "steps" else None,
    warmup_steps=warmup_steps,
    weight_decay=weight_decay,
    learning_rate=learning_rate,
    save_total_limit=save_total_limit,
    load_best_model_at_end=load_best_model_at_end,
    dataloader_pin_memory=False,
    seed=seed,
    remove_unused_columns=False,
)

# 11) Metrics function

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    preds = np.argmax(predictions, axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="macro", zero_division=0),
        "matthews_correlation": matthews_corrcoef(labels, preds),
        "precision": precision_score(labels, preds, average="macro", zero_division=0),
        "recall": recall_score(labels, preds, average="macro", zero_division=0),
    }

# 12) Trainer and training

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=dev_ds,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

print("\nStarting training...")
trainer.train()

# 13) Evaluate on test set

In [None]:
print("\nEvaluating on test set...")
results = trainer.evaluate(eval_dataset=test_ds)
print("\nMetrics (test):\n", json.dumps(results, indent=2))

# 14) Save results and model (optional)

In [None]:
if save_results:
    results_path = os.path.join(output_dir, "results", run_name)
    os.makedirs(results_path, exist_ok=True)
    with open(os.path.join(results_path, "eval_results.json"), "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved at: {os.path.join(results_path, 'eval_results.json')}")

if save_model:
    final_model_dir = os.path.join(output_dir, "final_model")
    trainer.save_model(final_model_dir)
    print(f"Model saved at: {final_model_dir}")
