In [20]:
import sys, json, os
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Iterator, Tuple, Optional, List
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy as np
from typing import Tuple
from pathlib import Path
import math
from collections import OrderedDict

PROJECT_DIR = "/gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER"
SRC_DIR = str(Path(PROJECT_DIR) / "src")
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

from multiomic_transformer.models.model import MultiomicTransformer
from multiomic_transformer.datasets.dataset_refactor import MultiChromosomeDataset, SimpleScaler, fit_simple_scalers
import multiomic_transformer.utils.experiment_loader as experiment_loader

def tanh_scaled_figsize(
    rows: int,
    cols: int,
    short_in: float = 3.0,
    max_ratio: float = 5.0,
    alpha: float = 0.85,
    min_w: float = 2.5,
    max_w: float = 10.0,
    min_h: float = 2.0,
    max_h: float = 10.0,
):
    if rows <= 0 or cols <= 0:
        return (short_in, short_in)

    raw_ratio = max(rows, cols) / max(1, min(rows, cols))  # >= 1
    scaled_ratio = 1.0 + (max_ratio - 1.0) * math.tanh(alpha * math.log(raw_ratio))

    if rows >= cols:
        fig_w = short_in
        fig_h = short_in * scaled_ratio
    else:
        fig_w = short_in * scaled_ratio
        fig_h = short_in

    # ---- NEW: clamp absolute size ----
    fig_w = float(np.clip(fig_w, min_w, max_w))
    fig_h = float(np.clip(fig_h, min_h, max_h))

    return fig_w, fig_h


def downsample_2d_mean(data: np.ndarray, max_rows=1500, max_cols=1500):
    r, c = data.shape
    row_bin = max(1, int(np.ceil(r / max_rows)))
    col_bin = max(1, int(np.ceil(c / max_cols)))

    r2 = (r // row_bin) * row_bin
    c2 = (c // col_bin) * col_bin
    data = data[:r2, :c2]

    data = data.reshape(r2 // row_bin, row_bin, c2 // col_bin, col_bin).mean(axis=(1, 3))
    return data, row_bin, col_bin

def save_input_heatmaps_from_batch(exp, out_dir: Tuple[str, Path], batch_idx: int = 0):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    batch = next(iter(exp.test_loader))
    atac_wins, tf_tensor, tg_expr_true, bias, tf_ids, tg_ids, motif_mask = batch

    # ---- TF expression vector (usually [B, n_tf] or [B, n_tf, 1]) ----
    tf0 = tf_tensor[batch_idx]
    if tf0.ndim == 2 and tf0.shape[-1] == 1:
        tf0 = tf0.squeeze(-1)

    save_heatmap_svg_rasterized_downsampled(
        tf0,
        out_dir / f"input_tf_expr.svg",
        title=f"Input TF Expr",
    )

    # ---- TG true expression (often [B, n_tg] or [B, n_tg, 1]) ----
    tg0 = tg_expr_true[batch_idx]
    if tg0.ndim == 2 and tg0.shape[-1] == 1:
        tg0 = tg0.squeeze(-1)

    save_heatmap_svg_rasterized_downsampled(
        tg0,
        out_dir / f"input_tg_expr_true.svg",
        title=f"Input TG Expr True",
    )
    
    # ---- Batch TG embedding -----
    tg_ids0 = tg_ids.to(exp.model.tg_identity_emb.weight.device)

    tg_id_slice = exp.model.tg_identity_emb.weight[tg_ids0]  # [193, 128]
    save_heatmap_svg_rasterized_downsampled(
        tg_id_slice,
        "./dev/model_heatmaps_svg/tg_identity_emb_slice_batch.svg",
        title="TGs in Batch",
    )

    # ---- Bias (shape varies; make it 2D-ish) ----
    b0 = bias[batch_idx]
    save_heatmap_svg_rasterized_downsampled(
        b0,
        out_dir / f"input_bias.svg",
        title=f"Input Bias",
    )

    # ---- Motif mask (often [B, n_tf, n_tg] or similar) ----
    mm0 = motif_mask[batch_idx]
    save_heatmap_svg_rasterized_downsampled(
        mm0,
        out_dir / f"input_motif_mask.svg",
        title=f"Input Motif Mask",
    )

    # ---- ATAC windows: pick something plottable ----
    # If atac_wins is [B, n_peaks, n_bins] → save as 2D
    # If higher-dim, flatten everything but last dim.
    a0 = atac_wins[batch_idx]
    if a0.ndim > 2:
        a0 = a0.reshape(a0.shape[0], -1)  # keep first dim, flatten rest
    save_heatmap_svg_rasterized_downsampled(
        a0,
        out_dir / f"input_atac_wins.svg",
        title=f"Input ATAC Wins",
        max_rows=1500,
        max_cols=1500,
    )

    print(f"Saved input heatmaps to: {out_dir}")

def save_heatmap_svg_rasterized_downsampled(
    weight_tensor: torch.Tensor,
    out_path: Tuple[str, Path],
    title: str = "",
    cmap: str = "viridis",
    short_in: float = 3.0,
    max_ratio: float = 5.0,
    alpha: float = 0.85,
    dpi: int = 200,
    rasterize: bool = True,
    vmin=None,
    vmax=None,
    max_rows: int = 1500,   # <-- NEW
    max_cols: int = 1500,   # <-- NEW
    is_data: bool = False,
):
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    data = weight_tensor.detach().cpu().numpy()

    if data.ndim == 1:
        data = data[None, :]  # shape (1, N)

    if data.ndim != 2:
        data = np.asarray(data).reshape(data.shape[0], -1)

    rows, cols = data.shape
    if rows > max_rows or cols > max_cols:
        data, row_bin, col_bin = downsample_2d_mean(data, max_rows=max_rows, max_cols=max_cols)
        rows, cols = data.shape
        title = f"{title}"

    fig_w, fig_h = tanh_scaled_figsize(
        rows, cols,
        short_in=short_in,
        max_ratio=max_ratio,
        alpha=alpha,
        min_w=3.0, max_w=9.0,
        min_h=2.0, max_h=7.0,
    )
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    
    if is_data:
        cmap = sns.color_palette("viridis", as_cmap=True)
    else:
        cmap = cmap

    sns.heatmap(
        data,
        cmap=cmap,
        xticklabels=False,
        yticklabels=False,
        cbar=False,
        ax=ax,
        vmin=vmin,
        vmax=vmax,
    )

    if rasterize and ax.collections:
        ax.collections[0].set_rasterized(True)

    title = title.replace("_", " ")
    ax.set_title(title, fontsize=24)
    ax.set_xlabel(f"{cols}", fontsize=32)
    ax.set_ylabel(f"{rows}", fontsize=32)

    fig.tight_layout(pad=0.15)
    fig.savefig(out_path, format=out_path.suffix.lstrip("."), dpi=dpi, bbox_inches="tight", facecolor=fig.get_facecolor())
    plt.close(fig)


def iter_weight_matrices(model: torch.nn.Module) -> Iterator[Tuple[str, torch.Tensor]]:
    """
    Yields (name, weight_tensor) for 2D weight-like tensors.
    Embeddings are treated as 2D weights too.
    """
    for name, param in model.named_parameters():
        if not isinstance(param, torch.Tensor):
            continue
        if param.ndim == 2:
            yield name, param.detach()
        # If you ever want to include 1D params (bias, layernorm), handle separately.


def split_in_proj_weight_qkv(
    in_proj_weight: torch.Tensor,
    d_model: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Split PyTorch MultiheadAttention in_proj_weight into (Wq, Wk, Wv).
    in_proj_weight shape: (3*d_model, d_model)
    """
    if in_proj_weight.ndim != 2:
        raise ValueError(f"in_proj_weight must be 2D, got {in_proj_weight.shape}")

    rows, cols = in_proj_weight.shape
    if d_model is None:
        d_model = cols

    if rows != 3 * d_model:
        raise ValueError(
            f"Expected in_proj_weight shape (3*d_model, d_model) = ({3*d_model}, {d_model}), "
            f"got {in_proj_weight.shape}"
        )

    Wq = in_proj_weight[0:d_model, :]
    Wk = in_proj_weight[d_model:2*d_model, :]
    Wv = in_proj_weight[2*d_model:3*d_model, :]
    return Wq, Wk, Wv


def save_model_weight_heatmaps(
    model: torch.nn.Module,
    out_dir: Tuple[str, Path],
    include: Optional[List[str]] = None,   # OR logic
    exclude: Optional[List[str]] = None,   # drop if any match
    cmap: str = "viridis",
    max_rows: int = 1500,
    max_cols: int = 1500,
    dpi: int = 200,
    short_in: float = 3.0,
    max_ratio: float = 5.0,
    alpha: float = 0.85,
    split_qkv: bool = True,               # <-- NEW
):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    saved = 0

    for name, w in iter_weight_matrices(model):
        # include/exclude filters
        if include is not None and not any(s in name for s in include):
            continue
        if exclude is not None and any(s in name for s in exclude):
            continue

        # ---- Special case: split encoder self-attn packed QKV ----
        # Typical names:
        #   encoder.layers.0.self_attn.in_proj_weight
        #   encoder.layers.1.self_attn.in_proj_weight
        if split_qkv and name.endswith("self_attn.in_proj_weight"):
            d_model = w.shape[1]  # cols
            Wq, Wk, Wv = split_in_proj_weight_qkv(w, d_model=d_model)

            for tag, mat in [("Q", Wq), ("K", Wk), ("V", Wv)]:
                safe_name = name.replace("/", "_") + f"_{tag}"
                out_path = out_dir / f"{safe_name}.svg"
                save_heatmap_svg_rasterized_downsampled(
                    mat,
                    out_path=out_path,
                    title=f"{name} [{tag}]",
                    cmap=cmap,
                    short_in=short_in,
                    max_ratio=max_ratio,
                    alpha=alpha,
                    dpi=dpi,
                    # max_rows=max_rows,
                    # max_cols=max_cols,
                )
                saved += 1

            continue  # don't save the combined packed matrix unless you also want it
        
        else:
            # ---- Default: save weight as-is ----
            safe_name = name.replace("/", "_")
            out_path = out_dir / f"{safe_name}.svg"
            save_heatmap_svg_rasterized_downsampled(
                w,
                out_path=out_path,
                title=name,
                cmap=cmap,
                short_in=short_in,
                max_ratio=max_ratio,
                alpha=alpha,
                dpi=dpi,
                # max_rows=max_rows,
                # max_cols=max_cols,
            )
        saved += 1

    print(f"Saved {saved} heatmaps to: {out_dir}")

def run_encoder_with_intermediates(encoder: torch.nn.TransformerEncoder,
                                  x: torch.Tensor,
                                  need_weights: bool = True):
    """
    encoder: exp.model.encoder (TransformerEncoder)
    x: [B, W, d_model]  (batch_first=True style)
    Returns:
      x_out: final [B, W, d_model]
      enc_dict: OrderedDict of intermediate tensors
    """
    enc_dict = OrderedDict()
    h = x

    for li, layer in enumerate(encoder.layers):
        prefix = f"encoder.layers.{li}"

        # ---- MultiheadAttention expects [L, N, E] unless batch_first=True inside MHA.
        # In torch.nn.TransformerEncoderLayer, self_attn is MultiheadAttention(batch_first=...).
        # We'll handle both cases robustly:
        is_batch_first = getattr(layer.self_attn, "batch_first", False)

        q = k = v = h
        if not is_batch_first:
            q = q.transpose(0, 1)  # [W, B, d]
            k = k.transpose(0, 1)
            v = v.transpose(0, 1)

        # =========================
        # Compute Q, K, V explicitly
        # =========================
        mha = layer.self_attn
        d_model = mha.embed_dim
        num_heads = mha.num_heads
        head_dim = d_model // num_heads

        # ---- Split packed projection weights ----
        Wq, Wk, Wv = mha.in_proj_weight.chunk(3, dim=0)   # each [d_model, d_model]

        if mha.in_proj_bias is not None:
            bq, bk, bv = mha.in_proj_bias.chunk(3, dim=0)
        else:
            bq = bk = bv = None

        # ---- Compute projected Q/K/V ----
        # h is [B, W, d_model] (batch_first assumed here)
        Q = torch.nn.functional.linear(h, Wq, bq)   # [B, W, d_model]
        K = torch.nn.functional.linear(h, Wk, bk)
        V = torch.nn.functional.linear(h, Wv, bv)

        enc_dict[f"{prefix}.self_attn.Q"] = Q
        enc_dict[f"{prefix}.self_attn.K"] = K
        enc_dict[f"{prefix}.self_attn.V"] = V

        # ---- Reshape into heads (optional but very useful) ----
        B, W, _ = Q.shape
        Qh = Q.view(B, W, num_heads, head_dim).transpose(1, 2)  # [B, H, W, d_head]
        Kh = K.view(B, W, num_heads, head_dim).transpose(1, 2)
        Vh = V.view(B, W, num_heads, head_dim).transpose(1, 2)

        enc_dict[f"{prefix}.self_attn.Q_heads"] = Qh
        enc_dict[f"{prefix}.self_attn.K_heads"] = Kh
        enc_dict[f"{prefix}.self_attn.V_heads"] = Vh
        
        attn_out, attn_w = layer.self_attn(
            q, k, v,
            need_weights=need_weights,
            average_attn_weights=False if need_weights else True,  # [B,H,W,W] if supported
        )

        if not is_batch_first:
            attn_out = attn_out.transpose(0, 1)  # back to [B,W,d]

        enc_dict[f"{prefix}.self_attn.out"] = attn_out
        if need_weights and attn_w is not None:
            # attn_w is typically [B, H, W, W] when average_attn_weights=False
            enc_dict[f"{prefix}.self_attn.weights"] = attn_w

        # ---- Residual 1 + Norm 1 ----
        h_attn = h + layer.dropout1(attn_out)
        enc_dict[f"{prefix}.resid1"] = h_attn
        h_norm1 = layer.norm1(h_attn)
        enc_dict[f"{prefix}.norm1"] = h_norm1

        # ---- FFN ----
        ffn1 = layer.linear1(h_norm1)
        enc_dict[f"{prefix}.ffn.linear1"] = ffn1

        # activation is usually GELU in newer torch, ReLU in older; handle both
        if hasattr(layer, "activation"):
            ffn_act = layer.activation(ffn1)
        else:
            # older versions used F.relu internally; safe fallback:
            ffn_act = torch.relu(ffn1)
        enc_dict[f"{prefix}.ffn.act"] = ffn_act

        ffn_drop = layer.dropout(ffn_act)
        enc_dict[f"{prefix}.ffn.dropout"] = ffn_drop

        ffn2 = layer.linear2(ffn_drop)
        enc_dict[f"{prefix}.ffn.linear2"] = ffn2

        # ---- Residual 2 + Norm 2 ----
        h_ffn = h_norm1 + layer.dropout2(ffn2)
        enc_dict[f"{prefix}.resid2"] = h_ffn
        h = layer.norm2(h_ffn)
        enc_dict[f"{prefix}.norm2"] = h

    return h, enc_dict


### Load the experiment

In [21]:
exp = experiment_loader.ExperimentLoader(
    experiment_dir = "/gpfs/Labs/Uzun/DATA/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/experiments/",
    experiment_name="mESC_E7.5_rep1_hvg_filter_disp_0.5",
    model_num=1,
)

exp.load_trained_model("trained_model.pt")

In [22]:
exp.model

MultiomicTransformer(
  (tf_identity_emb): Embedding(256, 128)
  (tg_query_emb): Embedding(25090, 128)
  (tg_identity_emb): Embedding(25090, 128)
  (tf_expr_dense_input_layer): Sequential(
    (0): Linear(in_features=1, out_features=512, bias=True)
    (1): SiLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=128, bias=False)
    (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (atac_acc_dense_input_layer): Sequential(
    (0): Linear(in_features=1, out_features=512, bias=True)
    (1): SiLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=128, bias=False)
    (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (posenc): PositionalEmbedding()
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
 

### Save the Heatmaps of the Model Weights

In [37]:
save_model_weight_heatmaps(
    exp.model,
    "./dev/model_heatmaps_svg/model_weights",
    split_qkv=True,
)

Saved 41 heatmaps to: dev/model_heatmaps_svg/model_weights


### Save Heatmaps of the Input Data

In [25]:
save_heatmap_svg_rasterized_downsampled(
    exp.model.tf_to_atac_cross_attn_pool.query,
    "./dev/model_heatmaps_svg/tf_to_atac_cross_attn_pool.query.svg",
    title="TF→ATAC Attention Pool Query",
)

save_heatmap_svg_rasterized_downsampled(
    exp.model.atac_to_tf_cross_attn_pool.query,
    "./dev/model_heatmaps_svg/atac_to_tf_cross_attn_pool.query.svg",
    title="ATAC→TF Attention Pool Query",
)

In [40]:
from collections import OrderedDict
import math
import torch

batch = next(iter(exp.test_loader))
atac_wins, tf_tensor, tg_expr_true, bias, tf_ids, tg_ids, motif_mask = batch
device = next(exp.model.parameters()).device

# move tensors you actually compute with
atac_wins  = atac_wins.to(device)
tf_tensor  = tf_tensor.to(device)
tf_ids     = tf_ids.to(device)
tg_ids     = tg_ids.to(device)
bias       = None if bias is None else bias.to(device)
motif_mask = None if motif_mask is None else motif_mask.to(device)

with torch.no_grad():

    # -------------------- ATAC dense input layer --------------------
    atac_lin1        = exp.model.atac_acc_dense_input_layer[0](atac_wins)        # [B, W, d_ff]
    atac_silu        = exp.model.atac_acc_dense_input_layer[1](atac_lin1)        # [B, W, d_ff]
    atac_dropout     = exp.model.atac_acc_dense_input_layer[2](atac_silu)        # [B, W, d_ff]
    atac_lin2        = exp.model.atac_acc_dense_input_layer[3](atac_dropout)     # [B, W, d_model]
    atac_layer_norm  = exp.model.atac_acc_dense_input_layer[4](atac_lin2)        # [B, W, d_model]

    win_emb = atac_layer_norm                                                   # [B, W, d_model]

    # positional encoding + add
    W = win_emb.shape[1]
    pos = torch.arange(W, device=device, dtype=torch.float32)
    pos_emb = exp.model.posenc(pos, bsz=win_emb.shape[0]).transpose(0, 1)        # [B, W, d_model]
    win_emb_pos = win_emb + pos_emb                                             # [B, W, d_model]

    # -------------------- ATAC Transformer encoder --------------------
    win_enc, encoder_dict = run_encoder_with_intermediates(exp.model.encoder, win_emb_pos, need_weights=True)                                                 # [B, W, d_model]

    # -------------------- TF embeddings --------------------
    tf_id_emb = exp.model.tf_identity_emb(tf_ids)                                # [T, d_model]

    # TF expr dense input layer (step-by-step)
    tf_expr_in = tf_tensor.unsqueeze(-1)                                         # [B, T, 1]
    tf_lin1    = exp.model.tf_expr_dense_input_layer[0](tf_expr_in)              # [B, T, d_ff]
    tf_silu    = exp.model.tf_expr_dense_input_layer[1](tf_lin1)                 # [B, T, d_ff]
    tf_dropout = exp.model.tf_expr_dense_input_layer[2](tf_silu)                 # [B, T, d_ff]
    tf_lin2    = exp.model.tf_expr_dense_input_layer[3](tf_dropout)              # [B, T, d_model]
    tf_ln      = exp.model.tf_expr_dense_input_layer[4](tf_lin2)                 # [B, T, d_model]
    tf_expr_emb = tf_ln

    # combined TF embedding used for cross-attn
    tf_emb = tf_expr_emb + tf_id_emb.unsqueeze(0)                                # [B, T, d_model]

    # -------------------- TF<->ATAC cross attention --------------------
    # CrossAttention forward: out = norm(query + 0.1*dropout(attn(query, kv)))
    tf_cross   = exp.model.cross_tf_to_atac(tf_emb,  win_enc)                    # [B, T, d_model]
    atac_cross = exp.model.cross_atac_to_tf(win_enc, tf_emb)                     # [B, W, d_model]

    # -------------------- Attention pooling --------------------
    tf_repr,   tf_pool_weights   = exp.model.tf_to_atac_cross_attn_pool(tf_cross)     # [B, d_model], [B, T, 1]
    atac_repr, atac_pool_weights = exp.model.atac_to_tf_cross_attn_pool(atac_cross)  # [B, d_model], [B, W, 1]

    # -------------------- Pooled cross-attn dense layer --------------------
    pooled_cat = torch.cat([tf_repr, atac_repr], dim=-1)                         # [B, 2*d_model]
    pooled_lin1    = exp.model.pooled_cross_attn_dense_layer[0](pooled_cat)      # [B, d_ff]
    pooled_gelu    = exp.model.pooled_cross_attn_dense_layer[1](pooled_lin1)     # [B, d_ff]
    pooled_dropout = exp.model.pooled_cross_attn_dense_layer[2](pooled_gelu)     # [B, d_ff]
    pooled_lin2    = exp.model.pooled_cross_attn_dense_layer[3](pooled_dropout)  # [B, d_model]
    pooled_ln      = exp.model.pooled_cross_attn_dense_layer[4](pooled_lin2)     # [B, d_model]
    tf_atac_cross_attn_output = pooled_ln                                        # [B, d_model]

    # -------------------- TG query / identity embeddings --------------------
    tg_query_emb = exp.model.tg_query_emb(tg_ids)                                # [G, d_model]
    tg_base = tg_query_emb.unsqueeze(0).expand(win_enc.shape[0], -1, -1)         # [B, G, d_model]

    tg_id_emb = exp.model.tg_identity_emb(tg_ids)                                # [G, d_model]

    # -------------------- distance bias shaping (as in forward) --------------------
    attn_bias = None
    if exp.model.use_bias and (bias is not None):
        attn_bias = bias
        if attn_bias.dim() == 3:
            attn_bias = attn_bias.unsqueeze(1)                                   # [B, 1, G, W]
        if attn_bias.shape[1] == 1:
            attn_bias = attn_bias.expand(win_enc.shape[0], exp.model.num_heads, tg_base.size(1), win_enc.size(1))
        attn_bias = torch.nan_to_num(attn_bias, nan=0.0, posinf=1e4, neginf=-1e4)
        attn_bias = (exp.model.bias_scale * attn_bias).clamp_(-20.0, 20.0)       # [B, H, G, W]

    # -------------------- TG->ATAC cross attention --------------------
    tg_cross = exp.model.cross_tg_to_atac(tg_base, win_enc, attn_bias=attn_bias) # [B, G, d_model]

    # add pooled TF/ATAC summary to each TG (scaled)
    n_tgs = tg_cross.size(1)
    scale = 1.0 / math.sqrt(max(1, n_tgs))
    tf_atac_expand = tf_atac_cross_attn_output.unsqueeze(1).expand(-1, n_tgs, -1) * scale  # [B, G, d_model]
    tg_cross_attn_repr = tg_cross + tf_atac_expand                                # [B, G, d_model]

    # -------------------- TG identity dot + gene_pred_dense --------------------
    tg_similarity_to_attn_output = (tg_cross_attn_repr * tg_id_emb.unsqueeze(0)).sum(dim=-1)  # [B, G]

    gene_lin1    = exp.model.gene_pred_dense[0](tg_cross_attn_repr)               # [B, G, d_ff]
    gene_relu    = exp.model.gene_pred_dense[1](gene_lin1)                        # [B, G, d_ff]
    gene_dropout = exp.model.gene_pred_dense[2](gene_relu)                        # [B, G, d_ff]
    gene_lin2    = exp.model.gene_pred_dense[3](gene_dropout)                     # [B, G, 1]
    gene_pred_term = gene_lin2.squeeze(-1)                                        # [B, G]

    tg_pred_pre_shortcut = tg_similarity_to_attn_output + gene_pred_term          # [B, G]

    # -------------------- Optional TF->TG shortcut --------------------
    shortcut_out = None
    shortcut_attn = None
    tg_pred = tg_pred_pre_shortcut

    if getattr(exp.model, "use_shortcut", False) and hasattr(exp.model, "shortcut_layer"):
        # shortcut_layer expects: tg_emb [G,d], tf_id_emb [T,d], tf_expr [B,T], motif_mask [G,T]
        # If your motif_mask is batched, slice batch 0; otherwise pass as-is.
        mm_for_shortcut = None
        if motif_mask is not None:
            mm_for_shortcut = motif_mask[0] if motif_mask.dim() == 3 else motif_mask

        shortcut_out, shortcut_attn = exp.model.shortcut_layer(
            tg_id_emb, tf_id_emb, tf_tensor, motif_mask=mm_for_shortcut
        )  # shortcut_out: [B,G], shortcut_attn: [G,T]
        tg_pred = tg_pred + shortcut_out                                          # [B, G]


# -------------------- model_dict (grouped outputs) --------------------
model_dict = OrderedDict({
    "inputs": [
        ("atac_windows", atac_wins),
        ("tf_expr", tf_tensor),
        ("tg_expr_true", tg_expr_true),
        ("bias", bias),
        ("tf_ids", tf_ids),
        ("tg_ids", tg_ids),
        ("motif_mask", motif_mask),
    ],

    "atac_dense_input_layer": [
        ("lin1", atac_lin1),
        ("silu", atac_silu),
        ("dropout", atac_dropout),
        ("lin2", atac_lin2),
        ("layer_norm", atac_layer_norm),
        ("pos_emb", pos_emb),
        ("win_emb_pos", win_emb_pos),
    ],

    "window_encoder": (
        [("output", win_enc)]
        + [(k, v) for k, v in encoder_dict.items()]
    ),

    "tf_id_emb": [
        ("embedding", tf_id_emb),
    ],

    "tf_dense_input_layer": [
        ("expr_in", tf_expr_in),
        ("lin1", tf_lin1),
        ("silu", tf_silu),
        ("dropout", tf_dropout),
        ("lin2", tf_lin2),
        ("layer_norm", tf_ln),
    ],

    "tf_emb_combined": [
        ("tf_expr_emb", tf_expr_emb),
        ("tf_emb", tf_emb),
    ],

    "cross_tf_to_atac": [
        ("tf_cross", tf_cross),
    ],

    "cross_atac_to_tf": [
        ("atac_cross", atac_cross),
    ],

    "attention_pooling": [
        ("tf_repr", tf_repr),
        ("tf_pool_weights", tf_pool_weights),
        ("atac_repr", atac_repr),
        ("atac_pool_weights", atac_pool_weights),
    ],

    "pooled_cross_attn_dense_layer": [
        ("cat_tf_atac", pooled_cat),
        ("lin1", pooled_lin1),
        ("gelu", pooled_gelu),
        ("dropout", pooled_dropout),
        ("lin2", pooled_lin2),
        ("layer_norm", pooled_ln),
        ("tf_atac_cross_attn_output", tf_atac_cross_attn_output),
    ],

    "tg_embeddings": [
        ("tg_query_emb", tg_query_emb),
        ("tg_base", tg_base),
        ("tg_id_emb", tg_id_emb),
    ],

    "tg_to_atac_bias": [
        ("attn_bias", attn_bias),
    ],

    "cross_tg_to_atac": [
        ("tg_cross", tg_cross),
    ],

    "tg_cross_attn_fusion": [
        ("tf_atac_expand", tf_atac_expand),
        ("tg_cross_attn_repr", tg_cross_attn_repr),
    ],

    "tg_prediction_head": [
        ("tg_similarity_to_attn_output", tg_similarity_to_attn_output),
        ("gene_lin1", gene_lin1),
        ("gene_relu", gene_relu),
        ("gene_dropout", gene_dropout),
        ("gene_lin2", gene_lin2),
        ("gene_pred_term", gene_pred_term),
        ("tg_pred_pre_shortcut", tg_pred_pre_shortcut),
        ("tg_pred_final", tg_pred),
    ],

    "tf_tg_shortcut": [
        ("shortcut_out", shortcut_out),
        ("shortcut_attn", shortcut_attn),
    ],
})


In [41]:
import re
import torch
from collections.abc import Mapping

def _slug(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"[^a-z0-9]+", "_", s)
    return s.strip("_")

def _select_plottable(x: torch.Tensor, batch_idx: int = 0, head_idx: int = 0):
    """
    Convert common model tensors to plottable 1D/2D by slicing batch/head.

    Rules:
    - 1D/2D: keep
    - 3D: assume [B, *, *] -> x[batch_idx]  -> [*, *] or [*]
    - 4D: assume [B, H, *, *] or [B, H, *, d] -> x[batch_idx, head_idx] -> [*, *]
    - >4D: slice batch then flatten remaining dims to 2D.
    """
    if not torch.is_tensor(x):
        return x

    if x.ndim <= 2:
        return x

    if x.ndim == 3:
        return x[batch_idx]

    if x.ndim == 4:
        return x[batch_idx, head_idx]

    y = x[batch_idx]
    if y.ndim > 2:
        y = y.reshape(y.shape[0], -1)
    return y

def save_grouped_heatmaps(
    model_dict,
    out_root: Tuple[str, Path],
    batch_idx: int = 0,
    head_idx: int = 0,
    dpi: int = 200,
    cmap: str = "viridis",
    skip_none: bool = True,
):
    out_root = Path(out_root)
    out_root.mkdir(parents=True, exist_ok=True)

    n_parts = len(model_dict)
    part_digits = max(2, len(str(n_parts - 1)))

    def _save_items(items, part_name: str, part_dir: Path):
        # items is expected to be iterable of (name, obj)
        n_items = len(items)
        file_digits = max(2, len(str(n_items - 1)))
        saved_i = 0

        for item_name, obj in items:
            if obj is None:
                if skip_none:
                    continue
                raise ValueError(f"{part_name}/{item_name} is None")

            # --- recurse if nested dict/OrderedDict ---
            if isinstance(obj, Mapping):
                sub_dir = part_dir / f"{saved_i:0{file_digits}d}_{_slug(item_name)}"
                sub_dir.mkdir(parents=True, exist_ok=True)
                _save_items(list(obj.items()), f"{part_name}/{item_name}", sub_dir)
                saved_i += 1
                continue

            # --- tensor case ---
            if not torch.is_tensor(obj):
                if skip_none:
                    continue
                raise TypeError(f"{part_name}/{item_name} is not a tensor: {type(obj)}")

            x = _select_plottable(obj, batch_idx=batch_idx, head_idx=head_idx)

            fname = f"{saved_i:0{file_digits}d}_{_slug(item_name)}.svg"
            out_path = part_dir / fname
            title = f"{part_name}\n{item_name}"

            save_heatmap_svg_rasterized_downsampled(
                x,
                out_path,
                title=title,
                cmap=cmap,
                dpi=dpi,
            )
            saved_i += 1

    for part_i, (part_name, items) in enumerate(model_dict.items()):
        part_slug = _slug(part_name)
        part_dir = out_root / f"{part_i:0{part_digits}d}_{part_slug}"
        part_dir.mkdir(parents=True, exist_ok=True)

        # items should be a list of (name, tensor/obj)
        _save_items(list(items), part_name, part_dir)

    print(f"Saved grouped heatmaps to: {out_root}")



save_grouped_heatmaps(
    model_dict,
    out_root="./dev/model_heatmaps_svg",
    batch_idx=0,
    head_idx=0,
    dpi=200,
)


Saved grouped heatmaps to: dev/model_heatmaps_svg
