# Pretrain and Create Model for Classification Based Tasks

In [1]:
from stFormer.classifier.Classifier import Classifier

## 1.1 Classify From Pretrained Model

We take out Subtype based information to evaluate classification fine-tuning and evaluation

1. **Data Loading & Splitting**  
   - Load `train_ds` from `dataset_path`.  
   - If `eval_dataset_path` provided, load `eval_ds`;  
     otherwise do a `train_test_split(test_size, seed=42)`.

2. **`model_init` Function**  
   - Loads base model & config from `model_checkpoint`.  
   - Overrides `num_labels` to match `self.label_mapping`.  
   - Optionally freezes the first `self.freeze_layers` encoder layers.
   - Adds a classification head onto BERT pretreained model if loading from masked learning objective

3. **Tokenizer & Data Collator**  
   - `AutoTokenizer.from_pretrained(...)` with `padding="max_length"`  
   - `DataCollatorWithPadding` to pad to `tokenizer.model_max_length`.

4. **Classification**
    - `Evaluation metrics` compute metrics to determine training/test loss and accuracy
    - `training args` takes dictionary of BERT training arguments for hyperparameter selection and model updating

5.  **Best Checkpoint Selection and Saving**
    - Saves model checkpoints to output directory based upon ``eval strategy` 
    - Returns final `trainer` model and saves final model to `output_directory`


In [None]:
classifier = Classifier(
    metadata_column = 'Tissue',
    mode='spot',
    classifier_type = 'sequence', #for class predictions
    token_dictionary_file='output/spot/token_dictionary.pickle',
    rare_threshold=0.1, #remove rare data types (less than 10% of samples)
    max_examples_per_class=10000, #option to downsample
    nproc=24,
)
ds_path, map_path = classifier.prepare_data(
    input_data = 'annotated.dataset/',
    output_directory = 'tmp/clasifier',
    output_prefix = 'Tissue_Classifier',
    )

In this example we utilize the model that was trained with a masked learning objective. While this is definitely possible, we suggest utilizing another Bert model that was trained using a classification task and then fine-tune on specific task

In [None]:
trainer = classifier.train(
    model_checkpoint='output/spot/spot_model', # pretrained model path
    dataset_path = ds_path, # dataset path from prepare data
    output_directory = 'output/models/classification', #output evaluation 
    test_size=0.2, # splits dataset into test/train splits
)

Using model checkpoint: output/spot/spot_model
Number of labels from data: 3
Label mapping: {'Brain': 0, 'Breast': 1, 'Skin': 2}
Linear(in_features=256, out_features=3, bias=True)
Max label: tensor(2)
Label mapping size: 3


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
100,1.0837,1.001198,0.539333,0.416546,0.536904,0.43858
200,0.7867,0.569715,0.796833,0.823148,0.796179,0.780779
300,0.6895,0.648199,0.687167,0.729824,0.686503,0.641001
400,0.5867,0.423288,0.868833,0.878322,0.868952,0.868393
500,0.4903,0.976297,0.581667,0.611228,0.579323,0.547443
600,0.5053,0.524849,0.801667,0.823287,0.800729,0.798036
700,0.3652,0.249833,0.940833,0.942768,0.940911,0.940833
800,0.33,0.908484,0.667667,0.694001,0.66742,0.620839
900,0.2918,0.830796,0.706167,0.740024,0.704866,0.680967
1000,0.248,0.150817,0.967167,0.967391,0.967184,0.967184


## 1.2 Train Gene Classifier

In [2]:
from stFormer.classifier.Classifier import Classifier
import pandas as pd
import pickle
import numpy as np
import os

We are replicating publication analysis by training a classifier to predict responsive genes in TNBC
1. load genes upregulated in response to neoadjuvent care in TNBC
2. load list of random shuffled genes as background
2. load ensembl to gene_name mapping dictionary
3. create dictionary for respnder and random genes
4. Run gene classification for predictions of responsive genes in dataset

In [4]:
os.chdir('analyses/models.to.test/Extended.model/')

In [5]:
file1 = "upregulated.top300"
file2 = "gene.shuffled.upregulated"
genes_responder = list(np.loadtxt(file1,dtype=str))
genes_random = list(np.loadtxt(file2, dtype=str))

training_args = {"num_train_epochs": 30.0, "weight_decay": 0.25, "learning_rate": 3e-6, "warmup_steps":1500, "lr_scheduler_type": "polynomial"}

### GeneClassifier Token Classification Overview

We take out per-cell Subtype labels and instead classify **individual genes (tokens)** within each sequence.

1. **Data Loading & Splitting**  
   - Load `train_ds` from `dataset_path`.  
   - If `eval_dataset_path` is provided, load `eval_ds`.  
     Otherwise, perform `train_test_split(test_size, seed=42)`.

2. **Label Mapping (Gene Classes)**  
   - Use `classifier_utils.label_classes("gene", ...)` to map each **input token (gene)** to a class label.  
   - Generates a per-token `labels` field matching `input_ids` shape.

3. **Tokenizer & Data Collator**   
   - Uses `DataCollatorForGeneClassification` to pad both `input_ids` and `labels` in sync.

4. **Model Initialization**  
   - Loads base model & config from `model_checkpoint`.  
   - Creates a `TokenClassification` head on the pretrained model.

5. **Classification Training**  
   - Computes token-level metrics (e.g., F1 score, accuracy).  

6. **Best Checkpoint Selection and Saving**  
   - Saves model checkpoints to output directory based on `evaluation_strategy`.  
   - Final model and tokenizer are saved to `output_directory`.  
   - Predictions and evaluation metrics are returned for downstream analysis.

In [20]:

ray_config = {"num_train_epochs": [1.0,],
"learning_rate": (1e-3, 1e-2),
"weight_decay": (0.01, 0.05),
"lr_scheduler_type": ["linear", "cosine", "polynomial"],
"warmup_steps": (5, 50),
"seed": (100, 1000),
"per_device_train_batch_size": [10,],
}


In [None]:
gene_class_dict = {'Responder': genes_responder,'Random.genes': genes_random}

# 2) Instantiate for token-classification
gene_classifier = Classifier(
    metadata_column=None,             # no sequence-level label
    mode='extended',
    gene_class_dict=gene_class_dict, #specify this dictionary to use Gene Classifier
    classifier_type = 'gene',
    freeze_layers=4,                  # freeze first two BERT layers (optional)
    forward_batch_size=50,
    max_examples=10_000,
    nproc=16,
    token_dictionary_file='SpatialModel/new_token_dictionary.pickle'
)

# 3) Prepare your dataset (must already contain `input_ids` for each cell)
ds_path, map_path = gene_classifier.prepare_data(
    input_data="STFormer_TNBC_neighbor.dataset",
    output_directory="tmp/gene_classifier",
    output_prefix="gene_classifier"
)


In [None]:
trainer = gene_classifier.train(
    model_checkpoint="run-8eb93bdf/checkpoint-1000",
    dataset_path=ds_path,
    output_directory="tmp/gene_classifier",
)

In [None]:
metrics = gene_classifier.evaluate(
    model_directory="models/visium_gene_classifier/final_model",
    eval_dataset_path=ds_path,
    id_class_dict_file=map_path,
    output_directory="output/gene_classifier",
)

## 1.3 Train and Evaluate Model with Hyperparameter search

In this example we utilize ray configuration to loop through a list of hyperparameters to search for the best configuration of arguments for a classification task.


Performs end-to-end hyperparameter search for a sequence-classification head using Ray Tune and Hugging Face Trainer.
1. **Define Hyperparameter Search Space**  
   - Pull ranges/choices from `self.ray_config` for  
     `learning_rate`, `num_train_epochs`, `weight_decay`, etc.  

2. **CLI Reporter**  
   - `CLIReporter` shows per-trial metrics (`eval_loss`, `eval_accuracy`)  
     and hyperparameter values in the console.

3. **Trainer & Hyperparameter Search**  
   - Instantiate `Trainer` with `model_init`, datasets, collator, and `compute_metrics`.  
   - Run `trainer.hyperparameter_search(...)` with Ray backend and `HyperOptSearch`.

4. **Best Checkpoint Selection & Saving**  
    - Use `ExperimentAnalysis` to find best trial/checkpoint by `eval_loss`.  
    - Load that checkpoint into a fresh `BertForSequenceClassification`.  
    - Save model & tokenizer under `output_directory/best_model`.

Load Datasets and Test/Train Split

In [1]:
from stFormer.classifier.Classifier import Classifier
import stFormer.classifier.classifier_utils as cu
from datasets import load_from_disk
import pandas as pd
import numpy as np
import random

random.seed(123)

In [None]:

ds = load_from_disk('annotated.dataset')
#ds_filt = cu.remove_rare(ds,rare_threshold=0.05,nproc=24,state_key='Tissue')

train1 = pd.read_csv('data/train1.csv').dropna()
test1 = pd.read_csv('data/test1.csv').dropna()
train2 = pd.read_csv('data/train2.csv').dropna()
test2 = pd.read_csv('data/test2.csv').dropna()

train_samples = np.unique(train1['Sample'].tolist())
test_samples =  np.unique(test1['Sample'].tolist())
ds_train = ds.filter(lambda ex: ex['Sample ID'] in train_samples,num_proc = 24)
ds_test = ds.filter(lambda ex: ex['Sample ID'] in test_samples,num_proc=24)



Set up hyperparameters, classification information, and prepare dataset for classification
1. For more hyperparameter options, please visit our docs: <https://cancerstformer.readthedocs.io/en/latest/>

In [None]:

hyperparameters ={
    "learning_rate":[1e-5,1e-3],
    "weight_decay": [0.0, 0.3],
    "warmup_ratio": [0,0.3]
    #'lr_scheduler_type': ["linear","cosine","polynomial"], 
    #'per_device_train_batch_size': [32]
    }

classifier = Classifier(
    metadata_column = 'Tissue',
    mode='spot',
    ray_config = hyperparameters,
    token_dictionary_file='output/spot/token_dictionary.pickle',
    nproc=24,
)

In [None]:
ds_path, map_path = classifier.prepare_data(
    input_data = ds_train, #takes Dataset Object or dataset file path
    output_directory = 'tmp_eval', #filtered dataset out location
    output_prefix = 'train_tissue' 
)

# 2) Prepare the  data exactly the same way
eval_ds_path, eval_map_path = classifier.prepare_data(
    input_data       = ds_test,
    output_directory = 'tmp_eval',
    output_prefix    = 'eval_tissue')

Train the model

In [None]:
trainer = classifier.train(
    model_checkpoint='output/models/tissue_classification/best_model', # pretrained model path
    dataset_path = ds_path, # dataset path from prepare data
    output_directory = 'output/eval_tissue_nohyperopt', #output evaluation 
    eval_dataset = eval_ds_path,
    n_trials = 4
    )


## 1.3 Plot Predictions using Evaluation Utils

Utilize seaborn, truth, and predicted values to create a confusion matrix and plot results

In [1]:
from stFormer.classifier.Classifier import Classifier
from datasets import load_from_disk
from sklearn.metrics import confusion_matrix
import pickle
import os

In [None]:

#Produce & save raw predictions
eval_ds = load_from_disk(ds_path).shuffle(seed=42).select(range(1000))
preds = trainer.predict(eval_ds)
y_true = preds.label_ids
y_pred = preds.predictions.argmax(-1)

with open("output/models/classification/predictions.pkl", "wb") as f:
    pickle.dump({"y_true": y_true, "y_pred": y_pred}, f)

In [13]:
import numpy as np
with open('output/eval_tissue_nohyperopt/predictions.pkl','rb') as f:
    preds = pickle.load(f)
y_true = preds.label_ids
y_pred = preds.predictions.argmax(-1)

map_path = 'tmp_eval/train_tissue_id_class_dict.pkl'

In [14]:

# Load the id→class mapping you dumped in prepare_data()
with open(map_path, "rb") as f:
    id_map = pickle.load(f)       

# We need a list of class names in label‐index order:
inv_map = {v:k for k,v in id_map.items()}
class_order = [inv_map[i] for i in range(len(inv_map))]

cm = confusion_matrix(y_true, y_pred, labels=list(id_map.values()))

In [None]:
import matplotlib.pyplot as plt 
import seaborn as sns
plt.figure(figsize=(8, 8))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_order, yticklabels=class_order, cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

In [None]:
#save heatmap with inbuilt plotting functionality
classifier.plot_predictions(
    predictions_file="output/models/classification/predictions.pkl",
    id_class_dict_file=map_path,
    title="Visium Spot Subtype Predictions",
    output_directory="output/models/classification",
    output_prefix="visium_spot",
    class_order=class_order
)