In [None]:
from convnext import get_convnext
from interp_utils import heatmap, seriate
from imagenet_val import simple_class_labels as class_labels

model = get_convnext()

In [None]:
# get writeout matrix of last layer
w_out = model.blocks[17].pwconv2.weight.data.T

# unembedding matrix
unembed = model.head.weight.data

In [None]:
# Examine last layer logit lens
logit_lens_arr = w_out @ unembed.T
heatmap(logit_lens_arr, title='Layer 17 logit lens', dim_names=('neuron', 'logit'), info_1={'class': class_labels}, include_idx=(True, False)).show()

# this works too!
# heatmap(w_out @ unembed.T)

# Notice that there are positive/negative streaks and dots. This is because the class labels are not sorted randomly and the neurons are doing things that are related to the class labels.
# For example, the dogs are all grouped together and most objects are grouped together. See if you can see some of the structure by hovering over the heatmap.
# Also note that you can zoom in by clicking and dragging.

In [None]:
# Permute neurons so that similar rows are near each other. This might take a little bit to run (takes 20s on my computer).
seriated_logit_lens_arr, neuron_perm, logit_perm = seriate(logit_lens_arr)

In [None]:
# Show the seriated heatmap with permuted neurons and classes
# We want to use the permutations as arguments instead of the permuted array so that the neuron index and the class labels are permutated also
# Zoom in on some of the clusters! Do they make sense?
heatmap(logit_lens_arr, perm_0=neuron_perm, perm_1=logit_perm, title='Layer 17 logit lens (seriated)', dim_names=('neuron', 'logit'), info_1={'class': class_labels}, include_idx=(True, False)).show()