In [3]:
import time
import torch
from torch import Tensor
import torch.nn as nn

In [4]:
device = torch.device("cpu")
# device = torch.device("cuda:0")

In [None]:
class Permutation(nn.Module):
    def __init__(self, n: int, perm: torch.Tensor) -> None:
        super().__init__()
        self.n = n
        self.perm = perm.clone().detach()

    def forward(self, x: Tensor) -> Tensor:
        if x.ndim == 1:
            return x[self.perm]
        return x[:, self.perm]

    def __eq__(self, other: object) -> bool:
        if isinstance(other, Permutation):
            return self.n == other.n and torch.all(self.perm == other.perm)
        return False

    def __ne__(self, value: object) -> bool:
        return not self.__eq__(value)

    def __hash__(self) -> int:
        return hash((self.n, self.perm.tolist()))

In [None]:
import itertools
from collections import deque
from typing import Iterator


def create_all_permutations(n: int) -> Iterator[Permutation]:
    for perm in itertools.permutations(range(n)):
        yield Permutation(n, torch.tensor(perm, dtype=torch.long))


def create_permutations_from_generators(n: int, generators: list[Permutation]) -> Iterator[Permutation]:
    assert all(perm.n == n for perm in generators)

    def compose(p1: Permutation, p2: Permutation) -> Permutation:
        return Permutation(n, p1.forward(p2.perm))

    identity = Permutation(n, torch.arange(n))
    generated_perms = {identity}
    queue = deque([identity])

    while queue:
        current_perm = queue.popleft()
        yield current_perm
        for gen in generators:
            new_perm = compose(current_perm, gen)
            if new_perm not in generated_perms:
                generated_perms.add(new_perm)
                queue.append(new_perm)

In [None]:
class GaussianDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples: int, d: int, var: float = 1.0) -> None:
        super().__init__()
        self.d = d
        self.num_samples = num_samples
        self.data = torch.randn(num_samples, d) * var

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, idx: int) -> Tensor:
        return self.data[idx]

In [None]:
class CanonicalModel(nn.Module):
    def __init__(self, d: int, model: nn.Module, device: torch.device) -> None:
        super().__init__()
        self.d = d
        self.model = model.to(device)
        self.device = device

    def forward(self, x: Tensor) -> Tensor:
        torch.sort(x, dim=-1, descending=True, out=x)
        return self.model(x)

In [None]:
class SymmetryModel(nn.Module):
    def __init__(self, d: int, perms: Iterator[Permutation], model: nn.Module, device: torch.device) -> None:
        super().__init__()
        self.d = d
        self.model = model.to(device)
        self.device = device
        self.perms = perms

    def forward(self, x: Tensor) -> Tensor:
        permuted_x = torch.stack([perm(x) for perm in self.perms])
        outputs = self.model(permuted_x)
        return torch.mean(outputs, dim=0)

In [None]:
class IntrinsicModel(nn.Module):
    def __init__(self, d: int, device: torch.device) -> None:
        super().__init__()
        self.d = d
        self.device = device

    def forward(self, x: Tensor) -> Tensor:
        permuted_x = torch.stack([perm(x) for perm in self.perms])
        outputs = self.model(permuted_x)
        return torch.mean(outputs, dim=0)

In [None]:
class LinearEquivariant(nn.Module):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features  # TODO: USE OUT FEATURES AS WELL, WE WANT TO BE EQUIVARIANT FROM R^M TO R^N
        self.bias = torch.ones(in_features)
        self.theta1 = torch.ones(1)
        self.theta2 = torch.ones(1)

        nn.Linear()

    def forward(self, x: Tensor) -> Tensor:
        ONES = torch.ones(self.in_features, self.in_features)
        ID = torch.eye(self.in_features)
        P = ONES * self.theta1 + (ONES - ID) * self.theta2
        out = torch.matmul(P, x) + self.bias
        return out