In [3]:
"""
Script for training SAE from a cache of scGPT activations and loading the custom fidelity function.
Assumes specific dataset structure for activations, as created by interp/data_processing/embed_fasta.py.
"""

import warnings
from pathlib import Path

import torch
from torch.utils.data import DataLoader

import sys
sys.path.append('/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm')

from sae.dictionary import AutoEncoder
# from train.fidelity import get_loss_recovery_fn
from train.load_sharded_acts import GFMDataset
from train.trainer import StandardTrainer
from train.training import train_run
from utils import get_device

warnings.filterwarnings("ignore", message="TypedStorage is deprecated")

In [11]:
def train_SAE_on_gfm_embeds(
    # Data paths and sources
    embd_dir: Path,
    eval_seq_path: Path,
    layer='',
    # Core model architecture
    expansion_factor: int = 8,
    # Training configuration
    batch_size: int = 32,
    steps: int = 1_000,
    seed: int = 0,
    # Optimization parameters
    lr: float = 1e-3,
    warmup_steps: int = 50,
    resample_steps: int = 0,  # 0 to disable
    # Regularization
    l1_penalty: float = 1e-1,
    l1_annealing_pct: float = 0.05,
    # Evaluation settings
    eval_batch_size: int = 128,
    eval_steps: int = 1_000,
    # Logging and checkpointing
    save_dir: str = "models",
    log_steps: int = 100,
    save_steps: int = 50,
    max_ckpts_to_keep: int = 3,
    # Weights & Biases configuration
    use_wandb: bool = False,
    wandb_entity: str = "",
    wandb_project: str = "test_logging",
    wandb_name: str = "SAE",
):
    """
    Train a Sparse Autoencoder (SAE) using cached activation data from a language model.

    Args:
        # Data paths and sources
        embd_dir: Directory containing cached model embeddings
        eval_seq_path: Path to sequences for fidelity evaluation, if None, fidelity evaluation is disabled

        # Core model architecture
        expansion_factor: Factor by which to expand the dictionary size relative to input dimension

        # Training configuration
        batch_size: Number of samples per training batch
        steps: Total number of training steps
        seed: Random seed for reproducibility

        # Optimization parameters
        lr: Learning rate for optimizer
        warmup_steps: Number of warmup steps for learning rate scheduler
        resample_steps: Steps between dictionary resampling (0 to disable)

        # Regularization
        l1_penalty: Coefficient for L1 regularization
        l1_annealing_pct: Percentage of training during which to anneal L1 penalty

        # Evaluation settings
        eval_batch_size: Batch size for evaluation
        eval_steps: Frequency of evaluation steps

        # Logging and checkpointing
        save_dir: Directory to save model checkpoints and outputs
        log_steps: Frequency of logging
        save_steps: Frequency of saving checkpoints

        # Weights & Biases configuration
        use_wandb: Whether to use Weights & Biases logging
        wandb_entity: W&B username or team name
        wandb_project: W&B project name
        wandb_name: W&B run name
    """
    # device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:2" if torch.cuda.is_available() else "cpu")
    device='cuda:2'

    def collate_fn(batch):
        return torch.stack(batch).to(device)

    # Initialize dataset and dataloader
    acts_dataset = GFMDataset(embd_dir)

    # Determine layer from dataset metadata
    # layer = acts_dataset.datasets[0]["layer"]
    # plm_name = acts_dataset.datasets[0]["plm_name"]
    # print(f"Using activations from layer {layer} of {plm_name}")
    

    dataloader = DataLoader(
        acts_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
    )
    print(f"Loaded dataset with {len(acts_dataset):,} tokens")

    # Configure resampling
    if resample_steps == 0:
        resample_steps = None

    # Setup trainer configuration
    trainer = StandardTrainer(
        activation_dim=acts_dataset.d_model,
        dict_size=acts_dataset.d_model * expansion_factor,
        warmup_steps=warmup_steps,
        resample_steps=resample_steps,
        lr=lr,
        l1_penalty=l1_penalty,
        l1_annealing_pct=l1_annealing_pct,
        seed=seed,
        wandb_name=wandb_name,
        layer=layer,
        plm_name='scgpt',
        device=device,
        steps=min(steps, len(dataloader)),
    )
    print(f"Training with config: {trainer.config}")

    # Initialize fidelity function if evaluation sequences provided
    if eval_seq_path is not None:
        fidelity_fn = get_loss_recovery_fn(
            esm_model_name=plm_name,
            layer_idx=int(layer),
            eval_seq_path=eval_seq_path,
            device=device,
            batch_size=eval_batch_size,
        )
    else:
        fidelity_fn = None

    # Train the SAE
    train_run(
        # Core training components
        data=dataloader,
        trainer=trainer,
        # Evaluation settings
        fidelity_fn=fidelity_fn,
        eval_steps=eval_steps,
        # Logging and checkpointing
        save_dir=save_dir,
        log_steps=log_steps,
        save_steps=save_steps,
        max_ckpts_to_keep=3,
        # Weights & Biases configuration
        use_wandb=use_wandb,
        wandb_entity=wandb_entity,
        wandb_project=wandb_project,
        additional_wandb_args={
            "eval_seq_path": eval_seq_path,
            "eval_steps": eval_steps,
            "batch_size": batch_size,
            "save_dir": save_dir,
        },
    )

In [12]:
# layer 0

train_SAE_on_gfm_embeds(
    # Data paths and sources
    embd_dir=Path('/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/scgpt/activations/layer_0'),
    eval_seq_path=None,
    # Core model architecture
    expansion_factor = 8,
    layer='layer_0',
    # Training configuration
    batch_size = 32,
    steps = 1_000,
    seed = 0,
    # Optimization parameters
    lr = 1e-3,
    warmup_steps = 50,
    resample_steps = 0,  # 0 to disable
    # Regularization
    l1_penalty = 1e-1,
    l1_annealing_pct = 0.05,
    # Evaluation settings
    eval_batch_size = 128,
    eval_steps = 1_000,
    # Logging and checkpointing
    save_dir=Path('/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/scgpt/sae_output/layer_0'),
    log_steps = 100,
    save_steps = 50,
    max_ckpts_to_keep = 3,
    # Weights & Biases configuration
    use_wandb = False,
    wandb_entity = "",
    wandb_project = "test_logging",
    wandb_name = "SAE"
)

Loading dataset metadata


100%|██████████| 766/766 [00:00<00:00, 610689.39it/s]


Loaded dataset with 766 tokens
Training with config: {'dict_class': 'AutoEncoder', 'trainer_class': 'StandardTrainer', 'activation_dim': 512, 'dict_size': 4096, 'lr': 0.001, 'l1_penalty': 0.1, 'l1_annealing_steps': 1, 'steps': 24, 'warmup_steps': 50, 'resample_steps': None, 'device': 'cuda:2', 'layer': 'layer_0', 'plm_name': 'scgpt', 'wandb_name': 'SAE', 'submodule_name': None}


  0%|          | 0/24 [00:01<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [264, 512] at entry 0 and [375, 512] at entry 1

In [23]:
import scanpy as sc

ad = sc.read_h5ad('/maiziezhou_lab2/yunfei/Projects/FM_temp/datasets/cosmx/lung/cosmx_human_lung.h5ad')

In [27]:
ad.obs

Unnamed: 0,AspectRatio,CenterX_global_px,CenterY_global_px,Width,Height,Mean.MembraneStain,Max.MembraneStain,Mean.PanCK,Max.PanCK,Mean.CD45,...,assay,organism,sex,tissue,dataset,x,y,nicheformer_split,_scvi_batch,_scvi_labels
1_1,1.34,4215.888889,158847.666667,47,35,3473,7354,715,5755,361,...,NanoString digital spatial profiling,Homo sapiens,female,lung,nanostring_cosmx_human_lung,4215.888889,158847.666667,train,0,0
2_1,1.45,6092.888889,158834.666667,87,60,3895,13832,18374,53158,260,...,NanoString digital spatial profiling,Homo sapiens,female,lung,nanostring_cosmx_human_lung,6092.888889,158834.666667,train,0,0
3_1,1.62,7214.888889,158843.666667,68,42,2892,6048,3265,37522,378,...,NanoString digital spatial profiling,Homo sapiens,female,lung,nanostring_cosmx_human_lung,7214.888889,158843.666667,train,0,0
4_1,0.47,7418.888889,158813.666667,48,102,6189,16091,485,964,679,...,NanoString digital spatial profiling,Homo sapiens,female,lung,nanostring_cosmx_human_lung,7418.888889,158813.666667,train,0,0
5_1,1.00,7446.888889,158845.666667,38,38,8138,19281,549,874,566,...,NanoString digital spatial profiling,Homo sapiens,female,lung,nanostring_cosmx_human_lung,7446.888889,158845.666667,train,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3999_20-2,1.02,-18135.000000,10850.777778,45,44,7430,9572,370,962,450,...,NanoString digital spatial profiling,Homo sapiens,male,lung,nanostring_cosmx_human_lung,-18135.000000,10850.777778,train,0,0
4000_20-2,1.71,-18088.000000,10846.777778,60,35,8362,11209,161,2024,572,...,NanoString digital spatial profiling,Homo sapiens,male,lung,nanostring_cosmx_human_lung,-18088.000000,10846.777778,train,0,0
4001_20-2,2.75,-19112.000000,10846.777778,99,36,5158,10180,634,2166,41,...,NanoString digital spatial profiling,Homo sapiens,male,lung,nanostring_cosmx_human_lung,-19112.000000,10846.777778,train,0,0
4003_20-2,2.12,-19551.000000,10841.777778,55,26,6339,9804,211,570,488,...,NanoString digital spatial profiling,Homo sapiens,male,lung,nanostring_cosmx_human_lung,-19551.000000,10841.777778,train,0,0


In [31]:
import numpy as np

# Get the index of the cell by name
cell_name = '2359_14-6'
cell_idx = ad.obs_names.get_loc(cell_name)

# Get the row corresponding to that cell
cell_data = ad.X[cell_idx]

# If it's sparse, convert to dense first
if hasattr(cell_data, 'toarray'):
    cell_data = cell_data.toarray().flatten()

# Get non-zero values
non_zero_values = cell_data[cell_data != 0]

print(len(non_zero_values))

113


In [13]:
tensor = torch.load(
            '/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/scgpt/activations/layer_0/4984_6__T_CD4_naive.pt', map_location="cpu", weights_only=True)

In [15]:
tensor.shape

torch.Size([264, 512])

In [None]:
import os
# dir_ = '/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/scgpt/activations/layer_0'
dir_ = '/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/scgpt/activations/layer_11'
for p in os.listdir(dir_):
    print(p)
    if p.endswith('.pt'):
        tensor = torch.load(os.path.join(dir_, p), map_location="cpu", weights_only=True)
        print(tensor.shape)

3621_17-1_macrophage.pt
torch.Size([352, 512])
4393_7_fibroblast.pt
torch.Size([320, 512])
4211_7_plasmablast.pt
torch.Size([334, 512])
2161_14-1_NK.pt
torch.Size([332, 512])
1349_15-3_macrophage.pt
torch.Size([311, 512])
1869_3-5_tumor_9.pt
torch.Size([437, 512])
999_14-2_epithelial.pt
torch.Size([333, 512])
256_6-3_tumor_6.pt
torch.Size([332, 512])
1377_20-1_T_CD4_naive.pt
torch.Size([345, 512])
1966_11-1_tumor_5.pt
torch.Size([408, 512])
3419_9_fibroblast.pt
torch.Size([264, 512])
2083_12-2_tumor_5.pt
torch.Size([356, 512])
628_3-4_endothelial.pt
torch.Size([327, 512])
2665_8-2_tumor_5.pt
torch.Size([435, 512])
384_4-4_endothelial.pt
torch.Size([334, 512])
224_19-7_T_CD4_memory.pt
torch.Size([383, 512])
3620_25-1_plasmablast.pt
torch.Size([474, 512])
875_15-2_NK.pt
torch.Size([276, 512])
372_9-5_tumor_9.pt
torch.Size([474, 512])
3092_27-1_endothelial.pt
torch.Size([324, 512])
1060_19-1_neutrophil.pt
torch.Size([438, 512])
98_21-5_tumor_12.pt
torch.Size([344, 512])
976_20-1_fibroblas

KeyboardInterrupt: 