In [1]:
import sys

import torch
import rsatoolbox.data
import rsatoolbox.rdm.calc
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names

sys.path.append("../src")

import utils  # noqa: E402
import models  # noqa: E402


def sample_images(data_loader, n=5, plot=False):
    """Samples a specified number of images from a data loader."""
    imgs, labels = next(iter(data_loader))

    imgs_o = []
    targets = []
    for value in range(10):
        imgs_o.append(imgs[np.where(labels == value)][0:n])
        targets.append([value] * n)

    imgs = torch.cat(imgs_o, dim=0)
    targets = torch.tensor(targets).flatten()

    if plot:
        plt.imshow(
            torch.moveaxis(
                make_grid(imgs, nrow=n, padding=0, normalize=False, pad_value=0), 0, -1
            )
        )
        plt.title("Sampled Test Images (5 of each class)")
        plt.axis("off")
        plt.show()

    return imgs, targets


def extract_features(model, imgs, return_layers):
    """Extracts features from specified layers of the model."""
    if return_layers == "all":
        # Automatically get the names of all layers in the model
        return_layers, _ = get_graph_node_names(model)

    # Create the feature extractor
    feature_extractor = create_feature_extractor(model, return_nodes=return_layers)
    model_features = feature_extractor(imgs)

    # Add input images (not a layer, but useful for RDM comparison)
    model_features = {"input": imgs, **model_features}
    return model_features


def calc_rdms(model_features, method="correlation"):
    """Calculates representational dissimilarity matrices (RDMs) for model features.

    Args:
        model_features: A dictionary where keys are layer names and values are features of the layers.
        method: The method to calculate RDMs, e.g., 'correlation'. Default is 'correlation'.

    Outputs:
        rdms: RDMs object containing dissimilarity matrices.
        rdms_dict: A dictionary with layer names as keys and their corresponding RDMs as values.
    """
    ds_list = []
    for l in range(len(model_features)):
        layer = list(model_features.keys())[l]
        feats = model_features[layer]

        if type(feats) is list:
            feats = feats[-1]

        feats = feats.cpu()

        if len(feats.shape) > 2:
            feats = feats.flatten(1)

        feats = feats.detach().numpy()
        ds = rsatoolbox.data.Dataset(feats, descriptors=dict(layer=layer))
        ds_list.append(ds)

    rdms = rsatoolbox.rdm.calc.calc_rdm(ds_list, method=method)
    rdms_dict = {
        list(model_features.keys())[i]: rdms.get_matrices()[i]
        for i in range(len(model_features))
    }

    return rdms, rdms_dict


def plot_maps(model_features, model_name):
    """Plots representational dissimilarity matrices (RDMs) across different layers."""
    fig = plt.figure(figsize=(14, 4))
    fig.suptitle(f"RDMs across layers – {model_name}")
    gs = fig.add_gridspec(1, len(model_features))
    fig.subplots_adjust(wspace=0.2, hspace=0.2)

    for l, layer in enumerate(model_features.keys()):
        map_ = np.squeeze(model_features[layer])

        if len(map_.shape) < 2:
            side_len = int(np.sqrt(map_.shape[0]))
            if side_len * side_len == map_.shape[0]:
                map_ = map_.reshape((side_len, side_len))

        if np.max(map_) > 0:
            map_ = map_ / np.max(map_)

        ax = plt.subplot(gs[0, l])
        ax_ = ax.imshow(map_, cmap="magma_r")
        ax.set_title(f"Layer: {layer}")
        ax.set_xlabel("Input Index")
        if l == 0:
            ax.set_ylabel("Input Index")

    fig.subplots_adjust(right=0.9)
    cbar_ax = fig.add_axes([0.92, 0.15, 0.01, 0.7])
    cbar = fig.colorbar(ax_, cax=cbar_ax)
    cbar.set_label("Dissimilarity", rotation=270, labelpad=15)

    plt.show()

In [2]:
# Config
n_classes = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load data
batch_size = 500  # Does not matter here
_, _, test_loader = utils.load_mnist_data(batch_size)
class_names = [str(i) for i in range(n_classes)]

# Load models
bp_model = models.LeNet5(n_classes=10, latent_dim=84, act_fn="relu")
bp_model.load_state_dict(torch.load("../results/backprop-model.pth"))
bp_model.eval().to(device)

ff_model = models.FFLeNet5(n_classes=10, latent_dim=84)
ff_model.load_state_dict(torch.load("../results/ff-model.pth"))
ff_model.eval().to(device)

pc_model = models.PCLeNet5(n_classes=10, latent_dim=84)
pc_model.load_state_dict(torch.load("../results/pc-model.pth"))
pc_model.eval().to(device)

RuntimeError: Error(s) in loading state_dict for FFLeNet5:
	Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3.weight", "conv3.bias", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias". 
	Unexpected key(s) in state_dict: "conv1.conv.weight", "conv1.conv.bias", "conv2.conv.weight", "conv2.conv.bias", "conv3.conv.weight", "conv3.conv.bias", "fc1.linear.weight", "fc1.linear.bias", "fc2.linear.weight", "fc2.linear.bias". 

In [None]:
train_nodes, val_nodes = get_graph_node_names(bp_model)
print(f"{train_nodes=}")
train_nodes, val_nodes = get_graph_node_names(ff_model)
print(f"{train_nodes=}")
train_nodes, val_nodes = get_graph_node_names(pc_model)
print(f"{train_nodes=}")

In [None]:
# FIXME: it's not returning 5 samples per class
import train.ff

imgs, targets = sample_images(test_loader, n=5, plot=True)
imgs_overlay = train.ff.overlay_label(imgs, targets, n_classes, is_positive=True)

# NOTE: analyze after activation (captures full non-linear representation)
print("Extracting features (backprop)")
layer_names = ["pool1", "pool2", "relu_2", "relu_3", "fc2"]
bp_features = extract_features(bp_model, imgs.to(device), return_layers=layer_names)

print("\nExtracting features (forward-forward)")
layer_names = ["pool1", "pool2", "conv3.relu", "fc1.relu", "fc2.linear"]
ff_features = extract_features(
    ff_model, imgs_overlay.to(device), return_layers=layer_names
)

print("\nExtracting features (predictive coding)")
layer_names = ["0.2", "1.2", "2.2", "3.1", "4.0"]
pc_features = extract_features(pc_model, imgs.to(device), return_layers=layer_names)

# Rename to match the backpropagation model
ff_features = {
    "input": ff_features["input"],
    "pool1": ff_features["pool1"],
    "pool2": ff_features["pool2"],
    "conv3.relu": ff_features["conv3.relu"],
    "fc1.relu": ff_features["fc1.relu"],
    "fc2.linear": ff_features["fc2.linear"],
}

pc_features = {
    "input": pc_features["input"],
    "pool1": pc_features["0.2"],
    "pool2": pc_features["1.2"],
    "relu_2": pc_features["2.2"],
    "relu_3": pc_features["3.1"],
    "fc2": pc_features["4.0"],
}

In [None]:
rdms_bp, rdms_dict_bp = calc_rdms(bp_features)
plot_maps(rdms_dict_bp, "Backpropagation")

rdms_ff, rdms_dict_ff = calc_rdms(ff_features)
plot_maps(rdms_dict_ff, "Forward-forward")

rdms_pc, rdms_dict_pc = calc_rdms(pc_features)
plot_maps(rdms_dict_pc, "Predictive Coding")