In [None]:
import math
import torch
from torch import Tensor
import torch.nn as nn


In [None]:
device = torch.device("cpu")

In [None]:
class Permutation(nn.Module):
    def __init__(self, perm: torch.Tensor) -> None:
        super().__init__()
        self.perm = perm.clone().detach()
        self.hash = hash(tuple(perm.tolist()))

    def forward(self, x: Tensor) -> Tensor:
        if x.ndim == 1:
            return x[self.perm]
        return x[:, self.perm]

    def __len__(self) -> int:
        return len(self.perm)

    def __eq__(self, other: object) -> bool:
        if isinstance(other, Permutation):
            return torch.all(self.perm == other.perm)
        return False

    def __ne__(self, value: object) -> bool:
        return not self.__eq__(value)

    def __hash__(self) -> int:
        return self.hash

In [None]:
import itertools
from collections import deque
from typing import Iterator


def create_all_permutations(perm_length: int) -> Iterator[Permutation]:
    for perm in itertools.permutations(range(perm_length)):
        yield Permutation(torch.tensor(perm, dtype=torch.long))


def create_permutations_from_generators(generators: list[Permutation]) -> Iterator[Permutation]:
    def compose(p1: Permutation, p2: Permutation) -> Permutation:
        return Permutation(p1.forward(p2.perm))

    length = len(generators[0])
    id = Permutation(torch.arange(length))
    generated_perms = {id}
    queue = deque([id])

    yield id

    while queue:
        current_perm = queue.popleft()
        for gen in generators:
            new_perm = compose(current_perm, gen)
            if new_perm not in generated_perms:
                generated_perms.add(new_perm)
                queue.append(new_perm)
                yield new_perm

In [None]:
class CanonicalModel(nn.Module):
    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        self.model = model.to(device)

    def forward(self, x: Tensor) -> Tensor:
        x = torch.sort(x, dim=-1, descending=True).values
        return self.model(x)

In [None]:
class SymmetryModel(nn.Module):
    def __init__(self, model: nn.Module, perms: Iterator[Permutation]) -> None:
        super().__init__()
        self.model = model
        self.perms = perms

    def forward(self, x: Tensor) -> Tensor:
        permuted_x = torch.stack([perm(x) for perm in self.perms])
        outputs = self.model(permuted_x)
        return torch.mean(outputs, dim=0)

In [None]:
# TODO ADD in_features and out_features params
class LinearEquivariant(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.b = torch.nn.Parameter(torch.randn(1))
        self.alpha = torch.nn.Parameter(torch.randn(1))
        self.beta = torch.nn.Parameter(torch.randn(1))

    def forward(self, x: Tensor) -> Tensor:
        result = self.beta * x + self.alpha * torch.sum(x, dim=-1, keepdim=True) + self.b
        return result


# TODO ADD in_features and out_features params
class LinearInvariant(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.b = torch.nn.Parameter(torch.randn(1))
        self.alpha = torch.nn.Parameter(torch.randn(1))

    def forward(self, x: Tensor) -> Tensor:
        result = self.alpha * torch.sum(x, dim=-1, keepdim=True) + self.b
        return result

In [None]:
# TODO: GENERATE SYNTHETIC DATA IN REAL TIME TO AVOID OVERFITTING ON TRAIN
class GaussianDataset(torch.utils.data.Dataset):
    def __init__(self, length: int, dim: int, var1: float = 1.0, var2: float = 0.8) -> None:
        super().__init__()

        labels = torch.randint(0, 2, (length, 1))
        data = torch.randn(length, dim)

        variance = math.sqrt(var1) * (labels == 0) + math.sqrt(var2) * (labels != 0)
        data = data * variance

        self.data = data.to(torch.float32)
        self.labels = labels.to(torch.float32)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
        return self.data[idx], self.labels[idx]

In [None]:
from torch.utils.data import DataLoader

data_size = 1000
hidden_dim = 5

train_size = int(0.8 * data_size)
test_size = data_size - train_size

ds_train = GaussianDataset(train_size, hidden_dim, var1=1.0, var2=0.8)
ds_test = GaussianDataset(test_size, hidden_dim, var1=1.0, var2=0.8)

dl_train = DataLoader(
    dataset=ds_train,
    batch_size=32,
    shuffle=True,
)

dl_test = DataLoader(
    dataset=ds_test,
    batch_size=32,
    shuffle=True,
)

In [None]:
from training import BinaryTrainer

# TODO: DOES NOT TRAIN WTF :C
# NONE OF THE NETWORKS EXHIBIT ANY FORM OF LEARNING

In [None]:
model = nn.Sequential(
    LinearEquivariant(),
    nn.ReLU(),
    LinearEquivariant(),
    nn.ReLU(),
    LinearInvariant(),
    nn.Sigmoid(),
)

trainer = BinaryTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    device=device,
)

trainer.fit(
    dl_train=dl_train,
    dl_test=dl_test,
    num_epochs=100,
    print_every=1,
)

In [None]:
layers = nn.Sequential(
    nn.Linear(hidden_dim, 64, bias=True),
    nn.ReLU(),
    nn.Linear(64, 64, bias=True),
    nn.ReLU(),
    nn.Linear(64, 1, bias=True),
    nn.Sigmoid(),
)

model = CanonicalModel(layers)

trainer = BinaryTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    device=device,
    log=True,
)

trainer.fit(
    dl_train=dl_train,
    dl_test=dl_test,
    num_epochs=1000,
    print_every=10,
)

In [None]:
layers = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_dim, hidden_dim, bias=True),
    nn.ReLU(),
    nn.Linear(hidden_dim, 1, bias=True),
    nn.Sigmoid(),
)

perms = list(create_permutations_from_generators([Permutation(torch.arange(hidden_dim))]))

model = SymmetryModel(layers, perms)

trainer = BinaryTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    device=device,
)

trainer.fit(dl_train=dl_train, dl_test=dl_test, num_epochs=100)