In [12]:
import json
from pathlib import Path
from typing import List, Tuple, Union, Optional
import os
import numpy as np
import torch
from scipy import sparse
from tqdm import tqdm
import scanpy as sc
import pandas as pd

import sys
sys.path.append('/maiziezhou_lab2/yunfei/Projects/interpTFM/sae')
from dictionary import AutoEncoder
from inference import get_sae_feats_in_batches, load_sae


def process_shard_activations(
    device: torch.device,
    esm_embeddings_pt_path: str,
    per_token_labels: Union[np.ndarray, "sparse.spmatrix"],
    threshold_percents: List[float],
    is_aa_concept_list: List[bool],
    feat_chunk_max: int = 512,
    is_sparse: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Process a shard by computing metrics for raw activation features (no SAE).
    Each activation dimension is treated as a 'feature'.
    """
    # 1) Load activations
    acts = torch.load(
        esm_embeddings_pt_path, map_location=device, weights_only=True
    )
    # Expected shape: [n_tokens_or_genes, n_features]
    if not torch.is_tensor(acts):
        acts = torch.tensor(acts, device=device)
    else:
        acts = acts.to(device)

    # 2) Chunking over feature dimensions (columns)
    total_features = acts.shape[1]
    feature_chunk_size = min(feat_chunk_max, total_features)
    num_chunks = int(np.ceil(total_features / feature_chunk_size))
    print(f"Calculating over {total_features} activation features in {num_chunks} chunks")

    # 3) Allocate outputs
    n_concepts = per_token_labels.shape[1]
    n_thresholds = len(threshold_percents)
    tp = np.zeros((n_concepts, total_features, n_thresholds))
    fp = np.zeros((n_concepts, total_features, n_thresholds))
    fn = np.zeros((n_concepts, total_features, n_thresholds))
    tp_per_domain = np.zeros((n_concepts, total_features, n_thresholds))

    # 4) Ensure labels are dense np.ndarray if calc_metrics_dense expects dense
    # If your calc_metrics_dense supports torch tensors, you can skip this conversion.
    # per_token_labels_arr = per_token_labels.toarray() if hasattr(per_token_labels, "toarray") else per_token_labels
    per_token_labels_arr = per_token_labels

    # 5) Loop over feature chunks; no SAE encoding, we just slice columns
    for feature_list in tqdm(np.array_split(range(total_features), num_chunks)):
        # feature_list is a 1D array of column indices
        feats_subset = acts[:, feature_list]  # shape: [N, |subset|]

        # If calc_metrics_dense expects numpy, convert here:
        # feats_subset_np = feats_subset.detach().cpu().numpy()
        feats_subset_np = feats_subset  # keep as tensor if your calc handles torch

        metrics = calc_metrics_dense(
            feats_subset_np, per_token_labels_arr, threshold_percents, is_aa_concept_list
        )
        tp_subset, fp_subset, fn_subset, tp_per_domain_subset = metrics

        tp[:, feature_list] = tp_subset
        fp[:, feature_list] = fp_subset
        fn[:, feature_list] = fn_subset
        tp_per_domain[:, feature_list] = tp_per_domain_subset

    return (tp, fp, fn, tp_per_domain)

def analyze_concepts_with_activations(
    adata_path: Path,
    gene_ids_path: Path,
    concepts_path: Path,
    gene_ignore: List,
    esm_embds_dir: Path,
    eval_set_dir: Path = Path("../../data/processed/valid"),
    output_dir: Path = Path("concept_results_scgpt_acts"),
    threshold_percents: List[float] = [0, 0.15, 0.5, 0.6, 0.8],
    shard: Optional[str] = None,  # e.g., 'shard_55'
    is_sparse: bool = True,
):
    """
    Associate gene-level concepts with raw scGPT activations (no SAE).
    Saves tp/fp/fn/tp_per_domain for apples-to-apples plots with SAE pipeline.
    """

    # --- Load AnnData (optional filtering, does NOT affect N alignment) ---
    ad_ = sc.read_h5ad(adata_path)
    ad_subset = ad_[ad_.obs["shards"] == shard].copy()
    if "index" in ad_subset.var.columns:
        genes_to_keep = ~ad_subset.var["index"].isin(gene_ignore)
        ad_subset = ad_subset[:, genes_to_keep]

    # --- Parse cell/gene pairs: keep only the token column ---
    pair_path = gene_ids_path / shard / "cell_gene_pairs.txt"
    if not pair_path.exists():
        raise FileNotFoundError(f"Missing {pair_path}")

    toks = (pair_path.read_text()).split()
    if len(toks) % 2 != 0:
        raise ValueError(
            f"{shard}: cell_gene_pairs.txt has odd token count ({len(toks)}). "
            "Expected (cell, token) pairs."
        )
    pairs = list(zip(toks[0::2], toks[1::2]))
    _, tokens = zip(*pairs)  # tokens length should match activations N

    # --- Load concepts and build per-token label matrix ---
    df = pd.read_csv(concepts_path / "gene_concepts.csv", index_col=0)
    # Dict: gene -> vector (C,)
    gene_to_concepts = {gene: df[gene].to_numpy(dtype=np.float32) for gene in df.columns}

    concept_names = load_concept_names(concepts_path / "gprofiler_gene_concepts_columns.txt")
    C = len(concept_names)

    special_tokens = {"<cls>", "<pad>", "<unk>", "[CLS]", "[PAD]", "[UNK]"}
    ignore_set = set(gene_ignore) if gene_ignore is not None else set()

    per_token_labels = np.zeros((len(tokens), C), dtype=np.float32)
    for i, tok in enumerate(tokens):
        if tok in special_tokens or tok in ignore_set:
            continue  # leave zero row to preserve alignment
        row = gene_to_concepts.get(tok)
        if row is not None:
            per_token_labels[i] = row

    # --- Load activations just to verify N alignment ---
    acts_pt = esm_embds_dir / f"{shard}" / "activations.pt"
    if not acts_pt.exists():
        raise FileNotFoundError(f"Missing activations at {acts_pt}")

    # Try to load on CPU first for shape check
    try:
        acts_preview = torch.load(acts_pt, map_location="cpu", weights_only=True)
    except TypeError:
        # Older PyTorch doesn't support weights_only
        acts_preview = torch.load(acts_pt, map_location="cpu")

    if isinstance(acts_preview, torch.Tensor):
        N_acts = acts_preview.shape[0]
        F = acts_preview.shape[1] if acts_preview.ndim >= 2 else None
    elif isinstance(acts_preview, dict):
        # Try common key or first tensor value
        if "activations" in acts_preview and isinstance(acts_preview["activations"], torch.Tensor):
            N_acts = acts_preview["activations"].shape[0]
            F = acts_preview["activations"].shape[1] if acts_preview["activations"].ndim >= 2 else None
        else:
            first_tensor = next(v for v in acts_preview.values() if isinstance(v, torch.Tensor))
            N_acts = first_tensor.shape[0]
            F = first_tensor.shape[1] if first_tensor.ndim >= 2 else None
    else:
        raise ValueError(f"Unrecognized activations format at {acts_pt}")

    if len(tokens) != N_acts:
        raise RuntimeError(
            f"N mismatch for {shard}: tokens={len(tokens)} vs activations={N_acts}. "
            "Ensure the token list matches the activations generation."
        )

    print(f"[check] shard={shard}  N={len(tokens)}  C={C}  F={F}")

    # --- Choose device (keep your cuda:2 preference, fallback to cpu) ---
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

    # --- Compute metrics directly on activations ---
    # NOTE: process_shard_activations is assumed to:
    #   - load activations from `esm_embeddings_pt_path`
    #   - chunk features
    #   - call calc_metrics_dense under the hood
    is_aa_concept_list = [False] * C  # placeholder flags; not used in current dense path
    (tp, fp, fn, tp_per_domain) = process_shard_activations(
        device=device,
        esm_embeddings_pt_path=acts_pt,
        per_token_labels=per_token_labels,
        threshold_percents=threshold_percents,
        is_aa_concept_list=is_aa_concept_list,
        feat_chunk_max=250,
        is_sparse=is_sparse,
    )

    # --- Save outputs ---
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        output_dir / f"{shard}_counts.npz",
        tp=tp, fp=fp, fn=fn, tp_per_domain=tp_per_domain,
    )

    shard_dir = output_dir / f"{shard}"
    shard_dir.mkdir(parents=True, exist_ok=True)

    # Save labels (sparse recommended)
    per_token_labels_matrix = sparse.csr_matrix(per_token_labels)
    sparse.save_npz(shard_dir / "gene_concepts.npz", per_token_labels_matrix)

    # Save concept names for later reference
    with open(shard_dir / "concept_names.txt", "w") as f:
        f.write("\n".join(concept_names))


def analyze_all_shards_with_activations(
    adata_path: Path,
    embds_dir: Path,
    concepts_path: Path,
    eval_set_dir: Path,
    gene_ids_path: Path,
    gene_ignore: List,
    output_dir: Path = Path("concept_results_scgpt_acts"),
    threshold_percents: List[float] = [0, 0.15, 0.5, 0.6, 0.8],
    is_sparse: bool = True,
):
    shards_to_eval = os.listdir(eval_set_dir)
    print(f"Analyzing set {eval_set_dir.stem} with {len(shards_to_eval)} shards")

    for shard in shards_to_eval:
        analyze_concepts_with_activations(
            adata_path=adata_path,
            gene_ids_path=gene_ids_path,
            concepts_path=concepts_path,
            gene_ignore=gene_ignore,
            esm_embds_dir=embds_dir,
            eval_set_dir=eval_set_dir,
            output_dir=output_dir,
            threshold_percents=threshold_percents,
            shard=shard,
            is_sparse=is_sparse,
        )

def load_concept_names(concept_name_path: Path) -> List[str]:
    """Load concept names from a file."""
    with open(concept_name_path, "r") as f:
        return f.read().strip().split("\n")

In [7]:
def count_unique_nonzero_dense(matrix: torch.Tensor) -> List[int]:
    """
    Count unique non-zero values in each column of a dense matrix.

    Args:
        matrix: Dense PyTorch tensor to analyze

    Returns:
        List of counts of unique non-zero values for each column
    """
    # Initialize list to store counts
    unique_counts = []

    # Iterate through each column
    for col in range(matrix.shape[1]):
        # Get unique values in the column
        unique_values = torch.unique(matrix[:, col])
        # Count how many unique values are non-zero
        count = torch.sum(unique_values != 0).item()
        unique_counts.append(count)

    return unique_counts


def calc_metrics_dense(
    sae_feats: torch.Tensor,
    per_token_labels_sparse: Union[np.ndarray, sparse.spmatrix],
    threshold_percents: List[float],
    is_aa_level_concept: List[bool],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Optimized GPU-compatible metric computation for dense matrices.
    """
    device = sae_feats.device
    labels = torch.tensor(per_token_labels_sparse.astype(np.float32), device=device)  # [N, C]
    N, F = sae_feats.shape
    C = labels.shape[1]
    T = len(threshold_percents)

    # Thresholds as tensor [T, 1, 1] for broadcasting
    thresholds = torch.tensor(threshold_percents, dtype=torch.float32, device=device).view(T, 1, 1)

    # Expand and binarize features: [T, N, F]
    feats_exp = sae_feats.unsqueeze(0)  # [1, N, F]
    bin_feats = (feats_exp > thresholds).float()  # [T, N, F]

    # Labels: [1, N, C]
    labels_exp = labels.T.unsqueeze(0)  # [1, C, N]

    # Calculate TP: [T, C, F]
    tp = torch.matmul(labels_exp, bin_feats)  # [T, C, F]
    tp = tp.permute(1, 2, 0).contiguous()  # [C, F, T]

    # Calculate FP: [T, C, F]
    not_labels_exp = (1.0 - labels.T).unsqueeze(0)  # [1, C, N]
    fp = torch.matmul(not_labels_exp, bin_feats)  # [T, C, F]
    fp = fp.permute(1, 2, 0).contiguous()  # [C, F, T]

    # Calculate TP per domain for non-AA-level only
    tp_per_domain = torch.zeros_like(tp)

    # non_aa_indices = [i for i, flag in enumerate(is_aa_level_concept) if not flag]
    # if non_aa_indices:
    #     non_aa_mask = torch.zeros(C, dtype=torch.bool, device=device)
    #     non_aa_mask[non_aa_indices] = True

    #     # For non-AA concepts: compute domain-level TP (number of examples with ≥1 positive feature)
    #     for t_idx in range(T):
    #         # For each threshold: binary_feats [N, F], labels [N, C]
    #         bf = bin_feats[t_idx]  # [N, F]
    #         l = labels  # [N, C]

    #         # Multiply elementwise [N, F] * [N, C] -> [N, C, F]
    #         combined = (bf.unsqueeze(1) * l.unsqueeze(2))  # [N, C, F]
    #         per_domain_tp = (combined.sum(dim=0) > 0).float()  # [C, F]
    #         tp_per_domain[:, :, t_idx] = per_domain_tp

    positive_labels = labels.sum(dim=0)  # [C]
    positive_labels = positive_labels.view(-1, 1, 1)  # [C, 1, 1]

    fn = positive_labels - tp  # [C, F, T]
    fn = torch.clamp(fn, min=0)  # Optional, to avoid negative values due to float precision

    return tp.cpu().numpy(), fp.cpu().numpy(), fn.cpu().numpy(), tp_per_domain.cpu().numpy()

In [13]:
# sae_dir=Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/sae_latents/sae_output_layer4')
embds_dir=Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/activations/layer_4')
eval_set_dir=Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/activations/layer_4_test')
output_dir=Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output_acts')
gene_ids_path = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/gene_ids')
concepts_path = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/gprofiler_annotation')

adata_path = Path('/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/ge_shards/cosmx_human_lung_sec8.h5ad')

filtered_genes = ['RGS5', 'CCL3L3']

analyze_all_shards_with_activations(
        adata_path=adata_path,
        # sae_dir=sae_dir,
        embds_dir=embds_dir,
        concepts_path=concepts_path,
        eval_set_dir=eval_set_dir,
        output_dir=output_dir,
        gene_ids_path=gene_ids_path,
        gene_ignore=filtered_genes
    )

Analyzing set layer_4_test with 5 shards
[check] shard=shard_55  N=511688  C=2885  F=512
Calculating over 512 activation features in 3 chunks


100%|██████████| 3/3 [00:04<00:00,  1.54s/it]


[check] shard=shard_57  N=530938  C=2885  F=512
Calculating over 512 activation features in 3 chunks


100%|██████████| 3/3 [00:05<00:00,  1.68s/it]


[check] shard=shard_56  N=485231  C=2885  F=512
Calculating over 512 activation features in 3 chunks


100%|██████████| 3/3 [00:04<00:00,  1.45s/it]


[check] shard=shard_58  N=521344  C=2885  F=512
Calculating over 512 activation features in 3 chunks


100%|██████████| 3/3 [00:04<00:00,  1.57s/it]


[check] shard=shard_59  N=521464  C=2885  F=512
Calculating over 512 activation features in 3 chunks


100%|██████████| 3/3 [00:04<00:00,  1.57s/it]
