# Pretrain and Create Model for Classification Based Tasks

In [None]:
from stFormer.classifier.Classifier import GenericClassifier

## 1.1 Classify From Pretrained Model

We take out Subtype based information to evaluate classification fine-tuning and evaluation

1. Loads Pretrained Masked Learning Objective Model From earlier
2. Adds a classification head onto Bert Model
3. splits data into train/test or option to evaluate on seperate dataset
4. performs training for classification on metadata column
5. evaluates on held out data
6. lastly, opiton for ray tuning and hyperparameter search for best model

In [None]:
classifier = GenericClassifier(
    metadata_column = 'subtype',
    nproc=24)
    
ds_path, map_path = classifier.prepare_data(
    input_data_file = 'output/spot/visium_spot.dataset',
    output_directory = 'tmp',
    output_prefix = 'visium_spot'
    )

In this example we utilize the model that was trained with a masked learning objective. While this is definitely possible, we suggest utilizing another Bert model that was trained using a classification task and then fine-tune on specific task

In [None]:
trainer = classifier.train(
    model_checkpoint='output/spot/models/250422_102707_stFormer_L6_E3/final', # pretrained model path
    dataset_path = ds_path, # dataset path from prepare data
    output_directory = 'output/models/classification', #output evaluation 
    test_size=0.2, # splits dataset into test/train splits
    evaluation_dataset = None # set path to outside dataset for external validation instead of test/train
)

## 1.2 Train and Evaluate Model with Hyperparameter search

In [None]:
classifier = GenericClassifier(
    metadata_column = 'subtype',
    ray_config={
        "learning_rate":[1e-5,5e-5], #loguniform learning rate
        "num_train_epochs":[2,3], #choice
        "weight_decay": [0.0, 0.3], #tune.uniform across values
        'lr_scheduler_type': ["linear","cosine","polynomial"], #scheduler
        'seed':[0,100]
        },
    nproc = 24
    )

In [None]:
ds_path, map_path = classifier.prepare_data(
    input_data_file = 'output/spot/visium_spot.dataset',
    output_directory = 'tmp',
    output_prefix = 'visium_spot'
    )

In [None]:
best_run = classifier.train(
    model_checkpoint='output/spot/models/250422_102707_stFormer_L6_E3/final',
    dataset_path = ds_path,
    output_directory = 'output/models/tuned_classification',
    n_trials=10,
    test_size=0.2,
    #stratify=True
)

## 1.3 Plot Predictions using Evaluation Utils

In [None]:
from stFormer.classifier.Classifier import GenericClassifier
from datasets import load_from_disk
from sklearn.metrics import confusion_matrix
import pickle
import os

#Produce & save raw predictions
eval_ds = load_from_disk(ds_path).shuffle(seed=42).select(range(1000))
preds = trainer.predict(eval_ds)
y_true = preds.label_ids
y_pred = preds.predictions.argmax(-1)

with open("output/models/classification/predictions.pkl", "wb") as f:
    pickle.dump({"y_true": y_true, "y_pred": y_pred}, f)


# Load the id→class mapping you dumped in prepare_data()
with open(map_path, "rb") as f:
    id_map = pickle.load(f)       

# We need a list of class names in label‐index order:
inv_map = {v:k for k,v in id_map.items()}
class_order = [inv_map[i] for i in range(len(inv_map))]

cm = confusion_matrix(y_true, y_pred, labels=list(id_map.values()))

In [None]:
import matplotlib.pyplot as plt 
import seaborn as sns
plt.figure(figsize=(8, 8))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_order, yticklabels=class_order, cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

In [None]:
#save heatmap with inbuilt plotting functionality
classifier.plot_predictions(
    predictions_file="output/models/classification/predictions.pkl",
    id_class_dict_file=map_path,
    title="Visium Spot Subtype Predictions",
    output_directory="output/models/classification",
    output_prefix="visium_spot",
    class_order=class_order
)