In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torchinfo
from jaxtyping import Float
from torch import Tensor

from metamodels import utils
from tasks import gen_eigenvalues

print(torch.cuda.is_available())

x, y = gen_eigenvalues(2, 3)
print("x:", x.shape)
print(x)
print("y:", y.shape)
print(y)

In [None]:
N = 4
net = nn.Sequential(nn.Flatten(), utils.MLP(N**2, 16, 16, 16, out_dim=N))

torchinfo.summary(net, (100, N, N))

In [None]:
utils.train(
    net,
    lambda bs: gen_eigenvalues(bs, N),
    10_000,
    lr=1e-4,
    batch_size=1000,
    log_every=100,
    loss_fn=torch.nn.MSELoss(),
    reuse_perfs=True,
)

In [None]:
utils.show_example(net, lambda bs: gen_eigenvalues(bs, N))

In [None]:
# Checking whether eigenvalues depend on the permutation of input rows: they do
x, y = gen_eigenvalues(1, 3)
print("x:", x.shape)
print(x)
print("y:", y.shape)
print(y)
# Permute the rows
x = x[:, torch.randperm(x.shape[1])]
print("x:", x.shape)
print(x)
# Compute the eigenvalues
new_y = torch.linalg.eigvals(x)
print("y:", new_y.shape)
print(new_y)

In [None]:
class MatrixTokenizer(nn.Module):
    def __init__(self, extra_space: int):
        self.extra_space = extra_space
        super().__init__()
        
    def forward(self, x: Float[Tensor, "batch row col"]):
        batch, rows, cols = x.shape
        assert rows == cols
        # Tokenize the rows and columns
        columns = x.transpose(1, 2)
        # Concatenate the rows and columns
        out = torch.cat([x, columns], dim=1)
        # Add zeros to the end of the tokens
        out = torch.cat([out, torch.zeros(*out.shape[:2], self.extra_space, device=x.device)], dim=2)
        return out
       
N = 4
free_space = N
d_model = N + free_space

net = nn.Sequential(
    MatrixTokenizer(free_space),
    utils.SkipSequential(
        utils.MultiAttention(d_model, 4, N),
        utils.MLP(d_model, 4),
        utils.DotProductLayer(d_model, 8, N),
        utils.MultiAttention(d_model, 4, N),
        # utils.MLP(d_model, 4),
    ),
    # Pooling layer
    utils.PoolingLayer("flatten"),
    utils.MLP(d_model * N * 2, 4, out_dim=N),
    # utils.PoolingLayer("max"),
    # utils.MLP(d_model, 4, out_dim=N),
)

torchinfo.summary(net, (100, N, N))

In [None]:
utils.train(net, lambda bs: gen_eigenvalues(bs, N), 
            batch_size=100,
            log_every=200,
            steps=2000,
            loss_fn=torch.nn.MSELoss())