# Meta models investigation

The goal of this notebook is to understand which architecture can learn things about neural networks (NN).
Having NN as imput is a very different data type than images, text or tabular data, so we might need different processing layers to extract meaning from them.

## Permutation of matrices
In NN, one can permute a lot of the hidden activations (all activations in an MLP, channels in a CNN, ...) and still have the same network.

The goal here is to train a NN with:
- two matrices as input, $A$ and $B$
- ouput is whether is a permutation of the rows and colums of the first, that is, if there are permutation matrices $P, Q$ such that $A = PBQ$.

In [None]:
import torch
import torch.nn as nn
from torch import Tensor
import plotly.express as px

from jaxtyping import Float, Int, Bool

In [None]:
def permute_rows_and_cols(matrix: Float[Tensor, "*batch rows cols"]) -> Float[Tensor, "*batch rows cols"]:
    """
    Permutes randomly the rows and columns of the input matrix.
    """
    *batch, rows, cols = matrix.shape
    row_perm = torch.randperm(rows)
    col_perm = torch.randperm(cols)
    return matrix[..., row_perm, :][..., :, col_perm]


def shuffle_matrix(matrix: Float[Tensor, "*batch rows cols"]) -> Float[Tensor, "*batch rows cols"]:
    """
    Shuffles all the elements of the matrix.
    """
    *batch, rows, cols = matrix.shape
    perm = torch.randperm(rows * cols)
    matrix = matrix.view(*batch, -1)
    return matrix[..., perm].view(*batch, rows, cols)

def gen_training_data(batch_size: int, mat_size: int, noise: float = 0.0) -> tuple[Float[Tensor, "batch 2 rows cols"], Bool[Tensor, "batch"]]:
    """
    Generates a batch of training data.
    """
    A = torch.rand(batch_size, mat_size, mat_size)

    permuted = permute_rows_and_cols(A)
    shuffled = shuffle_matrix(A)

    to_permute = torch.rand(batch_size) > 0.5
    B = torch.where(to_permute.view(-1, 1, 1), permuted, shuffled)

    if noise > 0:
        A += torch.randn_like(A) * noise
        B += torch.randn_like(B) * noise

    x = torch.stack([A, B], dim=1)
    return x, to_permute

def get_accuracy(model: nn.Module, mat_size: int, n_epochs: int = 1000, batch_size: int = 100) -> float:
    """
    Computes the accuracy of the model on the task of predicting if two matrices are permutations of each other.
    """
    n_correct = 0
    for _ in range(n_epochs):
        x, y = gen_training_data(batch_size, mat_size)
        y_pred = model(x)
        n_correct += ((y_pred > 0.5) == y.view(-1, 1)).sum().item()

    return n_correct / (n_epochs * batch_size)

def find_closed_row_col_perm(matrix: Float[Tensor, "row col"], target: Float[Tensor, "row col"]) -> Float[Tensor, "row col"]:
    """
    Finds the row and column permutations that minimize the distance between the matrix and the target.
    """

    mat_size = matrix.shape[0]

    best = None
    dist_best = float("inf")

    for row_perm in itertools.permutations(range(mat_size)):
        for col_perm in itertools.permutations(range(mat_size)):
            permuted = matrix[row_perm, :][:, col_perm]
            dist = ((permuted - target).abs()).sum()
            if dist < dist_best:
                best = permuted
                dist_best = dist
    return best


In [None]:
m = torch.arange(9).view(3, 3)
print(m)
print(permute_rows_and_cols(m))
print(shuffle_matrix(m))
print(gen_training_data(2, 3, 0.1))

In [None]:
def make_mlp(hidden_multipliers: list[int], mat_size: int) -> nn.Module:
    dim = 2 * mat_size ** 2
    dims = [dim] + [dim * m for m in hidden_multipliers] + [1]
    layers = [
        layer
        for dim_in, dim_out in zip(dims[:-1], dims[1:])
        for layer in [nn.Linear(dim_in, dim_out), nn.ReLU()]
    ]
    return nn.Sequential(nn.Flatten(), *layers[:-1], nn.Sigmoid())

make_mlp([4], 4)

In [None]:
# Training loop
from tqdm.notebook import tqdm
from torch.optim import AdamW
import wandb


MAT_SIZE = 3
batch_size = 100
n_epochs = 300_000
lr = 1e-3
use_wandb = True
noise = 0.01
hidden_multipliers = [4, 4]
weight_decay = 1e-4
loss_fn = nn.BCELoss()

new_run = True
if new_run:
    model = make_mlp(hidden_multipliers, MAT_SIZE)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    if use_wandb:
        wandb.init(project="meta-models", config={
        "batch_size": batch_size,
        "n_epochs": n_epochs,
        "lr": lr,
        "weight_decay": weight_decay,
        "mat_size": MAT_SIZE,
        "noise": noise,
        "hidden_multipliers": hidden_multipliers,
        "nb_hidden_layers": len(hidden_multipliers),
        "nb_params": sum(p.numel() for p in model.parameters())
    })

for epoch in tqdm(range(n_epochs)):
    optimizer.zero_grad()
    x, y = gen_training_data(batch_size, MAT_SIZE, noise=noise)
    y_pred = model(x)
    loss = loss_fn(y_pred, y.view(-1, 1).float())
    loss.backward()
    optimizer.step()
    if use_wandb:
        wandb.log({"loss": loss.item()})

    if epoch % 1_000 == 0:
        accuracy = get_accuracy(model, MAT_SIZE, 20)
        if use_wandb:
            wandb.log({"accuracy": accuracy})
        else:
            print(f"Epoch {epoch}: accuracy {accuracy:.3f}")

In [None]:
accuracy = get_accuracy(model, MAT_SIZE)
print(accuracy)

In [None]:
from pathlib import Path

models_dir = Path("models")
models_dir.mkdir(exist_ok=True)
idx = len(list(models_dir.glob("*.pt"))) + 1
arch_str = "_".join(map(str, hidden_multipliers))
name = f"models/permutation-of-matrices-noisy-{MAT_SIZE}x{MAT_SIZE}-{accuracy*100:.0f}acc-{arch_str}arch-{idx}.pt"

torch.save(model, name)
print(f"Saved model to {name}")

Now we compute the "per permutation accuracy", that is how frequently is the model correct when it predicts a permutation.

In [None]:
import itertools
from dataclasses import dataclass

@dataclass
class Result:
    row_perm: list[int]
    col_perm: list[int]
    correct: int
    incorrect: int
    predictions: list[bool]

    @property
    def accuracy(self):
        return self.correct / (self.correct + self.incorrect)

n_epochs = 10
batch_size = 100
results = []
for row_perm in tqdm(itertools.permutations(range(MAT_SIZE))):
    for col_perm in itertools.permutations(range(MAT_SIZE)):
        result = Result(row_perm, col_perm, 0, 0, [])
        for _ in range(n_epochs):
            x, y = gen_training_data(batch_size, MAT_SIZE)
            x[..., 1, :, :] = x[..., 0, row_perm, :]
            x[..., 1, :, :] = x[..., 1, :, col_perm]
            y[...] = True
            y_pred = model(x)
            y = y.view(-1, 1)
            correct = ((y_pred > 0.5) == y).sum().item()
            result.correct += correct
            result.incorrect += batch_size - correct
            result.predictions.extend(y_pred.view(-1).tolist())
        results.append(result)
        print(f"Accuracy for {row_perm} and {col_perm}: {result.accuracy:.3f}")



In [None]:
import math
nb_permutations = math.factorial(MAT_SIZE) ** 2
nb_shuffles = math.factorial(MAT_SIZE ** 2)

max_possible_accuracy = nb_permutations / nb_shuffles

print(f"Max possible accuracy: {max_possible_accuracy}")
print(f"Nb permutations: {nb_permutations}")
print(f"Nb shuffles: {nb_shuffles}")

In [None]:
# Computing the baseline from `find_closed_row_col_perm`

In [None]:
# Find input on which the model is wrong
x, y = gen_training_data(10000, MAT_SIZE)
y_pred = model(x)
wrong = (y_pred > 0.5).flatten() != y
idx_wrong = wrong.nonzero(as_tuple=True)[0]
x = x[idx_wrong]

x: Tensor = torch.cat((x, torch.zeros_like(x)), dim=1)
for a, b, c, d in x:
    c[...] = find_closed_row_col_perm(a, b)
    d[...] = b - c

worst: Tensor = x[:, 3].flatten(1).abs().max(dim=1).values
worst = worst.max(dim=0).indices.item()
x = x[worst:worst+1]

px.imshow(
    x.flatten(0, 1),
          facet_col=0, facet_col_wrap=4,
          height=400 * len(x), width=1200
          ).show()

