In [None]:
import time
import torch
from torch import Tensor
import torch.nn as nn

In [None]:
device = torch.device("cpu")

In [None]:
class Permutation(nn.Module):
    def __init__(self, n: int, perm: torch.Tensor) -> None:
        super().__init__()
        self.n = n
        self.perm = perm.clone().detach()

    def forward(self, x: Tensor) -> Tensor:
        if x.ndim == 1:
            return x[self.perm]
        return x[:, self.perm]

    def __eq__(self, other: object) -> bool:
        if isinstance(other, Permutation):
            return self.n == other.n and torch.all(self.perm == other.perm)
        return False

    def __ne__(self, value: object) -> bool:
        return not self.__eq__(value)

    def __hash__(self) -> int:
        return hash((self.n, self.perm.tolist()))

In [None]:
import itertools
from collections import deque
from typing import Iterator


def create_all_permutations(n: int) -> Iterator[Permutation]:
    for perm in itertools.permutations(range(n)):
        yield Permutation(n, torch.tensor(perm, dtype=torch.long))


def create_permutations_from_generators(n: int, generators: list[Permutation]) -> Iterator[Permutation]:
    assert all(perm.n == n for perm in generators)

    def compose(p1: Permutation, p2: Permutation) -> Permutation:
        return Permutation(n, p1.forward(p2.perm))

    identity = Permutation(n, torch.arange(n))
    generated_perms = {identity}
    queue = deque([identity])

    yield identity

    while queue:
        current_perm = queue.popleft()
        for gen in generators:
            new_perm = compose(current_perm, gen)
            if new_perm not in generated_perms:
                generated_perms.add(new_perm)
                queue.append(new_perm)
                yield new_perm

In [None]:
class GaussianDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples: int, d: int) -> None:
        super().__init__()

        self.d = d
        self.num_samples = num_samples

        labels = torch.randint(0, 2, (num_samples, 1))
        data = torch.randn(num_samples, d)

        variance = 0.1 * (labels == 0) + 10.0 * (labels != 0)
        data = data * variance

        self.data = data.to(torch.float32)
        self.labels = labels.to(torch.float32)

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> Tensor:
        return self.data[idx], self.labels[idx]

In [None]:
class CanonicalModel(nn.Module):
    def __init__(self, d: int, model: nn.Module, device: torch.device) -> None:
        super().__init__()
        self.d = d
        self.model = model.to(device)
        self.device = device

    def forward(self, x: Tensor) -> Tensor:
        x = torch.sort(x, dim=-1, descending=True).values
        return self.model(x)

In [None]:
class SymmetryModel(nn.Module):
    def __init__(self, d: int, perms: Iterator[Permutation], model: nn.Module, device: torch.device) -> None:
        super().__init__()
        self.d = d
        self.model = model.to(device)
        self.device = device
        self.perms = perms

    def forward(self, x: Tensor) -> Tensor:
        permuted_x = torch.stack([perm(x) for perm in self.perms])
        outputs = self.model(permuted_x)
        return torch.mean(outputs, dim=0)

In [None]:
class LinearEquivariant(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.b = torch.nn.Parameter(torch.randn(1))
        self.alpha = torch.nn.Parameter(torch.randn(1))
        self.beta = torch.nn.Parameter(torch.randn(1))

    def forward(self, x: Tensor) -> Tensor:
        result = self.beta * x + self.alpha * torch.sum(x, dim=-1, keepdim=True) + self.b
        return result


class LinearInvariant(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.b = torch.nn.Parameter(torch.randn(1))
        self.alpha = torch.nn.Parameter(torch.randn(1))

    def forward(self, x: Tensor) -> Tensor:
        result = self.alpha * torch.sum(x, dim=-1, keepdim=True) + self.b
        return result

In [None]:
from torch.utils.data import DataLoader

num_samples = 1000
hidden_dim = 500

train_size = int(0.8 * num_samples)
test_size = num_samples - train_size

ds_train = GaussianDataset(train_size, hidden_dim)
ds_test = GaussianDataset(test_size, hidden_dim)

dl_train = DataLoader(
    dataset=ds_train,
    batch_size=32,
    shuffle=True,
)

dl_test = DataLoader(
    dataset=ds_test,
    batch_size=32,
    shuffle=True,
)

In [None]:
from training import RegularTrainer

In [None]:
# TODO: DOES NOT TRAIN WTF :C
# NONE OF THE NETWORKS EXHIBIT ANY FORM OF LEARNING

In [None]:
model = nn.Sequential(
    LinearEquivariant(),
    nn.ReLU(),
    LinearEquivariant(),
    nn.ReLU(),
    LinearInvariant(),
    nn.Sigmoid(),
)

intrinsic_trainer = RegularTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.SGD(model.parameters(), lr=0.01),
    device=device,
)

intrinsic_trainer.fit(dl_train=dl_train, dl_test=dl_test, num_epochs=100)

In [None]:
layers = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_dim, 1, bias=True),
    nn.Sigmoid(),
)

In [None]:
model = CanonicalModel(hidden_dim, layers, device)

canonical_trainer = RegularTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.01),
    device=device,
)

canonical_trainer.fit(dl_train=dl_train, dl_test=dl_test, num_epochs=100)

In [None]:
model = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_dim, 1, bias=True),
    nn.Sigmoid(),
)

perms = list(create_all_permutations(hidden_dim))

model = SymmetryModel(hidden_dim, perms, layers, device)

symmetry_trainer = RegularTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.SGD(model.parameters(), lr=0.01),
    device=device,
)

symmetry_trainer.fit(dl_train=dl_train, dl_test=dl_test, num_epochs=100)