In [None]:
import math
import torch
from torch import Tensor
import torch.nn as nn

In [None]:
device = torch.device("cpu")

In [None]:
class Permutation(nn.Module):
    def __init__(self, perm: torch.Tensor) -> None:
        super().__init__()
        self.perm = perm.clone().detach()
        self.hash = hash(tuple(perm.tolist()))

    def forward(self, x: Tensor) -> Tensor:
        if x.ndim == 1:
            return x[self.perm]
        return x[:, self.perm]

    def __len__(self) -> int:
        return len(self.perm)

    def __eq__(self, other: object) -> bool:
        if isinstance(other, Permutation):
            return torch.all(self.perm == other.perm)
        return False

    def __ne__(self, value: object) -> bool:
        return not self.__eq__(value)

    def __hash__(self) -> int:
        return self.hash

In [None]:
import itertools
from collections import deque
from typing import Iterator


def create_all_permutations(perm_length: int) -> Iterator[Permutation]:
    for perm in itertools.permutations(range(perm_length)):
        yield Permutation(torch.tensor(perm, dtype=torch.long))


def create_permutations_from_generators(generators: list[Permutation]) -> Iterator[Permutation]:
    def compose(p1: Permutation, p2: Permutation) -> Permutation:
        return Permutation(p1.forward(p2.perm))

    length = len(generators[0])
    id = Permutation(torch.arange(length))
    generated_perms = {id}
    queue = deque([id])

    yield id

    while queue:
        current_perm = queue.popleft()
        for gen in generators:
            new_perm = compose(current_perm, gen)
            if new_perm not in generated_perms:
                generated_perms.add(new_perm)
                queue.append(new_perm)
                yield new_perm

In [None]:
class CanonicalModel(nn.Module):
    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        self.model = model.to(device)

    def forward(self, x: Tensor) -> Tensor:
        x = torch.sort(x, dim=-1, descending=True).values
        return self.model(x)

In [None]:
class SymmetryModel(nn.Module):
    def __init__(self, model: nn.Module, perms: Iterator[Permutation]) -> None:
        super().__init__()
        self.model = model
        self.perms = perms

    def forward(self, x: Tensor) -> Tensor:
        permuted_x = torch.stack([perm(x) for perm in self.perms])
        outputs = self.model(permuted_x)
        return torch.mean(outputs, dim=0)

In [None]:
class LinearEquivariant(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        """
        Initializes a LinearEquivariant module.
        This module is a custom linear layer that is equivariant to permutations of the input.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.b = torch.nn.Parameter(torch.randn(in_channels, out_channels))
        self.alpha = torch.nn.Parameter(torch.randn(in_channels, out_channels))
        self.beta = torch.nn.Parameter(torch.randn(in_channels, out_channels))

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the LinearEquivariant module.

        Args:
            x (Tensor): Input tensor of shape (batch_size, d, in_channels).

        Returns:
            Tensor: Output tensor of shape (batch_size, d, out_channels).
        """

        assert x.ndim == 3
        assert x.shape[-1] == self.in_channels

        # shape (batch_size, d, in_channels, 1)
        x = x.unsqueeze(-1)

        # shape (batch_size, 1, in_channels, 1)
        x_sum = torch.sum(x, dim=1, keepdim=True)

        # shape (batch_size, d, in_channels, out_channels)
        all = x * self.alpha + x_sum * self.beta + self.b

        # shape (batch_size, d, out_channels)
        reduced = torch.mean(all, dim=2)

        return reduced

    def _forward_manual(self, x: Tensor) -> Tensor:
        """
        Performs the forward pass manually using hard-coded loops.
        * FOR TESTING PURPOSES ONLY. DO NOT USE IN PRODUCTION.

        Args:
            x (Tensor): Input tensor of shape (batch_size, hdim, in_channels).

        Returns:
            Tensor: Output tensor of shape (batch_size, hdim, out_channels).
        """

        assert x.ndim == 3
        assert x.shape[-1] == self.in_channels

        batch_size, hdim, in_channels = x.shape
        out_channels = self.out_channels

        # shape (batch_size, in_channels)
        x_sum = torch.sum(x, dim=1)

        # shape (batch_size, hdim, out_channels)
        result = torch.zeros(x.shape[0], x.shape[1], self.out_channels)

        for b in range(batch_size):
            for c_out in range(out_channels):
                for c_in in range(in_channels):
                    alpha = self.alpha[c_in, c_out]
                    beta = self.beta[c_in, c_out]
                    bias = self.b[c_in, c_out]
                    res = x[b, :, c_in] * alpha + x_sum[b, c_in] * beta + bias
                    result[b, :, c_out] += res

        result = result / in_channels
        return result


class LinearInvariant(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        """
        Initialize the LinearInvariant module.
        This module is a custom linear layer that is invariant to permutations of the input.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.b = torch.nn.Parameter(torch.randn(in_channels, out_channels))
        self.alpha = torch.nn.Parameter(torch.randn(in_channels, out_channels))

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the LinearInvariant module.

        Args:
            x (Tensor): Input tensor of shape (batch_size, d, in_channels).

        Returns:
            Tensor: Output tensor of shape (batch_size, 1, out_channels).
        """
        assert x.ndim == 3
        assert x.shape[-1] == self.in_channels

        # shape (batch_size, d, in_channels, 1)
        x = x.unsqueeze(-1)

        # shape (batch_size, 1, in_channels, 1)
        x_sum = torch.sum(x, dim=1, keepdim=True)

        # shape (batch_size, 1, in_channels, out_channels)
        all = x_sum * self.alpha + self.b

        # shape (batch_size, 1, out_channels)
        reduced = torch.mean(all, dim=2)

        return reduced

In [None]:
class LinearMultiChannel(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, in_features: int, out_features: int) -> None:
        """
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            in_features (int): Number of input features.
            out_features (int): Number of output features.
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.in_features = in_features
        self.out_features = out_features
        self.b = torch.nn.Parameter(torch.randn(in_channels, out_channels, out_features, in_features))
        self.weights = torch.nn.Parameter(torch.randn(in_channels, out_channels, out_features, in_features))

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x (Tensor): Input tensor of shape (batch_size, in_features, in_channels).

        Returns:
            Tensor: Output tensor of shape (batch_size, out_features, out_channels).
        """
        assert x.ndim == 3
        assert x.shape[1] == self.in_features
        assert x.shape[2] == self.in_channels

        # shape (batch_size, in_channels, in_features)
        x = x.permute(0, 2, 1)

        # shape (batch_size, in_channels, 1, in_features, 1)
        x = x.view(-1, self.in_channels, 1, self.in_features, 1)

        # shape (batch_size, in_channels, out_channels, out_features, 1)
        all = self.weights @ x + self.b

        # shape (batch_size, out_channels, out_features)
        reduced = torch.mean(all, dim=(1, 4), keepdim=False)

        # shape (batch_size, out_features, out_channels)
        reduced = reduced.permute(0, 2, 1)

        return reduced

In [None]:
class GaussianDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        num_samples: int,
        shape: tuple[int, ...],
        var1: float = 1.0,
        var2: float = 0.8,
        static: bool = True,
    ) -> None:
        super().__init__()
        self.shape = shape
        self.var1 = var1
        self.var2 = var2
        self.static = static

        self.labels = torch.randint(0, 2, (num_samples,)).to(torch.float32)

        if static:
            var = var1 * (self.labels == 0) + var2 * (self.labels == 1)
            data = torch.randn(num_samples, *shape)
            self.data = torch.swapaxes(data.swapaxes(0, -1) * torch.sqrt(var), 0, -1)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
        label = self.labels[idx]

        if self.static:
            return self.data[idx], label

        var = self.var1 if label == 0 else self.var2
        data = torch.randn(self.shape) * math.sqrt(var)
        return data, label

In [None]:
from torch.utils.data import DataLoader

N = 10
D = 5
data_size = 1000

train_size = int(0.8 * data_size)
test_size = data_size - train_size

ds_train = GaussianDataset(
    num_samples=train_size,
    shape=(N, D),
    var1=1.0,
    var2=0.8,
    static=False,
)

ds_test = GaussianDataset(
    num_samples=test_size,
    shape=(N, D),
    var1=1.0,
    var2=0.8,
    static=True,
)

dl_train = DataLoader(
    dataset=ds_train,
    batch_size=32,
    shuffle=False,
)

dl_test = DataLoader(
    dataset=ds_test,
    batch_size=32,
    shuffle=False,
)

In [None]:
from training import BinaryTrainer

# TODO: DOES NOT TRAIN WTF :C
# NONE OF THE NETWORKS EXHIBIT ANY FORM OF LEARNING

In [None]:
model = nn.Sequential(
    LinearEquivariant(in_channels=D, out_channels=10),
    nn.ReLU(),
    LinearEquivariant(in_channels=10, out_channels=10),
    nn.ReLU(),
    LinearInvariant(in_channels=10, out_channels=1),
    nn.Sigmoid(),
)

trainer = BinaryTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    device=device,
    log=True,
)

trainer.fit(
    dl_train=dl_train,
    dl_test=dl_test,
    num_epochs=300,
    print_every=10,
)

In [None]:
layers = nn.Sequential(
    LinearMultiChannel(in_channels=D, out_channels=10, in_features=N, out_features=N),
    nn.ReLU(),
    LinearMultiChannel(in_channels=10, out_channels=10, in_features=N, out_features=N),
    nn.ReLU(),
    LinearMultiChannel(in_channels=10, out_channels=1, in_features=N, out_features=1),
    nn.Sigmoid(),
)

model = CanonicalModel(layers)

trainer = BinaryTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    device=device,
    log=True,
)

trainer.fit(
    dl_train=dl_train,
    dl_test=dl_test,
    num_epochs=300,
    print_every=10,
)

In [None]:
layers = nn.Sequential(
    LinearMultiChannel(in_channels=D, out_channels=10, in_features=N, out_features=N),
    nn.ReLU(),
    LinearMultiChannel(in_channels=10, out_channels=10, in_features=N, out_features=N),
    nn.ReLU(),
    LinearMultiChannel(in_channels=10, out_channels=1, in_features=N, out_features=1),
    nn.Sigmoid(),
)

perms = list(create_permutations_from_generators([Permutation(torch.arange(hidden_dim))]))

model = SymmetryModel(layers, perms)

trainer = BinaryTrainer(
    model=model,
    loss_fn=nn.BCELoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    device=device,
)

trainer.fit(dl_train=dl_train, dl_test=dl_test, num_epochs=100)

## Challenges encountered during Implementation:

### Numeric Errors:

The first challenge encountered is in the implementation of the invariant and equivariant layers.
The main implementation challenge rose from the fact that in the lecture, the equivariant layer is formulated as follows:

$$ F(x) : \mathbb{R}^{n \times d} \rightarrow \mathbb{R}^{n \times d'} $$

$$ F(x)_j = \sum _{i=1} ^ {d} L_{ij}(x) $$ 
where $L_{ij}(x)$ is a single feature linear equivariant layer.

Technically, this implementation is indeed correct, but the summation over all $L_{ij}(x)$ might causes layer outputs to blow-up.  
As result, the outputs of the $F \circ a \circ F ...$ become very large.

Our network is composed of these layers $\phi \circ F \circ a \circ F ...$, when $\phi$ is the sigmoid function that returns values between 0 and 1.

Since the last layer of the network is a sigmoid function, and the results of the previous layers are very large (their absolute value), the sigmoid function saturates and returns either 0.0 or 1.0. Because the sigmoid function got saturated, the propagated gradients become 0, hence the network does not learn.

To resolve this issue we defined the equivariant layer as follows:

$$ F(x)_j = \frac{1}{d} \sum _{i=1} ^ {d} L_{ij}(x) $$ 

This formulation still retains the equivariance property, but it prevents the layer outputs from blowing-up.

*Note: We applied the same averaging technique to the invariant layers as well.*

### Overfitting:

Another big issue we encountered was overfitting. To overcome it, we added an option to dynamically generate the data every time the `Dataset` is accessed. 
This way, the model never sees the same data twice, and not able to overfit. That indeed resolved completely the overfitting issue.

To compare across different dataset sizes, we trained all our models on the same number of epochs with dynamically generated data.