# Tahoe-x1 Model Tutorial

Run this notebook on a colab notebook with a free GPU:

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/helicalAI/helical/blob/main/examples/notebooks/Tahoe-x1-Tutorial.ipynb)

This tutorial demonstrates how to use the Tahoe-x1 foundation model for single-cell RNA-seq data. Tahoe-x1 is a transformer-based model that can extract both cell and gene embeddings from raw count data.

**What you'll learn in this notebook:**
- How to load and configure the Tahoe-x1 model
- Processing single-cell RNA-seq data for Tahoe
- Extracting cell embeddings
- Extracting gene embeddings
- Visualizing embeddings with UMAP
- Extracting attention weights for interpretability

For more examples, check out our [GitHub](https://github.com/helicalAI/helical) and [documentation](https://helical.readthedocs.io/).

## Installation

Install or update Helical to get access to the Tahoe model:

In [None]:
!pip install helical --upgrade

## Imports and Setup

In [None]:
import logging
import warnings
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset

# Configure logging
logging.getLogger().setLevel(logging.INFO)
warnings.filterwarnings("ignore")

# Check device availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## Load Example Dataset

We'll use the human fetal yolk sac scRNA-seq dataset from Helical's Hugging Face repository:

In [None]:
# Load dataset from Hugging Face
dataset = load_dataset(
    "helical-ai/yolksac_human", 
    split="train[:10%]", 
    trust_remote_code=True, 
    download_mode="reuse_cache_if_exists"
)

# Store labels for visualization later
labels = dataset["LVL1"]

print(f"Loaded {len(dataset)} cells")

## Convert to AnnData Format

Tahoe works with AnnData objects, the standard format for single-cell data in Python:

In [None]:
from helical.utils import get_anndata_from_hf_dataset

ann_data = get_anndata_from_hf_dataset(dataset)
print(ann_data)

# For this tutorial, let's use a subset for faster processing
ann_data_subset = ann_data[:500]  # Use first 500 cells
labels_subset = labels[:500]
print(f"\nUsing subset: {ann_data_subset.n_obs} cells, {ann_data_subset.n_vars} genes")

## Initialize Tahoe Model

Tahoe comes in three sizes (70m, 1b, 3b). Currently, the 70m model is available. The model uses Flash Attention by default for efficient inference.

In [None]:
from helical.models.tahoe import Tahoe, TahoeConfig

# Configure the Tahoe model
tahoe_config = TahoeConfig(
    model_size="70m",  # 12-layer transformer with 512 embedding dimensions
    batch_size=8,      # Adjust based on your GPU memory
    device=device,
)

# Initialize the model (will download weights on first use)
tahoe = Tahoe(configurer=tahoe_config)

print("\nTahoe model loaded successfully!")

## Process Data

Tahoe requires gene names to be mapped to Ensembl IDs. The `process_data` method handles this automatically:

In [None]:
# Process data - this will map gene symbols to Ensembl IDs
dataloader = tahoe.process_data(
    ann_data_subset,
    gene_names="gene_name",  # Column containing gene symbols
    use_raw_counts=True
)

print("Data processed and ready for inference!")

## Extract Cell Embeddings

Cell embeddings capture the transcriptional state of each cell in a dense vector representation:

In [None]:
# Get cell embeddings
cell_embeddings = tahoe.get_embeddings(dataloader)

print(f"Cell embeddings shape: {cell_embeddings.shape}")
print(f"Each cell is represented by a {cell_embeddings.shape[1]}-dimensional vector")

## Visualize Cell Embeddings with UMAP

Let's visualize the cell embeddings in 2D using UMAP to see how cells cluster by cell type:

In [None]:
import umap
import seaborn as sns
import matplotlib.pyplot as plt

# Reduce dimensionality with UMAP
reducer = umap.UMAP(min_dist=0.1, n_components=2, n_neighbors=15, random_state=42)
umap_embedding = reducer.fit_transform(cell_embeddings)

# Create plot dataframe
plot_df = pd.DataFrame(umap_embedding, columns=['UMAP1', 'UMAP2'])
plot_df['Cell Type'] = labels_subset

# Plot
plt.figure(figsize=(12, 8))
sns.scatterplot(
    data=plot_df, 
    x='UMAP1', 
    y='UMAP2', 
    hue='Cell Type',
    palette='tab10',
    s=30,
    alpha=0.7
)
plt.title('UMAP Visualization of Tahoe Cell Embeddings', fontsize=14, fontweight='bold')
plt.xlabel('UMAP 1', fontsize=12)
plt.ylabel('UMAP 2', fontsize=12)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='Cell Type')
plt.tight_layout()
plt.show()

## Extract Gene Embeddings

Tahoe can also extract gene embeddings, which represent how each gene is expressed across all cells:

In [None]:
# Get both cell and gene embeddings
cell_embeddings, gene_embeddings = tahoe.get_embeddings(
    dataloader, 
    return_gene_embeddings=True
)

print(f"Cell embeddings shape: {cell_embeddings.shape}")
print(f"Gene embeddings shape: {gene_embeddings.shape}")
print(f"\nEach gene is represented by a {gene_embeddings.shape[1]}-dimensional vector")

## Visualize Gene Embeddings

Let's visualize a subset of gene embeddings to see how genes cluster based on their expression patterns:

In [None]:
# Remove genes with NaN embeddings (genes not present in the data)
valid_genes = ~np.isnan(gene_embeddings).any(axis=1)
gene_embeddings_valid = gene_embeddings[valid_genes]

print(f"Valid gene embeddings: {gene_embeddings_valid.shape[0]} genes")

# Visualize a subset of genes with UMAP
n_genes_to_plot = min(1000, gene_embeddings_valid.shape[0])
gene_subset_idx = np.random.choice(gene_embeddings_valid.shape[0], n_genes_to_plot, replace=False)
gene_subset = gene_embeddings_valid[gene_subset_idx]

# UMAP for genes
gene_reducer = umap.UMAP(min_dist=0.1, n_components=2, n_neighbors=15, random_state=42)
gene_umap = gene_reducer.fit_transform(gene_subset)

# Plot
plt.figure(figsize=(10, 8))
plt.scatter(gene_umap[:, 0], gene_umap[:, 1], s=10, alpha=0.5, c='steelblue')
plt.title(f'UMAP Visualization of {n_genes_to_plot} Gene Embeddings', fontsize=14, fontweight='bold')
plt.xlabel('UMAP 1', fontsize=12)
plt.ylabel('UMAP 2', fontsize=12)
plt.tight_layout()
plt.show()

## Extract Attention Weights

For interpretability, you can extract attention weights from the transformer layers. This requires using the PyTorch attention implementation instead of Flash Attention.

**Note:** This is slower and uses more memory than the default Flash Attention.

In [None]:
# Create a new model with torch attention implementation
tahoe_config_attn = TahoeConfig(
    model_size="70m",
    batch_size=4,  # Reduce batch size for memory efficiency
    device=device,
    attn_impl='torch'  # Required for attention extraction
)

tahoe_attn = Tahoe(configurer=tahoe_config_attn)
print("Tahoe model with attention extraction loaded!")

In [None]:
# Process a smaller subset for attention extraction
ann_data_tiny = ann_data[:50]  # Use only 50 cells

dataloader_attn = tahoe_attn.process_data(
    ann_data_tiny,
    gene_names="gene_name",
    use_raw_counts=True
)

In [None]:
# Extract attention weights
cell_embeddings_attn, attention_weights = tahoe_attn.get_embeddings(
    dataloader_attn, 
    output_attentions=True
)

print(f"Cell embeddings shape: {cell_embeddings_attn.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"\nAttention weights dimensions: (n_cells, n_heads, seq_length, seq_length)")

## Visualize Attention Patterns

Let's visualize the attention pattern for one cell to see which genes the model pays attention to:

In [None]:
# Select first cell and average across attention heads
cell_idx = 0
cell_attention = attention_weights[cell_idx]  # Shape: (n_heads, seq_len, seq_len)

# Average across heads
avg_attention = cell_attention.mean(axis=0)  # Shape: (seq_len, seq_len)

# Find actual sequence length (excluding padding)
non_zero_mask = avg_attention.sum(axis=1) > 0
actual_seq_len = non_zero_mask.sum()
avg_attention_trimmed = avg_attention[:actual_seq_len, :actual_seq_len]

# Plot attention heatmap
plt.figure(figsize=(12, 10))
sns.heatmap(
    avg_attention_trimmed[:50, :50],  # Show first 50x50 for visibility
    cmap='viridis',
    square=True,
    cbar_kws={'label': 'Attention Weight'}
)
plt.title(f'Attention Pattern for Cell {cell_idx} (averaged across heads)', fontsize=14, fontweight='bold')
plt.xlabel('Key Position (Gene Tokens)', fontsize=12)
plt.ylabel('Query Position (Gene Tokens)', fontsize=12)
plt.tight_layout()
plt.show()

print(f"Showing first 50x50 positions of {actual_seq_len} total sequence length")

## Summary

In this notebook, you learned how to:

1. ✅ Load and configure the Tahoe-x1 model for single-cell RNA-seq analysis
2. ✅ Process scRNA-seq data with automatic gene symbol to Ensembl ID mapping
3. ✅ Extract cell embeddings that capture cellular states
4. ✅ Extract gene embeddings that represent gene expression patterns
5. ✅ Visualize embeddings using UMAP for exploratory analysis
6. ✅ Extract and visualize attention weights for model interpretability

### Next Steps

- **Cell Type Annotation**: Use the embeddings for downstream tasks like cell type classification
- **Gene Analysis**: Identify gene modules or marker genes using gene embeddings
- **Integration**: Combine Tahoe embeddings with other analysis tools in the scRNA-seq ecosystem
- **Fine-tuning**: Adapt the model for specific downstream tasks (see other notebooks)

### Model Information

- **Model**: Tahoe-x1 by Tahoe Therapeutics
- **Hugging Face**: [tahoebio/Tahoe-x1](https://huggingface.co/tahoebio/Tahoe-x1)
- **Architecture**: Transformer-based foundation model for scRNA-seq
- **Available sizes**: 70m (12 layers, 512d), 1b (24 layers, 1024d), 3b (36 layers, 1536d)

For more information and examples, visit the [Helical documentation](https://helical.readthedocs.io/).