# Pretain Hugging Face Dataset using Masked Learning Objective
Look over dataset tokenization tutorial prior to running this code, which will give you some of the prerequisites you need:
1. token_dictionary
2. Hugging Face Tokenized Dataset
3. Example Lengths

If using one of our pretrained models, this is an uneccessary step and you will just need to tokenize your own h5ad or looom file

In [None]:
from stFormer.pretrain.stFormer_pretrainer import run_pretraining

## 1.1 Create Example Lengths File

This file specifies the number of tokens (genes) in each spot in your tokenized data. The maximum value should be specified by max_length in tokenization process (truncated tokens)

In [None]:
from datasets import load_from_disk
import pickle
ds = load_from_disk('output/spot/visium_spot.dataset')
lengths = [len(example['input_ids']) for example in ds]
with open('output/example_lengths.pickle','wb') as file:
    pickle.dump(lengths,file)

## 1.2 Run Pretraining using BERT Framework


1. **Masking Objective**  
   - Randomly mask out a fraction of tokens (genes) in each sequence.  
   - Task: Predict the original gene ID at each masked position.  
   - Loss: Cross-entropy between predicted token distribution and true gene ID.


   - Learns rich, unsupervised representations of spatial gene expression patterns.  
   - Captures co-expression and neighborhood relationships without labels in neighbor mode.  
   - Provides strong initialization for downstream tasks (e.g., cell-type classification).

3. **Configuration**  
   - A standard BERT-style architecture (hidden size, layers, heads, etc.).  
   - Dropout, layer-norm, and positional embeddings adapted for gene sequences.  
   - Grouped batching by sequence length for efficient GPU utilization.

4. **Training Loop**  
   - Iterates over masked sequences, computing MLM loss.  
   - Periodic checkpointing of model weights.  
   - Final model and tokenizer saved for later fine-tuning or inference.

In [None]:
run_pretraining(
   dataset_path='output/spot/visium_spot.dataset',
   token_dict_path='output/token_dictionary.pickle',
   example_lengths_path='output/example_lengths.pickle',
   rootdir='output/spot',
   seed=42,
   num_layers=6,
   num_heads=4,
   embed_dim=256,
   max_input=2048,
   batch_size=12,
   lr=1e-3,
   warmup_steps=10000,
   epochs=3,
   weight_decay=0.001
)

## 1.3 Extract Embeddings
This module provides helper functions and a high-level class for turning
tokenized gene sequences into fixed-size embedding vectors using a
pretrained transformer model.


Encapsulates the end-to-end process of:
1. **Loading**  
   - A pretrained model from `model_directory` (with `output_hidden_states=True`)  
   - A HuggingFace disk‐based dataset
   - A token dictionary (gene ↔ token ID mapping)
2. **Batching**  
   - Iterating in chunks of `forward_batch_size`  
   - Extracting `input_ids` and their lengths  
3. **Preprocessing**  
   - Applying `pad_sequences` to each batch  
   - Generating the `attention_mask`
4. **Model Forward Pass**  
   - Running the model in `eval()` mode without gradients  
   - Gathering all hidden states
6. **Saving**  
   - Concatenate all batch embeddings into one tensor of shape `(N, hidden_dim)`  
   - Write to disk as `output_prefix + ".pt"`

In [None]:
from stFormer.tokenization.embedding_extractor import EmbeddingExtractor
from pathlib import Path

In [None]:
extractor = EmbeddingExtractor(
    token_dict_path=Path('output/token_dictionary.pickle'),
    emb_mode='cls',
    emb_layer = -1,
    forward_batch_size=64
    )
embeddins = extractor.extract_embs(
    model_directory='output/spot/models/250422_102707_stFormer_L6_E3/final',
    dataset_path='output/spot/visium_spot.dataset',
    output_directory='output/spot/embeddings',
    output_prefix='visium_spot'

)