## Evo 2 Yarrowia Sparse Auto Encoder (Local Version)

This notebook shows a **minimal, classâ€‘free** workflow to:

1. **Load Yarrowia** sequence from a local FASTA (`GCF_001761485.1_ASM176148v1_genomic.fna`) and annotations from a local **GFF** (`genomic.gff`).
2. Use a **local Evo2 model** (via `transformers`) to get layer activations.
3. Load a preâ€‘trained **Topâ€‘K tied Sparse Autoencoder (SAE)** from Hugging Face.
4. **Project activations into SAE features** and **plot a handful** of them with GFF annotations.

> [!TIP]
> **Setup:** Install dependencies with `pip install -U transformers huggingface_hub evo2 biopython flash-attn` and restart your kernel.
> This notebook defaults to the **7B model** which is compatible with the provided SAE weights.


### Set up imports

In [None]:
import os, io, base64, json, zipfile, time, yaml, pkgutil, gc, os
import numpy as np
import torch
import matplotlib.pyplot as plt
from Bio import SeqIO
# from huggingface_hub import hf_hub_download
# from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('CUDA_HOME', '/opt/apps/software/system/CUDA/12.2.0')}/lib64:/.singularity.d/libs:" + os.environ.get('LD_LIBRARY_PATH', '')
from evo2 import Evo2
from functools import partial


# Reproducibility & device
torch.manual_seed(42)
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
dtype = torch.bfloat16  # works well for large tensors
torch.set_grad_enabled(False)

print("Device:", device)

### Load Evo2 API helper functions

In [None]:
# MODEL_CONFIG: Change to "arcinstitute/evo2_40b" if you have a compatible SAE model.
MODEL_ID = "arcinstitute/evo2_7b" 

def evo2_forward(sequence, output_layers, model=MODEL_ID, device=device, dtype=dtype):
    """Local inference for Evo2 to get layer activations.
    Returns (logits, acts_dict) as numpy arrays.
    """
    global _local_model, _local_tokenizer
    if '_local_model' not in globals():
        print(f"Loading model {model}... ")
        _local_tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
        if device == 'cuda':
            _local_model = AutoModelForCausalLM.from_pretrained(
                model, 
                trust_remote_code=True, 
                torch_dtype=dtype,
                device_map="auto"
            )
        else:
            _local_model = AutoModelForCausalLM.from_pretrained(
                model, 
                trust_remote_code=True, 
                torch_dtype=dtype
            ).to(device)
        _local_model.eval()

    inputs = _local_tokenizer(sequence, return_tensors="pt").to(device)
    
    acts = {}
    hooks = []
    
    def get_hook(name):
        def hook(module, input, output):
            data = output[0] if isinstance(output, tuple) else output
            acts[f"{name}.output"] = data.detach().cpu().numpy()
        return hook

    # Register hooks for requested layers
    for layer_name in output_layers:
        try:
            module = _local_model
            if hasattr(module, 'backbone'):
                module = module.backbone
            elif hasattr(module, 'model'):
                module = module.model
                
            parts = layer_name.split('.')
            for part in parts:
                if part.isdigit():
                    module = module[int(part)]
                else:
                    module = getattr(module, part)
            hooks.append(module.register_forward_hook(get_hook(layer_name)))
        except Exception as e:
            print(f"Warning: Could not register hook for {layer_name}: {e}")

    with torch.no_grad():
        outputs = _local_model(**inputs)
        lgt = outputs.logits.detach().cpu().numpy()

    # Remove hooks
    for h in hooks:
        h.remove()
        
    def squeeze(x):
        return x[0] if (isinstance(x, np.ndarray) and x.ndim >= 3 and x.shape[0] == 1) else x
    
    return squeeze(lgt), {k: squeeze(v) for k, v in acts.items()}


### Load Yarrowia sequence (first 25,000 bases of chromosome 1)

In [None]:
FASTA_PATH = "/home/jacobbw/Desktop/Code/yarrowia_language_modeling/yarrowia_notebooks/sparce_autoencoder/GCF_001761485.1_ASM176148v1_genomic.fna"   # FASTA
GFF_PATH   = "genomic.gff"                               # GFF annotations

WINDOW_START = 1            # 1-based
WINDOW_END   = 25_000       # inclusive 1..25000

def pick_chr1_record(fasta_path):
    """Return the SeqRecord for chromosome 1 (best effort).
    If not found, return the first record.
    """
    candidates = []
    for rec in SeqIO.parse(fasta_path, "fasta"):
        rid = (rec.id + " " + (rec.description or "")).lower()
        candidates.append(rec)
        if any(tag in rid for tag in ["chromosome 1", "chromosome i", "chr1", "chr 1"]):
            return rec
    # fallback: first record
    return candidates[0]

chr1_rec = pick_chr1_record(FASTA_PATH)
seq_1based = chr1_rec.seq[WINDOW_START-1:WINDOW_END]
sequence = str(seq_1based)
print(chr1_rec.id, "length:", len(chr1_rec.seq))
print("Window length:", len(sequence))
print(sequence[:80] + " ...")


### Ask Evo2 for layer activations over this sequence
Potentially, we can look into trying other layer outputs to find different features

In [None]:
# Disable TF32 for max precision, or enable for speed (your call)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# === CONFIG ===
CHECKPOINT_PATH = "/home/jacobbw/roell_group/hf_cache/evo2_40b.pt"
SAE_LAYER_NAME = "blocks.26.mlp.l3"
WANTED_LAYERS = [SAE_LAYER_NAME]
USE_MULTI_GPU = True  # Set False if you want single GPU
COMPILE_MODEL = True  # torch.compile for faster inference
# === IMPORTS ===
from evo2.models import Evo2
from vortex.model.model import StripedHyena
def nuke_fp8(model):
    """Aggressively remove all FP8 artifacts from model."""
    for module in model.modules():
        # Kill fp8_meta dict
        if hasattr(module, 'fp8_meta'):
            module.fp8_meta = None
        # Kill fp8 flag
        if hasattr(module, 'fp8'):
            module.fp8 = False
        # Kill any TE-specific hooks that might interfere
        if hasattr(module, '_load_from_state_dict'):
            # Replace with vanilla PyTorch impl to skip FP8 restoration
            module._load_from_state_dict = partial(
                torch.nn.Module._load_from_state_dict, module
            )
def load_state_dict_streaming(model, state_dict, device):
    """
    Load weights one-by-one directly to device.
    Avoids holding two copies in memory.
    """
    model_state = model.state_dict()
    
    for key in model_state.keys():
        if key in state_dict:
            # Stream directly: disk -> GPU (via mmap)
            param = state_dict[key]
            if param.dtype != torch.bfloat16:
                param = param.to(dtype=torch.bfloat16)
            model_state[key].copy_(param.to(device=device, non_blocking=True))
    
    # Sync to ensure all transfers complete
    torch.cuda.synchronize()
def load_evo2_blazing_fast():
    """
    Maximum performance loader for Evo2 40B on H200s.
    """
    print("=" * 60)
    print("ðŸš€ BLAZING FAST EVO2 LOADER")
    print("=" * 60)
    
    # Verify hardware
    num_gpus = torch.cuda.device_count()
    print(f"Found {num_gpus} GPU(s): {[torch.cuda.get_device_name(i) for i in range(num_gpus)]}")
    
    primary_device = torch.device("cuda:0")
    torch.cuda.set_device(primary_device)
    
    # â”€â”€â”€ 1. CONFIG â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    print("\n[1/5] Loading config...")
    try:
        config_data = pkgutil.get_data("evo2", "configs/evo2_40b.yml")
        if config_data is None:
            config_data = pkgutil.get_data("evo2.models", "configs/evo2_40b.yml")
        config = yaml.safe_load(config_data)
    except Exception as e:
        raise RuntimeError(f"Config load failed: {e}. Check evo2 package installation.")
    # â”€â”€â”€ 2. INIT MODEL ON GPU (ZERO RAM) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    print("\n[2/5] Initializing model skeleton on GPU...")
    
    # Use meta device first for zero-allocation init, then materialize on GPU
    # This is the fastest way to create a model skeleton
    with torch.device('meta'):
        model = StripedHyena(config)
    
    # Now materialize on GPU with empty tensors
    def materialize_to_device(module, device, dtype):
        for name, param in module._parameters.items():
            if param is not None:
                module._parameters[name] = torch.nn.Parameter(
                    torch.empty(param.shape, device=device, dtype=dtype),
                    requires_grad=False
                )
        for name, buf in module._buffers.items():
            if buf is not None:
                module._buffers[name] = torch.empty(buf.shape, device=device, dtype=dtype)
    
    for module in model.modules():
        materialize_to_device(module, primary_device, torch.bfloat16)
    
    print(f"   Model structure allocated: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
    
    # â”€â”€â”€ 3. NUKE FP8 â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    print("\n[3/5] Patching FP8...")
    nuke_fp8(model)
    
    # â”€â”€â”€ 4. STREAM WEIGHTS DIRECTLY â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    print("\n[4/5] Streaming weights from disk â†’ GPU...")
    
    # Memory map the checkpoint (doesn't load into RAM)
    state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu", mmap=True, weights_only=True)
    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]
    
    # Filter FP8 garbage
    fp8_keys = [k for k in state_dict.keys() if "_extra_state" in k]
    for k in fp8_keys:
        del state_dict[k]
    print(f"   Filtered {len(fp8_keys)} FP8 metadata keys")
    
    # Stream load (one tensor at a time: disk â†’ GPU)
    load_state_dict_streaming(model, state_dict, primary_device)
    
    del state_dict
    gc.collect()
    torch.cuda.empty_cache()
    
    print(f"   Post-load VRAM: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
    
    # â”€â”€â”€ 5. MULTI-GPU SPLIT (OPTIONAL) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    if USE_MULTI_GPU and num_gpus >= 2:
        print("\n[5/5] Splitting model across GPUs...")
        # Naive layer split: first half on GPU 0, second half on GPU 1
        # For StripedHyena, layers are in model.blocks
        if hasattr(model, 'blocks'):
            n_layers = len(model.blocks)
            split_point = n_layers // 2
            for i, block in enumerate(model.blocks):
                target_device = torch.device(f"cuda:{0 if i < split_point else 1}")
                block.to(target_device)
            print(f"   Layers 0-{split_point-1} -> GPU 0")
            print(f"   Layers {split_point}-{n_layers-1} -> GPU 1")
        else:
            print("   Warning: 'blocks' attribute not found. Using single GPU.")
    else:
        print("\n[5/5] Single GPU mode - skipping split")
    
    # â”€â”€â”€ 6. COMPILE (torch.compile for speed) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    if COMPILE_MODEL:
        print("\n[BONUS] Compiling model with torch.compile...")
        try:
            model = torch.compile(model, mode="reduce-overhead", fullgraph=False)
            print("   âœ“ Compiled successfully")
        except Exception as e:
            print(f"   âœ— Compile failed (non-fatal): {e}")
    
    model.eval()
    
    # â”€â”€â”€ WRAP IN EVO2 â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    wrapper = Evo2.__new__(Evo2)
    wrapper.model = model
    wrapper.tokenizer = None  # Init separately if needed
    
    print("\n" + "=" * 60)
    print("âœ… MODEL READY")
    print(f"   GPU 0: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
    if num_gpus >= 2:
        print(f"   GPU 1: {torch.cuda.memory_allocated(1)/1024**3:.2f} GB")
    print("=" * 60)
    
    return wrapper
# === RUN ===
model_obj = load_evo2_blazing_fast()
# === INFERENCE ===
# Warmup (first run is slower due to CUDA kernels)
print("\nWarmup run...")
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
    # Dummy input for warmup
    # Replace with actual sequence
    _ = model_obj.model(torch.zeros(1, 1024, dtype=torch.long, device="cuda:0"))
print("\nActual inference...")
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
    logits, acts = evo2_forward(sequence, output_layers=WANTED_LAYERS, model=model_obj)
print("Layers returned:", list(acts.keys()))
if f"{SAE_LAYER_NAME}.output" in acts:
    print("Activations shape:", acts[f"{SAE_LAYER_NAME}.output"].shape)

### Inspect the activation vectors at the output of layer 26

In [None]:
print(f'There are {len(layer_act)} tokens in this sequence window.')
print(f'Each token has {len(layer_act[0])} features. This is the model hidden size or model dimension.')

# make a dataframe of this data with columns: position, base, feature_0, feature_1, ..., feature_n
df = pd.DataFrame(layer_act, columns=[f'feature_{i}' for i in range(layer_act.shape[1])])
df.insert(0, 'base', list(sequence))
df.insert(0, 'position', range(WINDOW_START, WINDOW_END + 1))

df.head(10)

### Load the Topâ€‘K tied SAE model from from Hugging Face
https://huggingface.co/Goodfire/Evo-2-Layer-26-Mixed

In [None]:
# We'll load the weights and keep them as plain tensors in a dict.
# These weights were trained with expansion_factor=8 and k=64.
SAE_REPO = "Goodfire/Evo-2-Layer-26-Mixed"
SAE_FILE = "sae-layer26-mixed-expansion_8-k_64.pt"

sae_path = hf_hub_download(repo_id=SAE_REPO, filename=SAE_FILE, repo_type="model")
state = torch.load(sae_path, map_location="cpu", weights_only=True)

# Normalize potential key prefixes (strip '_orig_mod.' or 'module.')
clean = {}
for k, v in state.items():
    nk = k.replace("_orig_mod.", "").replace("module.", "")
    clean[nk] = v

# Expected keys: 'W', 'b_enc', 'b_dec'
W     = clean["W"].to(device=device, dtype=dtype)         # [d_in, d_hidden_expanded]
b_enc = clean["b_enc"].to(device=device, dtype=dtype)     # [d_hidden_expanded]
b_dec = clean["b_dec"].to(device=device, dtype=dtype)     # [d_in]
TOP_K = 64

print("W:", tuple(W.shape))
print("b_enc:", tuple(b_enc.shape))
print("b_dec:", tuple(b_dec.shape))


### Define function for running the model (encode/decode helpers)

In [None]:
def relu(x):
    return torch.maximum(x, torch.zeros_like(x))

def encode_topk(x, W, b_enc, k=TOP_K):
    """x: [T, d_in]; returns sparse features f: [T, d_hidden_expanded] with Topâ€‘K per row."""
    # Pre-activation
    f = x @ W + b_enc                  # [T, d_hidden_expanded]
    f = relu(f)

    # Topâ€‘K per row
    # values, idx: [T, k]
    values, idx = torch.topk(f, k=min(k, f.shape[-1]), dim=-1)
    out = torch.zeros_like(f)
    out.scatter_(dim=-1, index=idx, src=values)
    return out

def decode(f, W, b_dec):
    return f @ W.T + b_dec


### Project Evo2 activations through the SAE

In [None]:
# Convert layer_act (numpy) -> torch
x = torch.tensor(layer_act, device=device, dtype=dtype)   # [T, d_in]

print(f'The input to the SAE has shape: {x.shape}')
print(f'The SAE weight (W) has shape: {W.shape}')

if x.shape[1] != W.shape[0]:
    print("\n" + "="*40)
    print("DIMENSION MISMATCH ERROR")
    print(f"Model activations (dim {x.shape[1]}) do not match SAE weights (dim {W.shape[0]}).")
    print("To fix this, ensure you are using 'arcinstitute/evo2_7b'.")
    print("="*40 + "\n")
else:
    with torch.no_grad():
        f_sparse = encode_topk(x, W, b_enc, k=TOP_K)          # [T, d_hidden_expanded]

    # Convert sparse features to numpy for further analysis
    sae_acts = f_sparse.float().cpu().numpy()
    print(f'The sae activations have shape: {sae_acts.shape}')


### Inspect the feats np object

In [None]:
print(f'There are {len(sae_acts)} tokens in this sequence window.')
print(f'Each token has {len(layer_act[0])} features. This is the model hidden size or model dimension.')
print(f'Each token embedding was expanded to {sae_acts.shape[1]} features by the SAE.')

# convert to a dataframe
df_feats = pd.DataFrame(sae_acts, columns=[f'feature_{i}' for i in range(sae_acts.shape[1])])
df_feats.insert(0, 'base', list(sequence))
df_feats.insert(0, 'position', range(WINDOW_START, WINDOW_END + 1))
df_feats.head(10)

### Which features are most active?

In [None]:
# average the features across all tokens,
feature_only_df = df_feats[df_feats.columns[2:]]  # exclude position, base

display(feature_only_df.head())

# loop over the columns and count the number of non-zero entries per feature
nonzero_counts = {}
for col in feature_only_df.columns:
    nonzero_counts[col] = (feature_only_df[col] != 0).sum()

nonzero_counts_series = pd.Series(nonzero_counts)
nonzero_counts_series.sort_values(ascending=False)[:20]  # show top 20 most active features


### Make a binary version of the feature data frame

In [None]:
# Make a binary version of the feature data frame
df_feats_binary = df_feats.copy()
for col in df_feats_binary.columns[2:]:  # exclude position, base
    df_feats_binary[col] = (df_feats_binary[col] != 0).astype(int)

df_feats_binary.head()

### Parse GFF and collect annotations in the 1..25,000 window

In [None]:
# Minimal GFF parser: we keep rows overlapping our window on this chromosome.
def parse_gff_window(gff_path, target_seqid, start_1b, end_1b,
                     keep_types=None):
    if keep_types is None:
        keep_types = {"gene","CDS","exon","mRNA","ncRNA","tRNA","rRNA","misc_feature","Regulatory","tmRNA","mobile_element"}

    out = []
    with open(gff_path, "r", newline="") as fh:
        for line in fh:
            if not line or line.startswith("#"):
                continue
            parts = line.strip().split("	")
            if len(parts) < 9:
                continue
            seqid, source, ftype, start, end, score, strand, phase, attrs = parts
            if seqid != target_seqid:
                continue
            try:
                s = int(start); e = int(end)
            except:
                continue
            if e < start_1b or s > end_1b:
                continue
            if ftype not in keep_types:
                continue
            # Clip to window for plotting
            s_clipped = max(start_1b, s) - start_1b
            e_clipped = min(end_1b,   e) - start_1b
            out.append((s_clipped, e_clipped, ftype, strand, attrs))
    return out

# Try to match the sequence/chromosome name used in GFF to our chosen FASTA record
gff_seqid = chr1_rec.id
ann = parse_gff_window(GFF_PATH, gff_seqid, WINDOW_START, WINDOW_END)

print("Annotations kept:", len(ann))
print("First few:", ann[:5])


### Inspect the annotations

In [None]:
# view annotations as a dataframe
df_ann = pd.DataFrame(ann, columns=["start", "end", "type", "strand", "attributes"])
df_ann

### Make a dataframe with columns for position, base, and each annotation as a binary value

In [None]:
# get a list of possible annotation types
annotation_types = df_ann["type"].unique()

print("Annotation types present in the sampled sequence:", annotation_types)

# make a dataframe of zeros with columns for each annotation type
df_ann_binary = pd.DataFrame(0, index=df_feats_binary.index, columns=annotation_types)

# loop over the rows and set the corresponding positions to 1
for index, row in df_ann.iterrows():
    df_ann_binary.loc[row["start"]:row["end"], row["type"]] = 1
df_ann_binary.insert(0, 'base', list(sequence))
df_ann_binary.insert(0, 'position', range(WINDOW_START, WINDOW_END + 1))

display(df_ann_binary.iloc[1510:1530])

### Add annotations to the df_feats_binary dataframe

In [None]:
# Add annotations to the df_feats_binary dataframe
df_feats_ann_binary = pd.concat([df_ann_binary, df_feats_binary], axis=1)

df_feats_ann_binary

### Analyze the correlations between the features and the annotations

In [None]:
# get the annotation and feature columns
annotation_cols = df_ann_binary.columns[2:]  # exclude position, base
feature_cols = [col for col in df_feats_ann_binary.columns if col.startswith("feature_")]

# make simplified dataframes
annotations = df_feats_ann_binary[annotation_cols]
features = df_feats_ann_binary[feature_cols]

# dictionary to hold correlation results
correlation_results = {}

# loop over the annotation columns and compute correlation with each feature
for ann_col in annotation_cols:
    corr = features.corrwith(annotations[ann_col])  # computes correlation column-wise
    correlation_results[ann_col] = corr

corr_df = pd.DataFrame(correlation_results)

corr_df

### Find the features with the highest correlations to each annotation

In [None]:
# Find the features with the highest correlations to each annotation
# loop over the annotation columns and get top correlated features
corr_df['gene'].sort_values(ascending=False)[:5]  # top 5 features correlated with 'gene' annotation

### Plot a handful of SAE features with GFF overlays

In [None]:
sae_acts.shape

In [None]:
annotation_cols

In [None]:

# Choose a few feature indices to visualize (change to taste)
# selected_features = [15680, 28339, 1050, 25666]
selected_features = [15809, 302, 7703, 26750, 25852]

ANNOTATION_COLORS = {
    'CDS': 'lightyellow',
    'gene': 'lightgray', 
    'mobile_element': 'lightgreen',
    'misc_feature': 'khaki',
    'rRNA': '#7AC8AC',
    'tRNA': '#662D91',
    'ncRNA': 'white',
    'Regulatory': 'lightcoral',
    'tmRNA': 'salmon',
    'exon': 'lightblue',
    'mRNA': 'lavender'
}

T = sae_acts.shape[0]
fig, axes = plt.subplots(len(selected_features), 1, figsize=(24, 1.6*len(selected_features)), sharex=True)

for i, feat_id in enumerate(selected_features):
    ax = axes[i] if len(selected_features) > 1 else axes
    if feat_id >= sae_acts.shape[1]:
        ax.text(0.01, 0.5, f"Feature {feat_id} out of range", transform=ax.transAxes)
        continue
    ax.plot(sae_acts[:, feat_id], lw=0.7, label=f"feature {feat_id}", alpha=0.9)
    for s, e, ftype, strand, attrs in ann:
        ax.axvspan(s, e, color=ANNOTATION_COLORS.get(ftype, 'lightgray'), alpha=0.25)
    ax.set_xlim(0, T)
    ax.set_yticks([0, max(0.1, sae_acts[:, feat_id].max())])
    ax.legend(loc="upper right", frameon=False)

plt.xlabel("Position (bp) in 1..25,000 window")
plt.show()



### Tips / Next steps
- Change `SAE_LAYER_NAME` to probe different hidden layers.
- Adjust `selected_features` to scan for interesting peaks.
- Expand `WINDOW_END` or slide `WINDOW_START` to explore more of chromosome 1.
- Save `feats_np` and overlay other tracks (GC%, motif hits, etc.).
