In [None]:
"""Evaluate a trained spectrogram classifier on a single PNG input."""

from __future__ import annotations

from pathlib import Path

from PIL import Image
import torch
import matplotlib.pyplot as plt
from contextlib import contextmanager

from conv2d_model import DEFAULT_IMAGE_SIZE, build_base_transform, load_weights

In [None]:
def load_image_tensor(image_path: Path, transform, device: torch.device) -> torch.Tensor:
	image = Image.open(image_path).convert("L")
	tensor = transform(image).unsqueeze(0)
	return tensor.to(device)

device = torch.device("mps")

In [None]:
file = Path("./data/training/wash_03/wash_03_2020.png") # 5460
model_path = Path("./conv2d_model.pt")

transform = build_base_transform(DEFAULT_IMAGE_SIZE)
model, classes = load_weights(model_path, device)
model.eval()

tensor = load_image_tensor(file, transform, device)
with torch.no_grad():
    logits = model(tensor)
    probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()

for i, cl in enumerate(classes):
    print(f"{cl}: {probs[i]:0.2f}")

In [None]:
activation_maps = {}

def register_hooks(model, layers):
    handles = []

    def hook_fn(name):
        def hook(module, _input, output):
            activation_maps[name] = output.detach().cpu()
        return hook

    for name, layer in layers.items():
        handles.append(layer.register_forward_hook(hook_fn(name)))
    return handles


def remove_hooks(handles):
    for handle in handles:
        handle.remove()


def inspect_layers(model, layers, inp):
    hooks = register_hooks(model, layers)
    with torch.no_grad():
        _ = model(inp)
    remove_hooks(hooks)


layers_to_probe = {
    "stem": model.stem[0],
    "layer1_block0": model.layer1[1],
    "layer2_block0": model.layer2[1],
    "layer3_block0": model.layer3[1],
    "layer4_block0": model.layer4[0],
    "layer4_block1": model.layer4[1],
    "layer4_block2": model.layer4[2],
}

inspect_layers(model, layers_to_probe, tensor)

In [None]:
fig_rows = []
for layer_name, activations in activation_maps.items():
    num_maps = min(6, activations.shape[1])
    fig, axes = plt.subplots(1, num_maps, figsize=(3 * num_maps, 3))
    if num_maps == 1:
        axes = [axes]
    for idx in range(num_maps):
        axes[idx].imshow(activations[0, idx], cmap="inferno")
        axes[idx].set_title(f"{layer_name}\nchannel {idx}")
        axes[idx].axis("off")
    plt.tight_layout()
    fig_rows.append((layer_name, fig))

fig_rows