In [10]:
import torch


In [None]:
import torch

ckpt = torch.load("ckpt_t0.pt", map_location="cpu")
for k, v in ckpt["state_dict"].items():          # o ckpt.keys() si guardaste solo dict
    print(f"{k:40s}   {tuple(v.shape)}")


model.backbone.conv1.weight                (64, 3, 7, 7)
model.backbone.bn1.weight                  (64,)
model.backbone.bn1.bias                    (64,)
model.backbone.bn1.running_mean            (64,)
model.backbone.bn1.running_var             (64,)
model.backbone.bn1.num_batches_tracked     ()
model.backbone.layer1.0.conv1.weight       (64, 64, 3, 3)
model.backbone.layer1.0.bn1.weight         (64,)
model.backbone.layer1.0.bn1.bias           (64,)
model.backbone.layer1.0.bn1.running_mean   (64,)
model.backbone.layer1.0.bn1.running_var    (64,)
model.backbone.layer1.0.bn1.num_batches_tracked   ()
model.backbone.layer1.0.conv2.weight       (64, 64, 3, 3)
model.backbone.layer1.0.bn2.weight         (64,)
model.backbone.layer1.0.bn2.bias           (64,)
model.backbone.layer1.0.bn2.running_mean   (64,)
model.backbone.layer1.0.bn2.running_var    (64,)
model.backbone.layer1.0.bn2.num_batches_tracked   ()
model.backbone.layer1.1.conv1.weight       (64, 64, 3, 3)
model.backbone.layer1.1.bn1.w

In [16]:
#!/usr/bin/env python3
# inspect_activations.py  –  traza las activaciones de tu ResNet-18 (CPU)

import torch, torchvision
from torchvision import transforms
from collections import OrderedDict
from torchvision.models import resnet18

# ───────────────────────────── 1. Construir la arquitectura ─────────────────────────────
def build_model(num_classes: int = 5):
    net = resnet18(weights=None)
    net.fc = torch.nn.Linear(512, num_classes)   # cabeza personalizada
    return net

net = build_model().eval().cpu()

# ─────────────────────────── 2. Cargar y limpiar el checkpoint ──────────────────────────
ckpt = torch.load("ckpt_t0.pt", map_location="cpu")
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt

fixed = {}
for k, v in state_dict.items():
    # 1) quita "model.backbone."  •  2) quita "model." o "backbone." si quedan
    for prefix in ("model.backbone.", "model.", "backbone."):
        if k.startswith(prefix):
            k = k[len(prefix):]
            break
    # 3) adapta la cabeza: head.*  →  fc.*
    if k.startswith("head."):
        k = "fc." + k[len("head."):]
    fixed[k] = v

missing, unexpected = net.load_state_dict(fixed, strict=True)
assert not missing and not unexpected, f"faltan {missing} · sobran {unexpected}"

# ───────────────────────────── 3. Hooks para activaciones ───────────────────────────────
activations = OrderedDict()
def save(tag):
    def _hook(_, __, out):
        activations[tag] = out.detach().cpu()
    return _hook

net.conv1.register_forward_hook(save("conv1"))
net.layer1.register_forward_hook(save("layer1"))
net.layer2.register_forward_hook(save("layer2"))
net.layer3.register_forward_hook(save("layer3"))
net.layer4.register_forward_hook(save("layer4"))
net.avgpool.register_forward_hook(save("avgpool"))

# ─────────────────────────── 4. Tres imágenes de CIFAR-10 ──────────────────────────────
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])
cifar_test = torchvision.datasets.CIFAR10("~/data", train=False,
                                          download=True, transform=transform)
batch = torch.stack([cifar_test[i][0] for i in range(3)])   # [3,3,224,224]

# ───────────────────────────────── 5. Forward pass ──────────────────────────────────────
with torch.no_grad():
    logits = net(batch)        # [3, 5]

# ─────────────────────────────── 6. Resultados ─────────────────────────────────────────
print("\n=== Logits (sin softmax) ===")
print(logits)

print("\n=== Formas de las activaciones ===")
for name, act in activations.items():
    print(f"{name:<7} → {tuple(act.shape)}")
    #print(f"{name:<7} → {tuple(act)}")



=== Logits (sin softmax) ===
tensor([[ 14.9622,  -1.0243, -14.4369,   8.1287, -42.2056],
        [ 16.9558,  21.6087, -18.2771, -15.4169, -40.5104],
        [ -0.5901,  19.4824,  -7.8706, -10.4091, -26.1498]])

=== Formas de las activaciones ===
conv1   → (3, 64, 112, 112)
layer1  → (3, 64, 56, 56)
layer2  → (3, 128, 28, 28)
layer3  → (3, 256, 14, 14)
layer4  → (3, 512, 7, 7)
avgpool → (3, 512, 1, 1)
