# Pretrain and Create Model for Classification Based Tasks

In [None]:
from stFormer.classifier.Classifier import GenericClassifier

## 1.1 Classify From Pretrained Model

We take out Subtype based information to evaluate classification fine-tuning and evaluation

1. **Data Loading & Splitting**  
   - Load `train_ds` from `dataset_path`.  
   - If `eval_dataset_path` provided, load `eval_ds`;  
     otherwise do a `train_test_split(test_size, seed=42)`.

2. **`model_init` Function**  
   - Loads base model & config from `model_checkpoint`.  
   - Overrides `num_labels` to match `self.label_mapping`.  
   - Optionally freezes the first `self.freeze_layers` encoder layers.
   - Adds a classification head onto BERT pretreained model if loading from masked learning objective

3. **Tokenizer & Data Collator**  
   - `AutoTokenizer.from_pretrained(...)` with `padding="max_length"`  
   - `DataCollatorWithPadding` to pad to `tokenizer.model_max_length`.

4. **Classification**
    - `Evaluation metrics` compute metrics to determine training/test loss and accuracy
    - `training args` takes dictionary of BERT training arguments for hyperparameter selection and model updating

5.  **Best Checkpoint Selection and Saving**
    - Saves model checkpoints to output directory based upon ``eval strategy` 
    - Returns final `trainer` model and saves final model to `output_directory`


In [None]:
classifier = GenericClassifier(
    metadata_column = 'subtype',
    nproc=24)
    
ds_path, map_path = classifier.prepare_data(
    input_data_file = 'output/spot/visium_spot.dataset',
    output_directory = 'tmp',
    output_prefix = 'visium_spot'
    )

In this example we utilize the model that was trained with a masked learning objective. While this is definitely possible, we suggest utilizing another Bert model that was trained using a classification task and then fine-tune on specific task

In [None]:
trainer = classifier.train(
    model_checkpoint='output/spot/models/250422_102707_stFormer_L6_E3/final', # pretrained model path
    dataset_path = ds_path, # dataset path from prepare data
    output_directory = 'output/models/classification', #output evaluation 
    test_size=0.2, # splits dataset into test/train splits
    evaluation_dataset = None # set path to outside dataset for external validation instead of test/train
)

## 1.2 Train and Evaluate Model with Hyperparameter search

In this example we utilize ray configuration to loop through a list of hyperparameters to search for the best configuration of arguments for a classification task.


Performs end-to-end hyperparameter search for a sequence-classification head using Ray Tune and Hugging Face Trainer.
1. **Define Hyperparameter Search Space**  
   - Pull ranges/choices from `self.ray_config` for  
     `learning_rate`, `num_train_epochs`, `weight_decay`, etc.  

2. **CLI Reporter**  
   - `CLIReporter` shows per-trial metrics (`eval_loss`, `eval_accuracy`)  
     and hyperparameter values in the console.

3. **Trainer & Hyperparameter Search**  
   - Instantiate `Trainer` with `model_init`, datasets, collator, and `compute_metrics`.  
   - Run `trainer.hyperparameter_search(...)` with Ray backend and `HyperOptSearch`.

4. **Best Checkpoint Selection & Saving**  
    - Use `ExperimentAnalysis` to find best trial/checkpoint by `eval_loss`.  
    - Load that checkpoint into a fresh `BertForSequenceClassification`.  
    - Save model & tokenizer under `output_directory/best_model`.

In [None]:
classifier = GenericClassifier(
    metadata_column = 'subtype',
    ray_config={
        "learning_rate":[1e-5,5e-5], #loguniform learning rate
        "num_train_epochs":[2,3], 
        "weight_decay": [0.0, 0.3], #tune.uniform across values
        'lr_scheduler_type': ["linear","cosine","polynomial"], #scheduler
        'seed':[0,100]
        },
    nproc = 24
    )

In [None]:
ds_path, map_path = classifier.prepare_data(
    input_data_file = 'output/spot/visium_spot.dataset',
    output_directory = 'tmp',
    output_prefix = 'visium_spot'
    )

In [None]:
best_run = classifier.train(
    model_checkpoint='output/spot/models/250422_102707_stFormer_L6_E3/final',
    dataset_path = ds_path,
    output_directory = 'output/models/tuned_classification',
    n_trials=10,
    test_size=0.2,
    #stratify=True
)

## 1.3 Plot Predictions using Evaluation Utils

Utilize seaborn, truth, and predicted values to create a confusion matrix and plot results

In [None]:
from stFormer.classifier.Classifier import GenericClassifier
from datasets import load_from_disk
from sklearn.metrics import confusion_matrix
import pickle
import os

#Produce & save raw predictions
eval_ds = load_from_disk(ds_path).shuffle(seed=42).select(range(1000))
preds = trainer.predict(eval_ds)
y_true = preds.label_ids
y_pred = preds.predictions.argmax(-1)

with open("output/models/classification/predictions.pkl", "wb") as f:
    pickle.dump({"y_true": y_true, "y_pred": y_pred}, f)


# Load the id→class mapping you dumped in prepare_data()
with open(map_path, "rb") as f:
    id_map = pickle.load(f)       

# We need a list of class names in label‐index order:
inv_map = {v:k for k,v in id_map.items()}
class_order = [inv_map[i] for i in range(len(inv_map))]

cm = confusion_matrix(y_true, y_pred, labels=list(id_map.values()))

In [None]:
import matplotlib.pyplot as plt 
import seaborn as sns
plt.figure(figsize=(8, 8))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_order, yticklabels=class_order, cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

In [None]:
#save heatmap with inbuilt plotting functionality
classifier.plot_predictions(
    predictions_file="output/models/classification/predictions.pkl",
    id_class_dict_file=map_path,
    title="Visium Spot Subtype Predictions",
    output_directory="output/models/classification",
    output_prefix="visium_spot",
    class_order=class_order
)