In [2]:
import json
from pathlib import Path
from typing import List, Tuple, Union, Optional
import os
import numpy as np
import torch
from scipy import sparse
from tqdm import tqdm
import scanpy as sc
import pandas as pd

import sys
sys.path.append('/maiziezhou_lab2/yunfei/Projects/interpTFM/sae')
from dictionary import AutoEncoder
from inference import get_sae_feats_in_batches, load_sae

In [3]:
def count_unique_nonzero_dense(matrix: torch.Tensor) -> List[int]:
    """
    Count unique non-zero values in each column of a dense matrix.

    Args:
        matrix: Dense PyTorch tensor to analyze

    Returns:
        List of counts of unique non-zero values for each column
    """
    # Initialize list to store counts
    unique_counts = []

    # Iterate through each column
    for col in range(matrix.shape[1]):
        # Get unique values in the column
        unique_values = torch.unique(matrix[:, col])
        # Count how many unique values are non-zero
        count = torch.sum(unique_values != 0).item()
        unique_counts.append(count)

    return unique_counts


def calc_metrics_dense(
    sae_feats: torch.Tensor,
    per_token_labels_sparse: Union[np.ndarray, sparse.spmatrix],
    threshold_percents: List[float],
    is_aa_level_concept: List[bool],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Optimized GPU-compatible metric computation for dense matrices.
    """
    device = sae_feats.device
    labels = torch.tensor(per_token_labels_sparse.astype(np.float32), device=device)  # [N, C]
    N, F = sae_feats.shape
    C = labels.shape[1]
    T = len(threshold_percents)

    # Thresholds as tensor [T, 1, 1] for broadcasting
    thresholds = torch.tensor(threshold_percents, dtype=torch.float32, device=device).view(T, 1, 1)

    # Expand and binarize features: [T, N, F]
    feats_exp = sae_feats.unsqueeze(0)  # [1, N, F]
    bin_feats = (feats_exp > thresholds).float()  # [T, N, F]

    # Labels: [1, N, C]
    labels_exp = labels.T.unsqueeze(0)  # [1, C, N]

    # Calculate TP: [T, C, F]
    tp = torch.matmul(labels_exp, bin_feats)  # [T, C, F]
    tp = tp.permute(1, 2, 0).contiguous()  # [C, F, T]

    # Calculate FP: [T, C, F]
    not_labels_exp = (1.0 - labels.T).unsqueeze(0)  # [1, C, N]
    fp = torch.matmul(not_labels_exp, bin_feats)  # [T, C, F]
    fp = fp.permute(1, 2, 0).contiguous()  # [C, F, T]

    # Calculate TP per domain for non-AA-level only
    tp_per_domain = torch.zeros_like(tp)

    # non_aa_indices = [i for i, flag in enumerate(is_aa_level_concept) if not flag]
    # if non_aa_indices:
    #     non_aa_mask = torch.zeros(C, dtype=torch.bool, device=device)
    #     non_aa_mask[non_aa_indices] = True

    #     # For non-AA concepts: compute domain-level TP (number of examples with ≥1 positive feature)
    #     for t_idx in range(T):
    #         # For each threshold: binary_feats [N, F], labels [N, C]
    #         bf = bin_feats[t_idx]  # [N, F]
    #         l = labels  # [N, C]

    #         # Multiply elementwise [N, F] * [N, C] -> [N, C, F]
    #         combined = (bf.unsqueeze(1) * l.unsqueeze(2))  # [N, C, F]
    #         per_domain_tp = (combined.sum(dim=0) > 0).float()  # [C, F]
    #         tp_per_domain[:, :, t_idx] = per_domain_tp

    positive_labels = labels.sum(dim=0)  # [C]
    positive_labels = positive_labels.view(-1, 1, 1)  # [C, 1, 1]

    fn = positive_labels - tp  # [C, F, T]
    fn = torch.clamp(fn, min=0)  # Optional, to avoid negative values due to float precision

    return tp.cpu().numpy(), fp.cpu().numpy(), fn.cpu().numpy(), tp_per_domain.cpu().numpy()

In [9]:
def load_concept_names(concept_name_path: Path) -> List[str]:
    """Load concept names from a file."""
    with open(concept_name_path, "r") as f:
        return f.read().strip().split("\n")


def process_shard(
    sae: AutoEncoder,
    device: torch.device,
    esm_embeddings_pt_path: str,
    per_token_labels: Union[np.ndarray, sparse.spmatrix],
    threshold_percents: List[float],
    is_aa_concept_list: List[bool],
    feat_chunk_max: int = 512,
    is_sparse: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Process a shard of data by splitting it into manageable chunks for feature calculation.

    Args:
        sae: Normalized SAE model
        device: PyTorch device to use for computation
        esm_embeddings_pt_path: Path to ESM embeddings file
        per_token_labels: Label matrix
        threshold_percents: List of threshold values to evaluate
        is_aa_concept_list: Boolean flags indicating if each concept is AA-level
        feat_chunk_max: Maximum chunk size for feature processing
        is_sparse: Whether to use sparse matrix operations

    Returns:
        Tuple of arrays (tp, fp, tp_per_domain) containing calculated metrics
    """
    # Load embeddings to specified device
    esm_acts = torch.load(
        esm_embeddings_pt_path, map_location=device, weights_only=True
    )

    # Calculate chunking parameters
    feature_chunk_size = min(feat_chunk_max, sae.dict_size)
    total_features = sae.dict_size
    num_chunks = int(np.ceil(total_features / feature_chunk_size))
    print(f"Calculating over {total_features} features in {num_chunks} chunks")

    # Initialize result arrays
    n_concepts = per_token_labels.shape[1]
    n_thresholds = len(threshold_percents)
    n_features = sae.dict_size
    tp = np.zeros((n_concepts, n_features, n_thresholds))
    fp = np.zeros((n_concepts, n_features, n_thresholds))
    fn = np.zeros((n_concepts, n_features, n_thresholds))
    tp_per_domain = np.zeros((n_concepts, n_features, n_thresholds))

    # Convert labels to appropriate format
    # per_token_labels = (
    #     sparse.csr_matrix(per_token_labels) if is_sparse else per_token_labels.toarray()
    # )

    # Process each chunk of features
    for feature_list in tqdm(np.array_split(range(total_features), num_chunks)):
        # Get SAE features for current chunk
        sae_feats = get_sae_feats_in_batches(
            sae=sae,
            device=device,
            esm_embds=esm_acts,
            chunk_size=1024,
            feat_list=feature_list,
        )

        # Calculate metrics using either sparse or dense implementation
        # if is_sparse:
        #     sae_feats_sparse = sparse.csr_matrix(sae_feats.cpu().numpy())
        #     metrics = calc_metrics_sparse(
        #         sae_feats_sparse,
        #         per_token_labels,
        #         threshold_percents,
        #         is_aa_concept_list,
        #     )
        # else:
        metrics = calc_metrics_dense(
            sae_feats, per_token_labels, threshold_percents, is_aa_concept_list
        )

        # Update results arrays with computed metrics
        tp_subset, fp_subset, fn_subset, tp_per_domain_subset = metrics
        tp[:, feature_list] = tp_subset
        fp[:, feature_list] = fp_subset
        fn[:, feature_list] = fn_subset
        tp_per_domain[:, feature_list] = tp_per_domain_subset

    return (tp, fp, fn, tp_per_domain)


# def analyze_concepts(
#     adata_path: Path,
#     gene_ids_path: Path,
#     concepts_path: Path,
#     gene_ignore: List,
#     sae_dir: Path,
#     esm_embds_dir: Path = Path("../../data/processed/embeddings"),
#     eval_set_dir: Path = Path("../../data/processed/valid"),
#     output_dir: Path = "concept_results",
#     threshold_percents: List[float] = [0, 0.15, 0.5, 0.6, 0.8],
#     shard: Optional[str] = None, # 'shard_55'
#     is_sparse: bool = True,
# ):
#     """
#     Analyzes concepts in protein sequences using a Sparse Autoencoder (SAE) model.

#     Args:
#         sae_dir (Path): Directory containing the normalized SAE model file 'ae_normalized.pt'
#         esm_embds_dir (Path, optional): Directory containing ESM embeddings.
#         eval_set_dir (Path, optional): Directory containing validation dataset and metadata.
#         output_dir (Path, optional): Directory where results will be saved.
#         threshold_percents (List[float], optional): List of threshold values for concept detection.
#         shard (int | None): Specific shard number to process. Must exist in evaluation set.
#         is_sparse (bool, optional): Whether to use sparse matrix operations.

#     Returns:
#         None: Results are saved to disk as NPZ file with following arrays:
#             - tp: True positives counts
#             - fp: False positives counts
#             - tp_per_domain: True positives counts per domain

#     Raises:
#         ValueError: If normalized SAE model is not found in sae_dir
#         ValueError: If specified shard is not in the evaluation set
#     """


#     # Verify that the normalized SAE model exists
#     if not (sae_dir / "ae_normalized.pt").exists():
#         raise ValueError(f"Normalized SAE model not found in {sae_dir}")
    
#     ad_ = sc.read_h5ad(adata_path)

#     # keep for only current shard
#     ad_subset = ad_[ad_.obs["shards"] == shard].copy()

#     # remove genes not in vocab
#     if "index" in ad_subset.var.columns:
#         genes_to_keep = ~ad_subset.var["index"].isin(gene_ignore)
#     ad_subset = ad_subset[:, genes_to_keep]

#     # print(shard)
#     # file_path = "/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/scgpt/gene_ids/shard_0/all_input_gene_ids.txt"
#     with open(gene_ids_path / shard / "cell_gene_pairs.txt", "r") as f:
#         gene_ids = f.read().split()
    
#     # Load the binary concept-gene matrix
#     df = pd.read_csv(concepts_path / 'gene_concepts.csv', index_col=0)  # Set index_col=0 if the first column is a concept name

#     # Create a dictionary: gene (column) ➜ list of 0/1 for all concepts
#     gene_to_concepts = {gene: df[gene].tolist() for gene in df.columns}
#     concept_names = load_concept_names(concepts_path / "gprofiler_gene_concepts_columns.txt")
#     # print(concept_names)
#     per_token_labels = np.zeros((len(gene_ids), len(concept_names)))

#     # print(per_token_labels.shape)

    

#     for i, gene in enumerate(gene_ids):
#         if gene in gene_to_concepts:
#             # print(len(gene_to_concepts[gene]))
#             # print(gene_to_concepts[gene])
#             per_token_labels[i] = gene_to_concepts[gene]
#         else:
#             # Keep all-zero vector (already initialized)
#             pass

#     # Set up device (GPU if available, otherwise CPU)
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     # Load the normalized SAE model
#     sae = load_sae(model_path=sae_dir / "ae_normalized.pt", device=device)

#     # Process the shard and get results (true positives, false positives, and true positives per domain)
#     (tp, fp, fn, tp_per_domain) = process_shard(
#         sae,
#         device,
#         esm_embds_dir / f"{shard}" / "activations.pt",
#         per_token_labels,
#         threshold_percents,
#         concept_names,
#         feat_chunk_max=250,
#         is_sparse=is_sparse,
#     )

#     # Create output directory if it doesn't exist and save results
#     output_dir.mkdir(parents=True, exist_ok=True)
#     np.savez_compressed(
#         output_dir / f"{shard}_counts.npz",
#         tp=tp,
#         fp=fp,
#         fn=fn,
#         tp_per_domain=tp_per_domain,
#     )

#     # Save per_token_labels to npz in f"shard_{i}/aa_concepts.npz"
#     shard_dir = output_dir / f"{shard}"
#     shard_dir.mkdir(parents=True, exist_ok=True)

#     # Convert to sparse matrix (recommended if data is mostly zeros)
#     per_token_labels_matrix = sparse.csr_matrix(per_token_labels)
#     sparse.save_npz(shard_dir / "gene_concepts.npz", per_token_labels_matrix)


def analyze_concepts(
    adata_path: Path,
    gene_ids_path: Path,
    concepts_path: Path,
    gene_ignore: List,
    sae_dir: Path,
    esm_embds_dir: Path = Path("../../data/processed/embeddings"),
    eval_set_dir: Path = Path("../../data/processed/valid"),
    output_dir: Path = Path("concept_results"),
    threshold_percents: List[float] = [0, 0.15, 0.5, 0.6, 0.8],
    shard: Optional[str] = None,  # e.g., 'shard_55'
    is_sparse: bool = True,
):
    """
    Analyze concepts using an SAE for one shard.

    Saves NPZ:
      - tp, fp, fn: [C, F, T]
      - tp_per_domain: [C, F, T] (placeholder zeros)
      - gene_concepts.npz: sparse CSR of per-token labels (N x C)
    """

    # --- Preconditions ---
    if not (sae_dir / "ae_normalized.pt").exists():
        raise ValueError(f"Normalized SAE model not found in {sae_dir}")

    # Load AnnData (not used for N-alignment; keep as-is for any side checks you need)
    ad_ = sc.read_h5ad(adata_path)
    ad_subset = ad_[ad_.obs["shards"] == shard].copy()
    if "index" in ad_subset.var.columns:
        genes_to_keep = ~ad_subset.var["index"].isin(gene_ignore)
    else:
        genes_to_keep = np.ones(ad_subset.var.shape[0], dtype=bool)
    ad_subset = ad_subset[:, genes_to_keep]

    # --- Load (cell, token) pairs and keep only the token column for labels ---
    pair_path = gene_ids_path / shard / "cell_gene_pairs.txt"
    if not pair_path.exists():
        raise FileNotFoundError(f"Missing {pair_path}")

    toks = (pair_path.read_text()).split()
    if len(toks) % 2 != 0:
        raise ValueError(
            f"{shard}: cell_gene_pairs.txt has odd token count ({len(toks)}). "
            "Expected (cell, token) pairs."
        )

    # (cell, token) pairs -> take only tokens for label rows
    pairs = list(zip(toks[0::2], toks[1::2]))
    _, tokens = zip(*pairs)  # tokens length == N expected for activations

    # --- Load concepts and build token->label rows ---
    df = pd.read_csv(concepts_path / "gene_concepts.csv", index_col=0)
    gene_to_concepts = {gene: df[gene].to_numpy(dtype=np.float32) for gene in df.columns}

    concept_names = load_concept_names(concepts_path / "gprofiler_gene_concepts_columns.txt")
    C = len(concept_names)
    per_token_labels = np.zeros((len(tokens), C), dtype=np.float32)

    # special tokens to zero out (no label)
    special_tokens = {"<cls>", "<pad>", "<unk>", "[CLS]", "[PAD]", "[UNK]"}
    ignore_set = set(gene_ignore) if gene_ignore is not None else set()

    for i, tok in enumerate(tokens):
        # keep row but zero it out for specials / ignored genes to maintain N alignment
        if tok in special_tokens or tok in ignore_set:
            continue
        row = gene_to_concepts.get(tok)
        if row is not None:
            per_token_labels[i] = row  # shape (C,)

    # --- Device & model ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sae = load_sae(model_path=sae_dir / "ae_normalized.pt", device=device)

    # --- OPTIONAL sanity check for N against activations ---
    acts_pt = esm_embds_dir / f"{shard}" / "activations.pt"
    if not acts_pt.exists():
        raise FileNotFoundError(f"Missing activations at {acts_pt}")
    # Only to peek at N; process_shard will load again to the right device
    acts_preview = torch.load(acts_pt, map_location="cpu", weights_only=True)
    if isinstance(acts_preview, torch.Tensor):
        N_acts = acts_preview.shape[0]
    elif isinstance(acts_preview, dict):
        # try common key or first tensor value
        if "activations" in acts_preview and isinstance(acts_preview["activations"], torch.Tensor):
            N_acts = acts_preview["activations"].shape[0]
        else:
            N_acts = next(v for v in acts_preview.values() if isinstance(v, torch.Tensor)).shape[0]
    else:
        raise ValueError(f"Unrecognized activations format at {acts_pt}")

    if len(tokens) != N_acts:
        raise RuntimeError(
            f"N mismatch for {shard}: tokens={len(tokens)} vs activations={N_acts}. "
            "The token list must correspond exactly to the activations generation."
        )

    print(f"[check] shard={shard}  N={len(tokens)}  C={C}")

    # --- Run metrics over feature chunks ---
    (tp, fp, fn, tp_per_domain) = process_shard(
        sae=sae,
        device=device,
        esm_embeddings_pt_path=str(acts_pt),
        per_token_labels=per_token_labels,
        threshold_percents=threshold_percents,
        is_aa_concept_list=[False] * C,  # placeholder; not used in current dense path
        feat_chunk_max=250,
        is_sparse=is_sparse,
    )

    # --- Save outputs ---
    output_dir.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        output_dir / f"{shard}_counts.npz",
        tp=tp,
        fp=fp,
        fn=fn,
        tp_per_domain=tp_per_domain,
    )

    shard_dir = output_dir / f"{shard}"
    shard_dir.mkdir(parents=True, exist_ok=True)

    # Save per-token labels (sparse) for later inspection
    per_token_labels_matrix = sparse.csr_matrix(per_token_labels)
    sparse.save_npz(shard_dir / "gene_concepts.npz", per_token_labels_matrix)



def analyze_all_shards_in_set(
    adata_path: Path,
    sae_dir: Path,
    embds_dir: Path,
    concepts_path: Path,
    eval_set_dir: Path,
    gene_ids_path: Path,
    gene_ignore: List,
    output_dir: Path = "concept_results",
    threshold_percents: List[float] = [0, 0.15, 0.5, 0.6, 0.8],
    is_sparse: bool = True,
):
    """Wrapper to scan calculate metrics across all shards in an evaluation set.

    Args:
        sae_dir (Path): Directory containing the normalized SAE model file 'ae_normalized.pt'
        embds_dir (Path): Directory containing ESM embeddings
        eval_set_dir (Path): Directory containing validation dataset and metadata
        output_dir (Path, optional): Directory where results will be saved.
        threshold_percents (List[float], optional): List of threshold values for concept detection.
        is_sparse (bool, optional): Whether to use sparse matrix operations.

    Returns:
        None: Results for each shard are saved to disk in the output_dir

    Raises:
        FileNotFoundError: If metadata.json is not found in eval_set_dir
        ValueError: If any individual shard analysis fails (inherited from analyze_concepts)
    """
    # Load list of shards to evaluate from metadata
    print(eval_set_dir)
    # with open(eval_set_dir / "metadata.json", "r") as f:
    shards_to_eval = os.listdir(eval_set_dir)
    print(f"Analyzing set {eval_set_dir.stem} with {shards_to_eval} shards")

    # Process each shard sequentially
    for shard in shards_to_eval:
        analyze_concepts(
            adata_path,
            gene_ids_path,
            concepts_path,
            gene_ignore,
            sae_dir,
            embds_dir,
            eval_set_dir,
            output_dir,
            threshold_percents,
            shard,
            is_sparse,
        )

In [10]:
sae_dir=Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/sae_latents/sae_output_layer4')
embds_dir=Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/activations/layer_4')
eval_set_dir=Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/activations/layer_4_test')
output_dir=Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output')
gene_ids_path = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/gene_ids')
concepts_path = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/gprofiler_annotation')

adata_path = Path('/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/ge_shards/cosmx_human_lung_sec8.h5ad')

filtered_genes = ['RGS5', 'CCL3L3']

analyze_all_shards_in_set(
        adata_path=adata_path,
        sae_dir=sae_dir,
        embds_dir=embds_dir,
        concepts_path=concepts_path,
        eval_set_dir=eval_set_dir,
        output_dir=output_dir,
        gene_ids_path=gene_ids_path,
        gene_ignore=filtered_genes
    )

/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/activations/layer_4_test
Analyzing set layer_4_test with ['shard_55', 'shard_57', 'shard_56', 'shard_58', 'shard_59'] shards
[check] shard=shard_55  N=511688  C=2885
Calculating over 4096 features in 17 chunks


100%|██████████| 17/17 [00:25<00:00,  1.49s/it]


[check] shard=shard_57  N=530938  C=2885
Calculating over 4096 features in 17 chunks


100%|██████████| 17/17 [00:25<00:00,  1.53s/it]


[check] shard=shard_56  N=485231  C=2885
Calculating over 4096 features in 17 chunks


100%|██████████| 17/17 [00:23<00:00,  1.36s/it]


[check] shard=shard_58  N=521344  C=2885
Calculating over 4096 features in 17 chunks


100%|██████████| 17/17 [00:50<00:00,  2.99s/it]


[check] shard=shard_59  N=521464  C=2885
Calculating over 4096 features in 17 chunks


100%|██████████| 17/17 [00:25<00:00,  1.48s/it]


In [8]:
from pathlib import Path
import os, json, torch
import numpy as np
from collections import Counter

def infer_N_from_activations(acts):
    """Infer N (#examples/tokens) from an activations.pt structure."""
    if isinstance(acts, torch.Tensor):
        return acts.shape[0], acts.shape
    if isinstance(acts, dict):
        # try common keys / first tensor
        for k in ['activations', 'acts', 'X', 'z', 'hidden_states']:
            if k in acts and isinstance(acts[k], torch.Tensor):
                t = acts[k]
                return t.shape[0], t.shape
        for v in acts.values():
            if isinstance(v, torch.Tensor):
                return v.shape[0], v.shape
    raise ValueError("Cannot infer N from activations structure")

def read_gene_ids(gene_ids_file: Path):
    with open(gene_ids_file, "r") as f:
        ids = f.read().split()
    return ids

def show_shard_alignment(shard: str,
                         embds_dir: Path,
                         gene_ids_path: Path):
    print(f"\n=== SHARD: {shard} ===")
    activations_file = embds_dir / shard / "activations.pt"
    gene_ids_file    = gene_ids_path / shard / "cell_gene_pairs.txt"

    print("activations.pt exists:", activations_file.exists())
    print("cell_gene_pairs.txt exists:", gene_ids_file.exists())

    if activations_file.exists():
        acts = torch.load(activations_file, map_location="cpu")
        try:
            N_acts, shape_acts = infer_N_from_activations(acts)
            print("N_activations:", N_acts, "  full shape:", shape_acts)
        except Exception as e:
            print("Could not infer N from activations:", repr(e))
    else:
        print("Missing activations file.")

    if gene_ids_file.exists():
        ids = read_gene_ids(gene_ids_file)
        print("N_gene_ids:", len(ids))
        # Quick duplicate/format check
        n_unique = len(set(ids))
        print("unique gene_ids:", n_unique, "  duplicates:", len(ids) - n_unique)
        if len(ids) > 0:
            print("first 5 gene_ids:", ids[:5])
            # If entries look like "cell,gene" pairs, show a split preview
            if "," in ids[0] or "\t" in ids[0]:
                delim = "," if "," in ids[0] else "\t"
                preview = [x.split(delim)[:2] for x in ids[:5]]
                print("first 5 parsed pairs:", preview)
    else:
        print("Missing gene_ids file.")

def list_shards_clean(dir_path: Path):
    """List shard names but only include directories (avoid files like .DS_Store)."""
    items = [d for d in os.listdir(dir_path) if (dir_path / d).is_dir()]
    print(f"\nShards under {dir_path}: {len(items)}")
    if items:
        print("first 10:", items[:10])
    return items

# ---- Run the checks ----
print(">>> CHECK 1: Are you iterating over actual shard directories only?")
shards_eval = list_shards_clean(eval_set_dir)
shards_emb  = list_shards_clean(embds_dir)
shards_ids  = list_shards_clean(gene_ids_path)

print("\n>>> CHECK 2: Intersections and mismatches")
S_eval, S_emb, S_ids = set(shards_eval), set(shards_emb), set(shards_ids)
print("eval ∩ embds:", len(S_eval & S_emb))
print("eval ∩ gene_ids:", len(S_eval & S_ids))
print("embds ∩ gene_ids:", len(S_emb & S_ids))
print("Missing in embds:", sorted(S_eval - S_emb)[:10])
print("Missing in gene_ids:", sorted(S_eval - S_ids)[:10])

print("\n>>> CHECK 3: Inspect a few shards in detail")
for s in sorted((S_eval & S_emb & S_ids))[:3]:  # inspect up to 3 shards
    show_shard_alignment(s, embds_dir, gene_ids_path)

>>> CHECK 1: Are you iterating over actual shard directories only?

Shards under /maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/activations/layer_4_test: 5
first 10: ['shard_55', 'shard_57', 'shard_56', 'shard_58', 'shard_59']

Shards under /maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/activations/layer_4: 60
first 10: ['shard_18', 'shard_28', 'shard_40', 'shard_34', 'shard_55', 'shard_12', 'shard_29', 'shard_1', 'shard_54', 'shard_46']

Shards under /maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/gene_ids: 60
first 10: ['shard_18', 'shard_28', 'shard_40', 'shard_34', 'shard_55', 'shard_12', 'shard_29', 'shard_1', 'shard_54', 'shard_46']

>>> CHECK 2: Intersections and mismatches
eval ∩ embds: 5
eval ∩ gene_ids: 5
embds ∩ gene_ids: 60
Missing in embds: []
Missing in gene_ids: []

>>> CHECK 3: Inspect a few shards in detail

=== SHARD: shard_55 ===
activations.pt exists: True
cell_gene_pairs.txt exists: True
N