In [1]:
import os
import torch
import numpy as np
from sklearn.preprocessing import normalize as sk_normalize
from k_means_constrained import KMeansConstrained
from taker import Model


def cluster_neurons(model: Model,
        layer: int,
        split_num: int=96,
        method="kmeans",
    ):
    # First, get variables for which components are used
    assert model.cfg.d_mlp % split_num == 0, \
        "split_num should evenly divide model's mlp width"
    split_size = model.cfg.d_mlp // split_num

    # Collect the neurons we are clustering
    weights = model.layers[layer]["mlp.W_in"].detach().cpu()
    normed_weights = sk_normalize(weights)

    # Perform the clustering
    if method == "kmeans":
        kmeans = KMeansConstrained(
            n_clusters=split_num, size_min=split_size, size_max=split_size, random_state=0
        ).fit(normed_weights, None)
        labels = [x for x in kmeans.labels_]
        return labels, kmeans

    if method == "random":
         labels = np.array(list(range(model.cfg.d_mlp))) % split_num
         return labels, {}

    raise NotImplementedError(f"method {method} not implemented")

In [3]:
m = Model("nickypro/tinyllama-15m")

for layer in range(1):
    labels, _ = cluster_neurons(m, layer, method="kmeans")
    print(labels)

- Loaded nickypro/tinyllama-15m
 - Registered 6 Attention Layers
[91, 35, 6, 33, 82, 28, 1, 76, 59, 25, 61, 55, 66, 91, 49, 55, 93, 37, 33, 24, 38, 92, 80, 32, 21, 26, 75, 76, 73, 31, 42, 4, 15, 11, 76, 0, 3, 91, 36, 72, 4, 59, 44, 35, 11, 23, 58, 54, 3, 50, 69, 94, 19, 56, 47, 92, 94, 62, 80, 66, 12, 23, 86, 32, 80, 79, 17, 86, 40, 13, 44, 17, 47, 1, 59, 25, 55, 72, 5, 43, 69, 27, 94, 76, 71, 45, 11, 22, 56, 69, 64, 29, 3, 2, 39, 24, 41, 32, 5, 90, 94, 63, 39, 38, 29, 8, 57, 88, 63, 86, 49, 87, 64, 7, 21, 16, 36, 87, 8, 3, 33, 11, 12, 93, 50, 81, 78, 85, 48, 27, 45, 91, 62, 70, 51, 47, 31, 52, 43, 70, 18, 55, 61, 68, 67, 55, 93, 17, 8, 41, 2, 81, 37, 62, 16, 41, 16, 73, 52, 46, 74, 5, 72, 65, 43, 42, 10, 59, 49, 69, 19, 52, 31, 18, 77, 84, 44, 24, 7, 21, 13, 3, 26, 14, 34, 87, 46, 75, 39, 57, 14, 75, 71, 68, 89, 54, 10, 30, 57, 85, 12, 20, 56, 10, 82, 57, 90, 51, 31, 15, 35, 58, 47, 62, 72, 65, 9, 76, 89, 22, 9, 19, 64, 49, 11, 45, 9, 82, 72, 72, 5, 43, 12, 91, 31, 79, 61, 29, 87, 54,