In [None]:
import math
import torch
from torch import Tensor
import torch.nn as nn

## Dataset

In [None]:
from src.gaussian_dataset import GaussianDataset
from torch.utils.data import DataLoader

N = 10
D = 5
data_samples = 1000

train_size = int(0.8 * data_samples)
test_size = data_samples - train_size

ds_train = GaussianDataset(num_samples=train_size, shape=(N, D), var1=1.0, var2=0.8, static=False)

ds_test = GaussianDataset(num_samples=test_size, shape=(N, D), var1=1.0, var2=0.8, static=True)

dl_train = DataLoader(dataset=ds_train, batch_size=32, shuffle=False)

dl_test = DataLoader(dataset=ds_test, batch_size=32, shuffle=False)

## Models

### General Purpose

In [None]:
from src.models import test_invariant, test_equivariant

device = torch.device("cpu")

In [None]:
from src.training import BinaryTrainer
from src.layers import LinearEquivariant, LinearInvariant, PositionalEncoding


def create_mlp_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        nn.Flatten(start_dim=1),
        nn.Linear(in_features=n * d, out_features=10 * d),
        nn.ReLU(),
        nn.Linear(in_features=10 * d, out_features=10 * d),
        nn.ReLU(),
        nn.Linear(in_features=10 * d, out_features=1),
        nn.Sigmoid(),
    )


def create_transformer_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        PositionalEncoding(d_model=d, max_len=n),
        nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(batch_first=True, d_model=d, nhead=1),
            norm=nn.LayerNorm(normalized_shape=d),
            num_layers=1,
        ),
        nn.Flatten(start_dim=1),
        nn.Linear(in_features=n * d, out_features=1),
        nn.Sigmoid(),
    )


def create_invariant_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        LinearEquivariant(in_channels=d, out_channels=10),
        nn.ReLU(),
        LinearEquivariant(in_channels=10, out_channels=10),
        nn.ReLU(),
        LinearInvariant(in_channels=10, out_channels=1),
        nn.Sigmoid(),
    )

In [None]:
def train_model(model: nn.Module, epochs: int = 200, log: bool = True):
    trainer = BinaryTrainer(
        model=model,
        loss_fn=nn.BCELoss(),
        optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
        device=device,
        log=log,
    )

    trainer.fit(
        dl_train=dl_train,
        dl_test=dl_test,
        num_epochs=epochs,
        print_every=25,
        early_stopping=100,
    )

### Canonization Based

#### MLP-Based

In [None]:
from src.models import CanonicalModel

model = CanonicalModel(create_mlp_model(N, D))

test_invariant(model, input=torch.randn(32, N, D))

In [None]:
train_model(model)

#### Attention-Based

In [None]:
model = CanonicalModel(create_transformer_model(N, D))

test_invariant(model, input=torch.randn(32, N, D))

In [None]:
train_model(model)

### Symmetrization Network

#### MLP-Based

In [None]:
from src.permutation import Permutation, create_all_permutations, create_permutations_from_generators
from src.models import SymmetryModel

shift_perm = Permutation((torch.arange(N) + 1) % N)

model = SymmetryModel(
    model=create_mlp_model(N, D),
    perm_creator=lambda: create_permutations_from_generators([shift_perm]),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

In [None]:
train_model(model)

#### Attention-Based

In [None]:
shift_perm = Permutation((torch.arange(N) + 1) % N)

model = SymmetryModel(
    model=create_transformer_model(N, D),
    perm_creator=lambda: create_permutations_from_generators([shift_perm]),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

In [None]:
train_model(model)

#### MLP-Based Sampled Symmetrization 

In [None]:
num = int(math.factorial(N) * 0.05)
num = 10

model = SymmetryModel(
    model=create_mlp_model(N, D),
    perm_creator=lambda: (Permutation(torch.randperm(N)) for _ in range(num)),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

In [None]:
train_model(model)

### Intrinsic Invariant

In [None]:
model = create_invariant_model(N, D)

test_invariant(model, torch.randn(32, N, D))

In [None]:
train_model(model)

### Standard with Augmentation

In [None]:
class Augmentation(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        """
        Randomly permute the input tensor along the channel dimension.

        Args:
            x (Tensor): Input tensor of shape (batch_size, d, channel)
        """
        rnd = torch.randn_like(x)
        indices = rnd.argsort(dim=-1)
        result = torch.gather(x, -1, indices)
        return result

In [None]:
model = nn.Sequential(
    Augmentation(),
    create_invariant_model(N, D),
)

test_invariant(model, torch.randn(32, N, D))

In [None]:
train_model(model)

---
---

### Question 4: Challenges encountered during Implementation:

##### Numeric Errors:

The first challenge encountered is in the implementation of the invariant and equivariant layers.
The main implementation challenge rose from the fact that in the lecture, the equivariant layer is formulated as follows:

$$ F(x) : \mathbb{R}^{n \times d} \rightarrow \mathbb{R}^{n \times d'} $$

$$ F(x)_j = \sum _{i=1} ^ {d} L_{ij}(x) $$ 
where $L_{ij}(x)$ is a single feature linear equivariant layer.

Technically, this implementation is indeed correct, but the summation over all $L_{ij}(x)$ might causes layer outputs to blow-up.  
As result, the outputs of the $F \circ a \circ F ...$ become very large.

Our network is composed of these layers $\phi \circ F \circ a \circ F ...$, when $\phi$ is the sigmoid function that returns values between 0 and 1.

Since the last layer of the network is a sigmoid function, and the results of the previous layers are very large (their absolute value), the sigmoid function saturates and returns either 0.0 or 1.0. Because the sigmoid function got saturated, the propagated gradients become 0, hence the network does not learn.

To resolve this issue we defined the equivariant layer as follows:

$$ F(x)_j = \frac{1}{d} \sum _{i=1} ^ {d} L_{ij}(x) $$ 

This formulation still retains the equivariance property, but it prevents the layer outputs from blowing-up.

*Note: We applied the same averaging technique to the invariant layers as well.*

##### Overfitting:

Another big issue we encountered was overfitting. To overcome it, we added an option to dynamically generate the data every time the `Dataset` is accessed. 
This way, the model never sees the same data twice, and not able to overfit. That indeed resolved completely the overfitting issue.
For the comparative analysis, we didn't use this option.

##### Symmetrization Network:

The symmetrization network is a very powerful tool to learn equivariant functions. However, it is computationally expensive
and tricky to implement efficiently. Our implementation balances performance and memory utilization by forwarding through the network multiple 
permuted versions of the input data at once (by creating a super-batch). We can control that number of forwarded permutations to balance performance and memory (more permutations - better performance but higher memory utilization).

### Question 8:

Currently, we're using the symmetry group $S_n$ over the channel dimensions.
A better symmetry group to use would be $S_n \times S_d$ when $S_n$ acts on the channel dimension and $S_d$ acts on the feature dimension. The reason this symmetry group is suitable is because each feature is a vector of length $d$ generated from a normal distribution, and any permutation of the vector does not change the probability of it being generated, nor the underlying distribution that generated it. Since the model tries to detect the underlying distribution, it should be invariant to permutations of the feature dimensions.

Formally:

$$ \Pr(x_1, x_2, ... x_n \sim \mathcal{N}(0, I) \; | \; x_1, x_2, ... x_n) = 
\Pr(\sigma \cdot x_1, \sigma \cdot x_2, ... \sigma \cdot x_n \sim \mathcal{N}(0, I) \; | \; x_1, x_2, ... ,x_n, \forall \sigma \in S_d) $$

when $x_i$ is a feature vector of length $d$ and $\sigma$ is a permutation of the feature dimensions
(remember that each input sample is composed of $n$ feature vectors of length $d$).