In [2]:
import torch
from pathlib import Path

MPID_CKPT = Path("/home/hep/an1522/dark_tridents_wspace/outputs/weights/DM-CNN_model_20260116-10_22_PM_epoch_4_batch_id_1961_labels_2_title_0.001_AG_GN_LM_TRAINING_step_9821.pwf")
# MPID_CKPT2 = Path("/home/hep/an1522/dark_tridents_wspace/outputs/weights/DM-CNN_model_20251030-09_18_PM_epoch_4_batch_id_1961_labels_2_title_0.001_AG_GN_LM_TRAINING_step_9821.pwf")
RN34_CKPT = Path("/home/hep/an1522/dark_tridents_wspace/outputs/weights/resnet34_gn/resnet34_gn_model_20260123-12_20_AM_epoch_4_batch_id_1961_labels_2_step_9821.pwf")

def extract_state(ckpt_obj):
    # matches your CLI helper logic (common key names) [file:3]
    if isinstance(ckpt_obj, dict):
        for k in ["state_dict", "model_state_dict", "model", "net", "weights"]:
            if k in ckpt_obj and isinstance(ckpt_obj[k], dict):
                return ckpt_obj[k]
        return ckpt_obj
    raise ValueError("Unsupported checkpoint format")

def load_state(path: Path):
    ckpt = torch.load(str(path), map_location="cpu")
    state = extract_state(ckpt)
    print(f"{path.name}: {len(state)} tensors")
    return state

mpid_state = load_state(MPID_CKPT)
# mpid_state2 = load_state(MPID_CKPT2)
rn34_state = load_state(RN34_CKPT)


DM-CNN_model_20260116-10_22_PM_epoch_4_batch_id_1961_labels_2_title_0.001_AG_GN_LM_TRAINING_step_9821.pwf: 46 tensors
resnet34_gn_model_20260123-12_20_AM_epoch_4_batch_id_1961_labels_2_step_9821.pwf: 110 tensors


In [3]:
from collections import Counter

def summarize_prefixes(state, n=20):
    keys = list(state.keys())
    top = Counter(k.split(".")[0] for k in keys)
    print("Top-level prefixes:", top.most_common(n))

print("MPID:")
summarize_prefixes(mpid_state)

# print("MPID2:")
# summarize_prefixes(mpid_state2)

print("\nResNet34_gn:")
summarize_prefixes(rn34_state)


MPID:
Top-level prefixes: [('features', 40), ('classifier', 6)]

ResNet34_gn:
Top-level prefixes: [('net', 110)]


In [4]:
SUFFIXES = {
    "weight", "bias",
    "running_mean", "running_var", "num_batches_tracked"
}

def module_name_from_key(k: str):
    parts = k.split(".")
    if parts[-1] in SUFFIXES:
        return ".".join(parts[:-1])
    return k

def candidate_modules(state, must_contain=None):
    mods = sorted({module_name_from_key(k) for k in state.keys()})
    if must_contain:
        mods = [m for m in mods if must_contain in m]
    return mods

mpid_modules = candidate_modules(mpid_state)
# mpid_modules2 = candidate_modules(mpid_state2)
rn34_modules = candidate_modules(rn34_state)

# print("Example MPID module names:", mpid_modules[:40])
# print("\nExample ResNet module names:", [m for m in rn34_modules if m.startswith("net.layer")][:40])

print("Example MPID module names:", mpid_modules)
# print("Example MPID2 module names:", mpid_modules2)
print("\nExample ResNet module names:", [m for m in rn34_modules if m.startswith("net.layer")])

Example MPID module names: ['classifier.1', 'classifier.3', 'classifier.4', 'features.0', 'features.10', 'features.12', 'features.14', 'features.16', 'features.17', 'features.19', 'features.2', 'features.21', 'features.23', 'features.24', 'features.26', 'features.28', 'features.3', 'features.30', 'features.31', 'features.32', 'features.5', 'features.7', 'features.9']

Example ResNet module names: ['net.layer1.0.bn1', 'net.layer1.0.bn2', 'net.layer1.0.conv1', 'net.layer1.0.conv2', 'net.layer1.1.bn1', 'net.layer1.1.bn2', 'net.layer1.1.conv1', 'net.layer1.1.conv2', 'net.layer1.2.bn1', 'net.layer1.2.bn2', 'net.layer1.2.conv1', 'net.layer1.2.conv2', 'net.layer2.0.bn1', 'net.layer2.0.bn2', 'net.layer2.0.conv1', 'net.layer2.0.conv2', 'net.layer2.0.downsample.0', 'net.layer2.0.downsample.1', 'net.layer2.1.bn1', 'net.layer2.1.bn2', 'net.layer2.1.conv1', 'net.layer2.1.conv2', 'net.layer2.2.bn1', 'net.layer2.2.bn2', 'net.layer2.2.conv1', 'net.layer2.2.conv2', 'net.layer2.3.bn1', 'net.layer2.3.bn2

In [5]:
def filter_cam_layers(mods):
    # Typical hook targets are convs or whole residual blocks (layerX / layerX.N) [file:3]
    keep = []
    for m in mods:
        if any(s in m for s in ["conv", "layer", "features"]):
            keep.append(m)
    return keep

mpid_cam = filter_cam_layers(mpid_modules)
mpid_cam2 = filter_cam_layers(mpid_modules)
rn34_cam = filter_cam_layers(rn34_modules)

print("MPID CAM-ish candidates (first 80):")
for m in mpid_cam[:80]:
    print(" ", m)

# print("MPID2 CAM-ish candidates (first 80):")
# for m in mpid_cam2[:80]:
#     print(" ", m)

print("\nResNet34 CAM-ish candidates (layer* ones):")
for m in rn34_cam:
    if m.startswith("net.layer"):
        print(" ", m)


MPID CAM-ish candidates (first 80):
  features.0
  features.10
  features.12
  features.14
  features.16
  features.17
  features.19
  features.2
  features.21
  features.23
  features.24
  features.26
  features.28
  features.3
  features.30
  features.31
  features.32
  features.5
  features.7
  features.9

ResNet34 CAM-ish candidates (layer* ones):
  net.layer1.0.bn1
  net.layer1.0.bn2
  net.layer1.0.conv1
  net.layer1.0.conv2
  net.layer1.1.bn1
  net.layer1.1.bn2
  net.layer1.1.conv1
  net.layer1.1.conv2
  net.layer1.2.bn1
  net.layer1.2.bn2
  net.layer1.2.conv1
  net.layer1.2.conv2
  net.layer2.0.bn1
  net.layer2.0.bn2
  net.layer2.0.conv1
  net.layer2.0.conv2
  net.layer2.0.downsample.0
  net.layer2.0.downsample.1
  net.layer2.1.bn1
  net.layer2.1.bn2
  net.layer2.1.conv1
  net.layer2.1.conv2
  net.layer2.2.bn1
  net.layer2.2.bn2
  net.layer2.2.conv1
  net.layer2.2.conv2
  net.layer2.3.bn1
  net.layer2.3.bn2
  net.layer2.3.conv1
  net.layer2.3.conv2
  net.layer3.0.bn1
  net.layer

In [6]:
import torch
import torch.nn as nn
from contextlib import contextmanager

def get_module_by_name(model: nn.Module, name: str) -> nn.Module:
    name = str(name).strip()
    for n, m in model.named_modules():
        if n == name:
            return m
    raise KeyError(f"Layer '{name}' not found in model.named_modules()")

@contextmanager
def capture_activations_and_grads(layer: nn.Module):
    cache = {"acts": None, "grads": None}

    def fwd_hook(m, inp, out):
        cache["acts"] = out

    def bwd_hook(m, grad_input, grad_output):
        # grad_output is a tuple; grad wrt module output is grad_output[0]
        cache["grads"] = grad_output[0]

    h1 = layer.register_forward_hook(fwd_hook)
    h2 = layer.register_full_backward_hook(bwd_hook)
    try:
        yield cache
    finally:
        h1.remove()
        h2.remove()


In [7]:
def probe_layer(model: nn.Module, x: torch.Tensor, layer_name: str, class_idx: int = 0):
    model.eval()
    layer = get_module_by_name(model, layer_name)

    with capture_activations_and_grads(layer) as cache:
        # Forward
        logits = model(x)

        # Backward for Grad-CAM-style methods
        score = logits[:, class_idx].sum()
        model.zero_grad(set_to_none=True)
        score.backward()

    acts = cache["acts"]
    grads = cache["grads"]

    print(f"\n=== Layer: {layer_name} ({layer.__class__.__name__}) ===")
    print("logits:", tuple(logits.shape), " sample:", logits.detach().cpu()[0].tolist())

    if acts is None:
        print("No activations captured (unexpected).")
        return

    if not isinstance(acts, torch.Tensor):
        print("Activation is not a tensor:", type(acts))
        return

    print("acts shape:", tuple(acts.shape), "dtype:", acts.dtype, "device:", acts.device)

    # Expect [B,C,H,W] for CAM layers
    if acts.ndim == 4:
        B,C,H,W = acts.shape
        a = acts.detach()
        print(f"spatial grid: {H} x {W}  (CAM pixels)")
        print("acts stats:",
              "min", float(a.min()), "max", float(a.max()),
              "mean", float(a.mean()),
              "frac>0", float((a>0).float().mean()))
    else:
        print("acts ndim is", acts.ndim, "(CAM methods typically want 4D [B,C,H,W])")

    if grads is None:
        print("No gradients captured -> Grad-CAM/++ will fail for this layer.")
    else:
        g = grads.detach()
        print("grads shape:", tuple(g.shape))
        print("grads stats:",
              "min", float(g.min()), "max", float(g.max()),
              "mean", float(g.mean()),
              "L2", float(torch.sqrt(torch.mean(g*g))))


In [8]:
import torch
from pathlib import Path
from mpid_net import mpid_net_binary

def extract_state(ckpt_obj):
    # same pattern your scripts use (state_dict/model_state_dict/model/net/weights) [file:1]
    if isinstance(ckpt_obj, dict):
        for k in ["state_dict", "model_state_dict", "model", "net", "weights"]:
            if k in ckpt_obj and isinstance(ckpt_obj[k], dict):
                return ckpt_obj[k]
        return ckpt_obj
    raise ValueError("Unsupported checkpoint format")

ckpt = torch.load(str(MPID_CKPT), map_location="cpu")
state = extract_state(ckpt)

mpid_model = mpid_net_binary.MPID()
mpid_model.load_state_dict(state, strict=True)
mpid_model.eval()

print(mpid_model)

# ckpt2 = torch.load(str(MPID_CKPT2), map_location="cpu")
# state2 = extract_state(ckpt2)

# mpid_model2 = mpid_net_binary.MPID()
# mpid_model2.load_state_dict(state, strict=True)
# mpid_model2.eval()

# print(mpid_model2)

MPID(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): GroupNorm(64, 64, eps=1e-05, affine=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): GroupNorm(64, 64, eps=1e-05, affine=True)
    (6): AvgPool2d(kernel_size=2, stride=2, padding=1)
    (7): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): GroupNorm(96, 96, eps=1e-05, affine=True)
    (10): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1))
    (11): ReLU()
    (12): GroupNorm(96, 96, eps=1e-05, affine=True)
    (13): AvgPool2d(kernel_size=2, stride=2, padding=1)
    (14): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
    (16): GroupNorm(128, 128, eps=1e-05, affine=True)
    (17): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (18): ReLU()
    (19): GroupNorm(128, 128, eps=1e-05, affine=True)
    (20): AvgPool2d(kernel_size=2, str

In [9]:
import torch
import torch.nn as nn
from pathlib import Path

def extract_state(ckpt_obj):
    if isinstance(ckpt_obj, dict):
        for k in ["state_dict", "model_state_dict", "model", "net", "weights"]:
            if k in ckpt_obj and isinstance(ckpt_obj[k], dict):
                return ckpt_obj[k]
        return ckpt_obj
    raise ValueError("Unsupported checkpoint format")

def infer_norm_from_state(state):
    # BatchNorm has running_mean/running_var buffers; GroupNorm doesn't. [file:1]
    for k in state.keys():
        if k.endswith("running_mean") or k.endswith("running_var"):
            return "bn"
    return "gn"

def make_norm_layer(kind="gn", gngroups=32):
    kind = kind.lower()
    if kind == "bn":
        return lambda c: nn.BatchNorm2d(c)
    if kind == "gn":
        def gn(c):
            g = min(int(gngroups), int(c))
            while g > 1 and (c % g) != 0:
                g -= 1
            return nn.GroupNorm(g, c)
        return gn
    raise ValueError("kind must be 'bn' or 'gn'")

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
        super().__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1   = norm_layer(planes)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2   = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.relu(out + identity)
        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, norm_layer, inchannels=1, numclasses=2, dropout=0.0):
        super().__init__()
        self.inplanes = 64

        self.conv1 = nn.Conv2d(inchannels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1   = norm_layer(64)
        self.relu  = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block,  64, layers[0], norm_layer=norm_layer)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(nn.Dropout(p=float(dropout)), nn.Linear(512 * block.expansion, numclasses))

    def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )
        layers = [block(self.inplanes, planes, stride, downsample, norm_layer)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, 1, None, norm_layer))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        x = self.layer4(self.layer3(self.layer2(self.layer1(x))))
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class ResNetBinaryWrapperLocal(nn.Module):
    # Matches your checkpoint prefix style: net.layer4..., net.fc.1.weight, etc. [file:3]
    def __init__(self, norm="gn", gngroups=32, dropout=0.0, numclasses=2, inchannels=1):
        super().__init__()
        norm_layer = make_norm_layer(norm, gngroups=gngroups)
        self.net = ResNet(BasicBlock, [3, 4, 6, 3], norm_layer=norm_layer,
                          inchannels=inchannels, numclasses=numclasses, dropout=dropout)
    def forward(self, x):
        return self.net(x)

def load_resnet34_from_ckpt(path, device="cpu"):
    ckpt = torch.load(str(path), map_location="cpu")
    state = extract_state(ckpt)
    norm = infer_norm_from_state(state)
    model = ResNetBinaryWrapperLocal(norm=norm, gngroups=32, dropout=0.0, numclasses=2, inchannels=1)
    model.load_state_dict(state, strict=True)
    model.to(device).eval()
    return model, norm

RESNET_CKPT = Path("/home/hep/an1522/dark_tridents_wspace/outputs/weights/resnet34_gn/resnet34_gn_model_20260123-12_20_AM_epoch_4_batch_id_1961_labels_2_step_9821.pwf")
resnet_model, inferred_norm = load_resnet34_from_ckpt(RESNET_CKPT)

print(resnet_model)

ResNetBinaryWrapperLocal(
  (net): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): GroupNorm(32, 64, eps=1e-05, affine=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

In [10]:
# x must be [1,1,512,512] on same device as model
# class_idx: pick 0/1 depending which logit you treat as "signal"/"background"
candidates_resnet = ["net.layer1", "net.layer2", "net.layer3", 
                     "net.layer4", "net.layer4.2.conv2", "net.layer4.2.bn2"]
# candidates_mpid = ["features.19", "features.20", "features.21", 
#                     "features.22", "features.23", "features.28", 
#                     "features.31", "features.32"]

candidates_mpid = ['features.7', 'features.10', 'features.12', 
                     'features.14', 'features.16', 'features.17', 
                     'features.19', 'features.2', 'features.21', 
                     'features.23', 'features.24', 'features.26', 
                     'features.28', 'features.29', 'features.3', 
                     'features.30', 'features.31', 'features.32', 
                     'features.5', 'features.7', 'features.9']

# candidates_resnet = ['net.layer1.0.bn1', 'net.layer1.0.bn2', 'net.layer1.0.conv1', 
#                      'net.layer1.0.conv2', 'net.layer1.1.bn1', 'net.layer1.1.bn2', 
#                      'net.layer1.1.conv1', 'net.layer1.1.conv2', 'net.layer1.2.bn1', 
#                      'net.layer1.2.bn2', 'net.layer1.2.conv1', 'net.layer1.2.conv2', 
#                      'net.layer2.0.bn1', 'net.layer2.0.bn2', 'net.layer2.0.conv1', 
#                      'net.layer2.0.conv2', 'net.layer2.0.downsample.0', 'net.layer2.0.downsample.1', 
#                      'net.layer2.1.bn1', 'net.layer2.1.bn2', 'net.layer2.1.conv1', 
#                      'net.layer2.1.conv2', 'net.layer2.2.bn1', 'net.layer2.2.bn2', 
#                      'net.layer2.2.conv1', 'net.layer2.2.conv2', 'net.layer2.3.bn1',
#                      'net.layer2.3.bn2', 'net.layer2.3.conv1', 'net.layer2.3.conv2', 
#                      'net.layer3.0.bn1', 'net.layer3.0.bn2', 'net.layer3.0.conv1', 
#                      'net.layer3.0.conv2', 'net.layer3.0.downsample.0', 'net.layer3.0.downsample.1',
#                      'net.layer3.1.bn1', 'net.layer3.1.bn2', 'net.layer3.1.conv1', 
#                      'net.layer3.1.conv2', 'net.layer3.2.bn1', 'net.layer3.2.bn2', 
#                      'net.layer3.2.conv1', 'net.layer3.2.conv2', 'net.layer3.3.bn1', 
#                      'net.layer3.3.bn2', 'net.layer3.3.conv1', 'net.layer3.3.conv2', 
#                      'net.layer3.4.bn1', 'net.layer3.4.bn2', 'net.layer3.4.conv1', 
#                      'net.layer3.4.conv2', 'net.layer3.5.bn1', 'net.layer3.5.bn2', 
#                      'net.layer3.5.conv1', 'net.layer3.5.conv2', 'net.layer4.0.bn1', 
#                      'net.layer4.0.bn2', 'net.layer4.0.conv1', 'net.layer4.0.conv2', 
#                      'net.layer4.0.downsample.0', 'net.layer4.0.downsample.1', 
#                      'net.layer4.1.bn1', 'net.layer4.1.bn2', 'net.layer4.1.conv1', 
#                      'net.layer4.1.conv2', 'net.layer4.2.bn1', 'net.layer4.2.bn2', 
#                      'net.layer4.2.conv1', 'net.layer4.2.conv2']
# candidates_mpid   = ['classifier.1', 'classifier.3', 'classifier.4',
#                      'features.0', 'features.10', 'features.12', 
#                      'features.14', 'features.16', 'features.17', 
#                      'features.19', 'features.2', 'features.21', 
#                      'features.23', 'features.24', 'features.26', 
#                      'features.28', 'features.3', 'features.30', 
#                      'features.31', 'features.32', 'features.5', 
#                      'features.7', 'features.9']

x = torch.randn(1, 1, 512, 512, requires_grad=False)

print("")
print("mpid:")

for ln in candidates_mpid:
    # print("")
    # print("mpid:")
    probe_layer(mpid_model, x, ln, class_idx=0)
    # print("")
    # print("mpid2:")
    # probe_layer(mpid_model2, x, ln, class_idx=0)

print("")
print("rn34_gn:")

for ln in candidates_resnet:
    probe_layer(resnet_model, x, ln, class_idx=0)



mpid:

=== Layer: features.7 (Conv2d) ===
logits: (1, 2)  sample: [-1.540178656578064, 1.5494192838668823]
acts shape: (1, 96, 128, 128) dtype: torch.float32 device: cpu
spatial grid: 128 x 128  (CAM pixels)
acts stats: min -19.41550064086914 max 15.316720008850098 mean -0.032182786613702774 frac>0 0.5053749084472656
grads shape: (1, 96, 128, 128)
grads stats: min -0.0056997025385499 max 0.004124149214476347 mean -5.8910273992296425e-08 L2 0.00018462615844327956

=== Layer: features.10 (Conv2d) ===
logits: (1, 2)  sample: [-1.540178656578064, 1.5494192838668823]
acts shape: (1, 96, 126, 126) dtype: torch.float32 device: cpu
spatial grid: 126 x 126  (CAM pixels)
acts stats: min -71.86093139648438 max 50.46554946899414 mean -1.3355437517166138 frac>0 0.4587053656578064
grads shape: (1, 96, 126, 126)
grads stats: min -0.0012752788607031107 max 0.0007818647427484393 mean 3.5845580725890613e-08 L2 3.9940434362506494e-05

=== Layer: features.12 (GroupNorm) ===
logits: (1, 2)  sample: [-1.54