# Meta models investigation

The goal of this notebook is to understand which architecture can learn things about neural networks (NN).
Having NN as imput is a very different data type than images, text or tabular data, so we might need different processing layers to extract meaning from them.

## Permutation of matrices
In NN, one can permute a lot of the hidden activations (all activations in an MLP, channels in a CNN, ...) and still have the same network.

The goal here is to train a NN with:
- two matrices as input, $A$ and $B$
- ouput is whether is a permutation of the rows and colums of the first, that is, if there are permutation matrices $P, Q$ such that $A = PBQ$.

In [None]:
from __future__ import annotations

import random

import plotly.express as px
import torch
import torch.nn as nn
from jaxtyping import Float, Bool
from torch import Tensor

In [None]:
def permute_rows_and_cols(
    matrix: Float[Tensor, "*batch rows cols"]
) -> Float[Tensor, "*batch rows cols"]:
    """
    Permutes randomly the rows and columns of the input matrix.
    """
    *batch, rows, cols = matrix.shape
    row_perm = torch.randperm(rows)
    col_perm = torch.randperm(cols)
    return matrix[..., row_perm, :][..., :, col_perm]


def shuffle_matrix(
    matrix: Float[Tensor, "*batch rows cols"]
) -> Float[Tensor, "*batch rows cols"]:
    """
    Shuffles all the elements of the matrix.
    """
    *batch, rows, cols = matrix.shape
    perm = torch.randperm(rows * cols)
    matrix = matrix.view(*batch, -1)
    return matrix[..., perm].view(*batch, rows, cols)


def gen_training_data(
    batch_size: int, mat_size: int, noise: float = 0.0
) -> tuple[Float[Tensor, "batch 2 rows cols"], Bool[Tensor, "batch"]]:
    """
    Generates a batch of training data.
    """
    A = torch.rand(batch_size, mat_size, mat_size)

    permuted = permute_rows_and_cols(A)
    shuffled = shuffle_matrix(A)

    to_permute = torch.rand(batch_size) > 0.5
    B = torch.where(to_permute.view(-1, 1, 1), permuted, shuffled)

    if noise > 0:
        A += torch.randn_like(A) * noise
        B += torch.randn_like(B) * noise

    x = torch.stack([A, B], dim=1)
    return x, to_permute


def get_accuracy(
    model: nn.Module, mat_size: int, n_epochs: int = 1000, batch_size: int = 100
) -> float:
    """
    Computes the accuracy of the model on the task of predicting if two matrices are permutations of each other.
    """
    n_correct = 0
    for _ in range(n_epochs):
        x, y = gen_training_data(batch_size, mat_size)
        y_pred = model(x)
        n_correct += ((y_pred > 0.5) == y.view(-1, 1)).sum().item()

    return n_correct / (n_epochs * batch_size)


def find_closed_row_col_perm(
    matrix: Float[Tensor, "row col"], target: Float[Tensor, "row col"]
) -> Float[Tensor, "row col"]:
    """
    Finds the row and column permutations that minimize the distance between the matrix and the target.
    """

    mat_size = matrix.shape[0]

    best = None
    dist_best = float("inf")

    for row_perm in itertools.permutations(range(mat_size)):
        for col_perm in itertools.permutations(range(mat_size)):
            permuted = matrix[row_perm, :][:, col_perm]
            dist = ((permuted - target).abs()).sum()
            if dist < dist_best:
                best = permuted
                dist_best = dist
    return best

In [None]:
m = torch.arange(9).view(3, 3)
print(m)
print(permute_rows_and_cols(m))
print(shuffle_matrix(m))
print(gen_training_data(2, 3, 0.1))

In [None]:
# Training loop
from tqdm.notebook import tqdm
from torch.optim import AdamW
import wandb


MAT_SIZE = 3
batch_size = 100
n_epochs = 300_000
lr = 1e-3
use_wandb = True
noise = 0.01
hidden_multipliers = [4, 4]
weight_decay = 1e-4
loss_fn = nn.BCELoss()

new_run = True
if new_run:
    model = make_mlp(hidden_multipliers, MAT_SIZE)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    if use_wandb:
        wandb.init(
            project="meta-models",
            config={
                "batch_size": batch_size,
                "n_epochs": n_epochs,
                "lr": lr,
                "weight_decay": weight_decay,
                "mat_size": MAT_SIZE,
                "noise": noise,
                "hidden_multipliers": hidden_multipliers,
                "nb_hidden_layers": len(hidden_multipliers),
                "nb_params": sum(p.numel() for p in model.parameters()),
            },
        )

for epoch in tqdm(range(n_epochs)):
    optimizer.zero_grad()
    x, y = gen_training_data(batch_size, MAT_SIZE, noise=noise)
    y_pred = model(x)
    loss = loss_fn(y_pred, y.view(-1, 1).float())
    loss.backward()
    optimizer.step()
    if use_wandb:
        wandb.log({"loss": loss.item()})

    if epoch % 1_000 == 0:
        accuracy = get_accuracy(model, MAT_SIZE, 20)
        if use_wandb:
            wandb.log({"accuracy": accuracy})
        else:
            print(f"Epoch {epoch}: accuracy {accuracy:.3f}")

In [None]:
accuracy = get_accuracy(model, MAT_SIZE)
print(accuracy)

In [None]:
from pathlib import Path

models_dir = Path("models")
models_dir.mkdir(exist_ok=True)
idx = len(list(models_dir.glob("*.pt"))) + 1
arch_str = "_".join(map(str, hidden_multipliers))
name = f"models/permutation-of-matrices-noisy-{MAT_SIZE}x{MAT_SIZE}-{accuracy*100:.0f}acc-{arch_str}arch-{idx}.pt"

torch.save(model, name)
print(f"Saved model to {name}")

Now we compute the "per permutation accuracy", that is how frequently is the model correct when it predicts a permutation.

In [None]:
import itertools
from dataclasses import dataclass


@dataclass
class Result:
    row_perm: list[int]
    col_perm: list[int]
    correct: int
    incorrect: int
    predictions: list[bool]

    @property
    def accuracy(self):
        return self.correct / (self.correct + self.incorrect)


n_epochs = 10
batch_size = 100
results = []
for row_perm in tqdm(itertools.permutations(range(MAT_SIZE))):
    for col_perm in itertools.permutations(range(MAT_SIZE)):
        result = Result(row_perm, col_perm, 0, 0, [])
        for _ in range(n_epochs):
            x, y = gen_training_data(batch_size, MAT_SIZE)
            x[..., 1, :, :] = x[..., 0, row_perm, :]
            x[..., 1, :, :] = x[..., 1, :, col_perm]
            y[...] = True
            y_pred = model(x)
            y = y.view(-1, 1)
            correct = ((y_pred > 0.5) == y).sum().item()
            result.correct += correct
            result.incorrect += batch_size - correct
            result.predictions.extend(y_pred.view(-1).tolist())
        results.append(result)
        print(f"Accuracy for {row_perm} and {col_perm}: {result.accuracy:.3f}")

In [None]:
import math

nb_permutations = math.factorial(MAT_SIZE) ** 2
nb_shuffles = math.factorial(MAT_SIZE**2)

max_possible_accuracy = nb_permutations / nb_shuffles

print(f"Max possible accuracy: {max_possible_accuracy}")
print(f"Nb permutations: {nb_permutations}")
print(f"Nb shuffles: {nb_shuffles}")

In [None]:
# Computing the baseline from `find_closed_row_col_perm`

In [None]:
# Find input on which the model is wrong
x, y = gen_training_data(10000, MAT_SIZE)
y_pred = model(x)
wrong = (y_pred > 0.5).flatten() != y
idx_wrong = wrong.nonzero(as_tuple=True)[0]
x = x[idx_wrong]

x: Tensor = torch.cat((x, torch.zeros_like(x)), dim=1)
for a, b, c, d in x:
    c[...] = find_closed_row_col_perm(a, b)
    d[...] = b - c

worst: Tensor = x[:, 3].flatten(1).abs().max(dim=1).values
worst = worst.max(dim=0).indices.item()
x = x[worst : worst + 1]

px.imshow(
    x.flatten(0, 1), facet_col=0, facet_col_wrap=4, height=400 * len(x), width=1200
).show()

In [None]:
# Layer that is invariant to permutations of rows and cols of a matrix

# We can process stuff independently (invariant to permutations)
# We can do equivariant mixing of values
# We can then combine everything with pooling (max, mean, sum, attention...)


class SetPermutationInvariant(nn.Module):
    """A layer that is symmetric polynomial of the input."""

    def __init__(self, d_in: int, n_head: int = 4, d_out: int | None = None):
        super().__init__()

        self.d_in = d_in
        self.n_head = n_head
        self.d_out = d_out or d_in

        self.read_head = nn.Linear(d_in, n_head)
        self.write_head = nn.Linear(7 * self.n_head, self.d_out)

    def forward(
        self, x: Float[Tensor, "*batch token d_model"]
    ) -> Float[Tensor, "*batch token d_model"]:
        x = self.read_head(x)

        # All the possible reductions
        r1 = x.mean(dim=-2)
        r2 = x.max(dim=-2).values
        r3 = x.min(dim=-2).values
        r4 = x.norm(dim=-2)
        r5 = x.var(dim=-2)
        r6 = x.median(dim=-2).values
        # Kurtosis is not implemented in PyTorch
        r7 = (x - r1.unsqueeze(-2)).pow(4).mean(dim=-2) / r5.pow(2)

        reductions = torch.stack([r1, r2, r3, r4, r5, r6, r7], dim=-2)
        return self.write_head(reductions.flatten(-2, -1))


from metamodels.utils import MLP


class NN(nn.Module):
    def __init__(self, in_dim: int, d_model: int):
        super().__init__()

        self.invariant_layers = nn.Sequential(
            nn.Linear(in_dim, d_model),
            MLP(d_model, 4),
        )
        self.reduction = SetPermutationInvariant(d_model, d_out=7)
        self.final = nn.Sequential(
            MLP(7, 4, out_dim=1),
            nn.Sigmoid(),
        )

    def forward(
        self, x: Float[Tensor, "*batch d_model"]
    ) -> Float[Tensor, "*batch d_model"]:
        x = self.invariant_layers(x)
        x = self.reduction(x)
        x = self.final(x)
        return x


a_set = torch.rand(2, 10, 1)
n = NN(1, 12)
print(n)
pred = n(a_set)
pred

In [None]:
def is_permutation(x: Float[Tensor, "batch vec_size_x_2"]) -> Float[Tensor, "batch"]:
    """Return how far the second part of the vector is to a permutation of the first part."""
    *batch, vec_size = x.shape
    assert vec_size % 2 == 0
    vec_size //= 2
    first = x[..., :vec_size]
    second = x[..., vec_size:]
    # Sort them both
    first = first.sort(dim=-1).values
    second = second.sort(dim=-1).values
    # Compute the distance
    return (first - second).abs().sum(dim=-1)


def gen_training_data(
    batch_size: int,
    vec_size: int,
    quantisation: float = 0.05,
    add_last_dim: bool = False,
) -> tuple[Float[Tensor, "batch token"], Bool[Tensor, "batch"]]:
    """
    Generates a batch of training data. All the elements of a batch use the same permutation.
    """
    A = torch.randn(batch_size, vec_size)

    permuted = A[:, torch.randperm(vec_size)]
    completely_random = torch.rand_like(A)

    to_permute = torch.rand(batch_size) > 0.5
    B = torch.where(to_permute.view(-1, 1), permuted, completely_random)

    # Quantize A and B
    A = (A / quantisation).round() * quantisation
    B = (B / quantisation).round() * quantisation

    # Concatenate them
    x = torch.cat((A, B), dim=-1)

    # Compute correct labels (the random ones might be a permutation of the first ones)
    y = is_permutation(x)

    if add_last_dim:
        x = x[..., None]
    return x, y < quantisation


vec_size = 20
x, y = gen_training_data(2, vec_size, 0.1)
print(x.shape, y.shape)
print(x)
print(y)
n(x[..., None])

n2 = nn.Sequential(MLP(vec_size * 2, 4, out_dim=1), nn.Sigmoid())
n2(x)

In [None]:
# net = NN(1, 12)
# train(net, lambda batch_size: gen_training_data(batch_size, 5, add_last_dim=True))

In [None]:
def get_accuracy(net: nn.Module, data_generator) -> float:
    correct = 0
    for _ in range(100):
        x, y = data_generator(100)
        y_pred = net(x)
        correct += ((y_pred > 0.5) == y.view(-1, 1)).sum().item()
    return correct / (100 * 100)


get_accuracy(net, lambda bs: gen_training_data(bs, vec_size, add_last_dim=last_dim))

In [None]:
# Find adversarial examples on which the model is wrong


def adv_examples(
    net: nn.Module,
    batch_size: int = 1,
    steps: int = 100,
    vec_size: int = 5,
    add_last_dim: bool = False,
) -> tuple[Float[Tensor, "batch vec_size_x_2"], Bool[Tensor, "batch"]]:
    x = torch.rand(batch_size, vec_size * 2)
    if add_last_dim:
        x = x[..., None]
    x.requires_grad = True
    loss_fn = nn.BCELoss()

    for _ in range(steps):
        y_pred = net(x)
        dist_to_perm = is_permutation(x.squeeze(-1))
        y_true = dist_to_perm < 0.1
        loss = loss_fn(y_pred, 1 - y_true.view(-1, 1).float())
        # Add a term to maximize the distance to a permutation
        # loss = loss + dist_to_perm.mean()
        # Add a term to minimize the norm of the input
        # loss = loss + x.norm(dim=1).mean() * 0
        # Penalise values outside of [-1, 1]
        loss = loss + (x.abs() - 1).clamp(min=0).mean() * 0.1
        loss.backward()
        x.data += x.grad.data * 0.01
        x.grad = None
        # Clamp the values
        # x.data = torch.clamp(x.data, -1, 1)
    if random.random() < 0.01:
        print("Dist to perm:", dist_to_perm.mean().item())
        print("Loss:", loss.item())
    return x.detach(), y_true


x, y = adv_examples(net, vec_size=vec_size, steps=100, add_last_dim=last_dim)

x1 = x[0, :vec_size]
x2 = x[0, vec_size:]
x1_sorted = x1.sort().values
x2_sorted = x2.sort().values

px.imshow(torch.stack((x1_sorted, x2_sorted), dim=0), height=400, width=1200).show()

# Check accuracy
y_pred = net(x)
print("* Pred:", y_pred.item())
print("* True:", y.item())
print(x.squeeze())

In [None]:
vec_size = 4
net = NN(1, 12)
net = nn.Sequential(MLP(vec_size * 2, 4, 1, out_dim=1), nn.Sigmoid())
last_dim = False
# net = nn.Sequential(SetPermutationInvariant(1, 2, 14), MLP(14, 1, out_dim=1), nn.Sigmoid())
# last_dim = True
# train(net, lambda batch_size: gen_training_data(batch_size, 5, add_last_dim=last_dim), lr=0.05)

In [None]:
train(
    net,
    lambda batch_size: gen_training_data(batch_size, vec_size, add_last_dim=last_dim),
    adversary=lambda batch_size: adv_examples(
        net, batch_size, vec_size=vec_size, add_last_dim=last_dim
    ),
    lr=1e-4,
    steps=10000,
)

In [None]:
get_accuracy(net, lambda bs: gen_training_data(bs, vec_size, add_last_dim=last_dim))

In [None]:
net

In [None]:
w = net[0].layers[0].weight.detach().clone()
b = net[0].layers[0].bias.detach().clone()
w = torch.cat((w, b.unsqueeze(-1)), dim=-1)
# w[w.abs() < 0.1] = float('nan')
px.imshow(w, height=1200)

# Trying the dot product layer

In [None]:
import metamodels.utils as utils
import torchinfo

N = 4
free_space = N * 2
d_model = N**2 * 2 + free_space

net = nn.Sequential(
    nn.Flatten(),
    utils.SkipSequential(
        utils.MLP(d_model, 4),
        utils.DotProductLayer(d_model, 8, N),
        utils.MLP(d_model, 4),
        utils.DotProductLayer(d_model, 8, N),
    ),
    # Pooling layer
    utils.MLP(d_model, 4, out_dim=1),
    nn.Sigmoid(),
)

torchinfo.summary(net, (100, 2, N, N))

In [None]:
utils.train(
    net,
    lambda bs: gen_training_data(bs, N),
)