# Curriculum Finetuning + UMAP Embedding Demo (Colab Friendly)

This notebook demonstrates a compact, end-to-end workflow:
1. Load the ESKAPE genomic features dataset.
2. Prepare labels for curricular classification (using `contig_id`).
3. Plot UMAP embeddings before and after a short finetuning run.

The demo uses small samples to keep runtime reasonable on Colab.


## Setup and Installation

If you are running in Google Colab, install the required packages below.
If you are running locally from a checked-out repo where `prokbert` is already available, you can skip this cell.


In [None]:
# Core dependencies for the demo
!pip -q install git+https://github.com/nbrg-ppcu/prokbert.git umap-learn seaborn datasets transformers


## Imports and Seed

We use a fixed seed for reproducibility, but the UMAP projection can still vary slightly across runs.


In [None]:
import os
import numpy as np
import torch
from datasets import ClassLabel, load_dataset
from sklearn.metrics import accuracy_score
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
)

from prokbert.curriculum_utils import plot_umap_embeddings
from prokbert.models import ProkBertForCurricularClassification


## GPU Check (Colab)

To enable GPU in Colab: `Runtime` -> `Change runtime type` -> `GPU`.


In [None]:
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))


## Configuration

Adjust these values to trade off speed vs. quality.


In [None]:
model_name = "neuralbioinfo/prokbert-mini-long"
dataset_name = "neuralbioinfo/ESKAPE-genomic-features"
dataset_split = "ESKAPE"
output_dir = "./curriculum_umap_demo"

seed = 42
max_samples = 4000  # total samples for train/eval/test
num_train_epochs = 1
train_batch_size = 16
eval_batch_size = 16
max_length = 256


## Helper Functions

These mirror the logic used in the Python example, with a focus on clarity.


In [None]:
def resolve_columns(dataset):
    sequence_candidates = ("segment", "sequence", "seq")
    label_candidates = ("contig_id", "class_label", "label", "labels", "y")

    sequence_col = next((c for c in sequence_candidates if c in dataset.column_names), None)
    label_col = next((c for c in label_candidates if c in dataset.column_names), None)

    if sequence_col is None:
        raise ValueError("No sequence column found. Expected one of: segment, sequence, seq.")
    if label_col is None:
        raise ValueError("No label column found. Expected one of: contig_id, class_label, label, labels, y.")

    return sequence_col, label_col


def encode_labels(dataset, label_col):
    if not isinstance(dataset.features[label_col], ClassLabel):
        dataset = dataset.class_encode_column(label_col)

    label_feature = dataset.features[label_col]
    id2label = {i: name for i, name in enumerate(label_feature.names)}
    label2id = {name: i for i, name in id2label.items()}
    return dataset, id2label, label2id


def tokenize_dataset(dataset, tokenizer, sequence_col, max_length):
    def tokenize_function(examples):
        tokenized = tokenizer(
            examples[sequence_col],
            truncation=True,
            max_length=max_length,
        )
        tokenized["labels"] = examples["labels"]
        if "sequence_id" in examples:
            tokenized["sequence_id"] = examples["sequence_id"]
        return tokenized

    remove_columns = [c for c in dataset.column_names if c not in ("labels", "sequence_id")]
    num_proc = min(os.cpu_count() or 1, 8)
    return dataset.map(tokenize_function, batched=True, remove_columns=remove_columns, num_proc=num_proc)


def compute_metrics(eval_pred):
    logits = eval_pred.predictions if hasattr(eval_pred, "predictions") else eval_pred[0]
    labels = eval_pred.label_ids if hasattr(eval_pred, "label_ids") else eval_pred[1]
    preds = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, preds)}


## Load and Prepare the Dataset

We keep only the sequence and label columns, then add a `sequence_id` for UMAP grouping.


In [None]:
set_seed(seed)

dataset = load_dataset(dataset_name, split=dataset_split)
dataset = dataset.shuffle(seed=seed)

if max_samples:
    dataset = dataset.select(range(min(max_samples, len(dataset))))

sequence_col, label_col = resolve_columns(dataset)

dataset = dataset.remove_columns(
    [col for col in dataset.column_names if col not in (sequence_col, label_col)]
)

dataset = dataset.add_column("sequence_id", list(range(len(dataset))))

dataset, id2label, label2id = encode_labels(dataset, label_col)
if label_col != "labels":
    dataset = dataset.rename_column(label_col, "labels")

dataset


## Train/Eval/Test Split


In [None]:
split = dataset.train_test_split(test_size=0.2, seed=seed)
temp = split["test"].train_test_split(test_size=0.5, seed=seed)
train_ds = split["train"]
eval_ds = temp["train"]
test_ds = temp["test"]


## Tokenization


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

train_ds = tokenize_dataset(train_ds, tokenizer, sequence_col, max_length)
eval_ds = tokenize_dataset(eval_ds, tokenizer, sequence_col, max_length)
test_ds = tokenize_dataset(test_ds, tokenizer, sequence_col, max_length)


## Model Setup and UMAP (Before Training)

We plot embeddings from a shuffled sample of the training set (up to 1,000 examples).


In [None]:
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
model_dtype = torch.bfloat16 if use_bf16 else torch.float32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ProkBertForCurricularClassification.from_pretrained(
    model_name,
    curricular_num_labels=len(id2label),
    curricular_face_m=0.5,
    curricular_face_s=64.0,
    classification_dropout_rate=0.1,
    curriculum_hidden_size=128,
    torch_dtype=model_dtype,
    id2label=id2label,
    label2id=label2id,
)
model = model.to(device)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

plot_umap_embeddings(
    model,
    train_ds,
    data_collator,
    output_dir,
    "umap_before_training.png",
    eval_batch_size,
    seed,
)


## Training

We run a short finetuning pass to keep this demo fast.


In [None]:
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    report_to="none",
    evaluation_strategy="epoch",
    save_strategy="no",
    logging_steps=25,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    num_train_epochs=num_train_epochs,
    load_best_model_at_end=False,
    bf16=use_bf16,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()


## UMAP (After Training)


In [None]:
plot_umap_embeddings(
    model,
    train_ds,
    data_collator,
    output_dir,
    "umap_after_training.png",
    eval_batch_size,
    seed,
)


## Visualize the Plots


In [None]:
from IPython.display import Image, display

print("Before training")
display(Image(filename=os.path.join(output_dir, "umap_before_training.png")))

print("After training")
display(Image(filename=os.path.join(output_dir, "umap_after_training.png")))
