In [1]:
import math
import torch
from torch import Tensor
import torch.nn as nn

## Dataset

In [2]:
from src.gaussian_dataset import GaussianDataset
from torch.utils.data import DataLoader

N = 10
D = 5
data_samples = 1000

train_size = int(0.8 * data_samples)
test_size = data_samples - train_size

ds_train = GaussianDataset(num_samples=train_size, shape=(N, D), var1=1.0, var2=0.8, static=False)

ds_test = GaussianDataset(num_samples=test_size, shape=(N, D), var1=1.0, var2=0.8, static=True)

dl_train = DataLoader(dataset=ds_train, batch_size=32, shuffle=False)

dl_test = DataLoader(dataset=ds_test, batch_size=32, shuffle=False)

## Models

### General Purpose

In [3]:
from src.models import test_invariant, test_equivariant

device = torch.device("cpu")

In [4]:
from src.training import BinaryTrainer
from src.layers import LinearEquivariant, LinearInvariant, PositionalEncoding


def create_mlp_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        nn.Flatten(start_dim=1),
        nn.Linear(in_features=n * d, out_features=10 * d),
        nn.ReLU(),
        nn.Linear(in_features=10 * d, out_features=10 * d),
        nn.ReLU(),
        nn.Linear(in_features=10 * d, out_features=1),
        nn.Sigmoid(),
    )


def create_transformer_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        PositionalEncoding(d_model=d, max_len=n),
        nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(batch_first=True, d_model=d, nhead=1),
            norm=nn.LayerNorm(normalized_shape=d),
            num_layers=1,
        ),
        nn.Flatten(start_dim=1),
        nn.Linear(in_features=n * d, out_features=1),
        nn.Sigmoid(),
    )


def create_invariant_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        LinearEquivariant(in_channels=d, out_channels=10),
        nn.ReLU(),
        LinearEquivariant(in_channels=10, out_channels=10),
        nn.ReLU(),
        LinearInvariant(in_channels=10, out_channels=1),
        nn.Sigmoid(),
    )

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from src.train_results import FitResult


def train_model(model: nn.Module) -> FitResult:
    trainer = BinaryTrainer(
        model=model,
        criterion=nn.BCELoss(),
        optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
        device=device,
        log=True,
    )

    return trainer.fit(
        dl_train=dl_train,
        dl_test=dl_test,
        num_epochs=10000,
        print_every=25,
        time_limit=60 * 30,
        early_stopping=50,
    )

ModuleNotFoundError: No module named 'train_results'

### Canonization Based

#### MLP-Based

In [None]:
from src.models import CanonicalModel

model = CanonicalModel(create_mlp_model(N, D))

test_invariant(model, input=torch.randn(32, N, D))

In [None]:
train_model(model)

#### Attention-Based

In [None]:
model = CanonicalModel(create_transformer_model(N, D))

test_invariant(model, input=torch.randn(32, N, D))

In [None]:
train_model(model)

### Symmetrization Network

#### MLP-Based

In [None]:
from src.permutation import Permutation, create_all_permutations, create_permutations_from_generators
from src.models import SymmetryModel

shift_perm = Permutation((torch.arange(N) + 1) % N)

model = SymmetryModel(
    model=create_mlp_model(N, D),
    perm_creator=lambda: create_permutations_from_generators([shift_perm]),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

In [None]:
# train_model(model)

#### Attention-Based

In [None]:
shift_perm = Permutation((torch.arange(N) + 1) % N)

model = SymmetryModel(
    model=create_transformer_model(N, D),
    perm_creator=lambda: create_permutations_from_generators([shift_perm]),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

In [None]:
# train_model(model)

#### MLP-Based Sampled Symmetrization 

In [None]:
num = int(math.factorial(N) * 0.05)
num = 30

model = SymmetryModel(
    model=create_mlp_model(N, D),
    perm_creator=lambda: (Permutation(torch.randperm(N)) for _ in range(num)),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

In [None]:
train_model(model)

### Intrinsic Invariant

In [None]:
model = create_invariant_model(N, D)

test_invariant(model, torch.randn(32, N, D))

In [None]:
train_model(model)

### Standard with Augmentation

In [None]:
class Augmentation(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        """
        Randomly permute the input tensor along the channel dimension.

        Args:
            x (Tensor): Input tensor of shape (batch_size, d, channel)
        """
        rnd = torch.randn_like(x)
        indices = rnd.argsort(dim=-1)
        result = torch.gather(x, -1, indices)
        return result

In [None]:
model = nn.Sequential(
    Augmentation(),
    create_invariant_model(N, D),
)

test_invariant(model, torch.randn(32, N, D))

In [None]:
train_model(model)