In [1]:
import math
import torch
from torch import Tensor
import torch.nn as nn

## Dataset

In [2]:
from src.gaussian_dataset import GaussianDataset
from torch.utils.data import DataLoader

N = 100
D = 5

train_size = 1000
test_size = 1000

ds_train = GaussianDataset(num_samples=train_size, shape=(N, D), var1=1.0, var2=0.8, static=False)

ds_test = GaussianDataset(num_samples=test_size, shape=(N, D), var1=1.0, var2=0.8, static=True)

dl_train = DataLoader(dataset=ds_train, batch_size=32, shuffle=False)

dl_test = DataLoader(dataset=ds_test, batch_size=32, shuffle=False)

## Models

### General Purpose

In [3]:
from src.models import test_invariant, test_equivariant

device = torch.device("cpu")

In [4]:
from src.training import BinaryTrainer
from src.layers import LinearEquivariant, LinearInvariant, PositionalEncoding


def create_mlp_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        nn.Flatten(start_dim=1),
        nn.Linear(in_features=n * d, out_features=10 * d),
        nn.ReLU(),
        nn.Linear(in_features=10 * d, out_features=10 * d),
        nn.ReLU(),
        nn.Linear(in_features=10 * d, out_features=1),
        nn.Sigmoid(),
    )


def create_transformer_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        PositionalEncoding(d_model=d, max_len=n),
        nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(batch_first=True, d_model=d, nhead=1),
            norm=nn.LayerNorm(normalized_shape=d),
            num_layers=1,
        ),
        nn.Flatten(start_dim=1),
        nn.Linear(in_features=n * d, out_features=1),
        nn.Sigmoid(),
    )


def create_invariant_model(n: int, d: int) -> nn.Module:
    return nn.Sequential(
        LinearEquivariant(in_channels=d, out_channels=10),
        nn.ReLU(),
        LinearEquivariant(in_channels=10, out_channels=10),
        nn.ReLU(),
        LinearInvariant(in_channels=10, out_channels=1),
        nn.Sigmoid(),
    )

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from src.train_results import FitResult


def train_model(model: nn.Module, log_dir: str | None = None) -> FitResult:
    trainer = BinaryTrainer(
        model=model,
        criterion=nn.BCELoss(),
        optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
        device=device,
        log=True,
        log_dir=log_dir,
    )

    return trainer.fit(
        dl_train=dl_train,
        dl_test=dl_test,
        num_epochs=10000,
        print_every=25,
        time_limit=60 * 30,
        early_stopping=300,
    )

### Canonization Based

#### MLP-Based

In [6]:
from src.models import CanonicalModel

model = CanonicalModel(create_mlp_model(N, D))

test_invariant(model, input=torch.randn(32, N, D))

True

In [7]:
train_model(model, log_dir="runs/canonical-mlp")

--- EPOCH 1/10000 --- (time: 00:00:00)
train_batch (Avg. Loss 0.693, Accuracy 53.00%): 100%|██████████| 32/32 [00:00<00:00, 178.18it/s]
test_batch (Avg. Loss 0.691, Accuracy 59.80%): 100%|██████████| 32/32 [00:00<00:00, 388.80it/s]
--- EPOCH 26/10000 --- (time: 00:00:05)
train_batch (Avg. Loss 0.682, Accuracy 58.00%): 100%|██████████| 32/32 [00:00<00:00, 148.99it/s]
test_batch (Avg. Loss 0.681, Accuracy 49.90%): 100%|██████████| 32/32 [00:00<00:00, 261.15it/s]
--- EPOCH 51/10000 --- (time: 00:00:10)
train_batch (Avg. Loss 0.664, Accuracy 56.90%): 100%|██████████| 32/32 [00:00<00:00, 103.34it/s]
test_batch (Avg. Loss 0.659, Accuracy 77.40%): 100%|██████████| 32/32 [00:00<00:00, 255.19it/s]
--- EPOCH 76/10000 --- (time: 00:00:14)
train_batch (Avg. Loss 0.601, Accuracy 73.50%): 100%|██████████| 32/32 [00:00<00:00, 150.85it/s]
test_batch (Avg. Loss 0.591, Accuracy 83.20%): 100%|██████████| 32/32 [00:00<00:00, 260.83it/s]
--- EPOCH 101/10000 --- (time: 00:00:19)
train_batch (Avg. Loss 0.490

FitResult(num_epochs=1099, train_loss=[0.6902720332145691, 0.6475526094436646, 0.6109548211097717, 0.6525806784629822, 0.8150444030761719, 0.8434545993804932, 0.6731957793235779, 0.6810147166252136, 0.6905162930488586, 0.6408487558364868, 0.7357474565505981, 0.6590905785560608, 0.690438985824585, 0.6800671815872192, 0.7125382423400879, 0.6899153590202332, 0.6890438795089722, 0.6888217926025391, 0.6987063884735107, 0.694175660610199, 0.6907956600189209, 0.6920245289802551, 0.6872631311416626, 0.6968593001365662, 0.6894164681434631, 0.6921457648277283, 0.6928277611732483, 0.6903718709945679, 0.6869746446609497, 0.6936500668525696, 0.6898740530014038, 0.6960240006446838, 0.690488874912262, 0.690040111541748, 0.6836337447166443, 0.681727409362793, 0.6908680200576782, 0.6992005109786987, 0.6743867993354797, 0.677019476890564, 0.6814888119697571, 0.6534305810928345, 0.736005961894989, 0.6491286158561707, 0.692682683467865, 0.6773411631584167, 0.7473115921020508, 0.694404661655426, 0.69401907

#### Attention-Based

In [8]:
model = CanonicalModel(create_transformer_model(N, D))

test_invariant(model, input=torch.randn(32, N, D))



True

In [9]:
train_model(model, log_dir="runs/canonical-attn")

--- EPOCH 1/10000 --- (time: 00:00:00)
train_batch (Avg. Loss 0.718, Accuracy 52.30%): 100%|██████████| 32/32 [00:02<00:00, 11.59it/s]
test_batch (Avg. Loss 0.692, Accuracy 51.50%): 100%|██████████| 32/32 [00:00<00:00, 56.08it/s]
--- EPOCH 26/10000 --- (time: 00:01:19)
train_batch (Avg. Loss 0.263, Accuracy 89.50%): 100%|██████████| 32/32 [00:02<00:00, 14.66it/s]
test_batch (Avg. Loss 0.233, Accuracy 89.80%): 100%|██████████| 32/32 [00:00<00:00, 52.33it/s]
--- EPOCH 51/10000 --- (time: 00:02:34)
train_batch (Avg. Loss 0.216, Accuracy 91.20%): 100%|██████████| 32/32 [00:02<00:00, 13.74it/s]
test_batch (Avg. Loss 0.265, Accuracy 88.80%): 100%|██████████| 32/32 [00:00<00:00, 55.28it/s]
--- EPOCH 76/10000 --- (time: 00:03:59)
train_batch (Avg. Loss 0.218, Accuracy 90.10%): 100%|██████████| 32/32 [00:02<00:00, 13.54it/s]
test_batch (Avg. Loss 0.158, Accuracy 93.40%): 100%|██████████| 32/32 [00:00<00:00, 43.74it/s]
--- EPOCH 101/10000 --- (time: 00:05:07)
train_batch (Avg. Loss 0.216, Accura

FitResult(num_epochs=629, train_loss=[0.8936654925346375, 0.7152407765388489, 0.6287521719932556, 0.6849594116210938, 0.9562020897865295, 0.9346592426300049, 0.6796101927757263, 0.6621783971786499, 0.7033883929252625, 0.7549343109130859, 0.6669254302978516, 0.7702639102935791, 0.7397125959396362, 0.730010986328125, 0.6570054292678833, 0.6948411464691162, 0.7248270511627197, 0.6934640407562256, 0.7143498659133911, 0.6954349279403687, 0.7065883874893188, 0.6702858209609985, 0.686933159828186, 0.7144218683242798, 0.6722450256347656, 0.6926038265228271, 0.6939974427223206, 0.6787756681442261, 0.6835915446281433, 0.698376476764679, 0.6645684242248535, 0.7100903987884521, 0.6913476586341858, 0.6758277416229248, 0.6620659828186035, 0.6809163093566895, 0.75576251745224, 0.7774873971939087, 0.6634427309036255, 0.6687943935394287, 0.756853461265564, 0.6557490825653076, 0.7382183074951172, 0.6703120470046997, 0.7222234606742859, 0.683743417263031, 0.7053150534629822, 0.6873010993003845, 0.7013373

### Symmetrization Network

#### MLP-Based

In [10]:
from src.permutation import Permutation, create_all_permutations, create_permutations_from_generators
from src.models import SymmetryModel

shift_perm = Permutation((torch.arange(N) + 1) % N)

model = SymmetryModel(
    model=create_mlp_model(N, D),
    perm_creator=lambda: create_permutations_from_generators([shift_perm]),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

False

In [11]:
#train_model(model, log_dir="runs/symmetry-mlp")

#### Attention-Based

In [12]:
shift_perm = Permutation((torch.arange(N) + 1) % N)

model = SymmetryModel(
    model=create_transformer_model(N, D),
    perm_creator=lambda: create_permutations_from_generators([shift_perm]),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

False

In [13]:
#train_model(model, log_dir="runs/symmetry-mlp")

#### MLP-Based Sampled Symmetrization 

In [14]:
num = int(math.factorial(N) * 0.05)
num = 30

model = SymmetryModel(
    model=create_mlp_model(N, D),
    perm_creator=lambda: (Permutation(torch.randperm(N)) for _ in range(num)),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

False

In [15]:
train_model(model, log_dir="runs/symmetry-sampling-mlp")

  self.hash = hash(tuple(perm.tolist()))
	%perm.1 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%34, %35, %36, %37, %38) # /tmp/ipykernel_72518/1933901993.py:6:0
	%perm.5 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%45, %46, %47, %48, %49) # /tmp/ipykernel_72518/1933901993.py:6:0
	%perm.9 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%56, %57, %58, %59, %60) # /tmp/ipykernel_72518/1933901993.py:6:0
	%perm.13 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%67, %68, %69, %70, %71) # /tmp/ipykernel_72518/1933901993.py:6:0
	%perm.17 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%78, %79, %80, %81, %82) # /tmp/ipykernel_72518/1933901993.py:6:0
	%perm.21 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%89, %90, %91, %92, %93) # /tmp/ipykernel_72518/1933901993.py:6:0
	%perm.25 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm

--- EPOCH 1/10000 --- (time: 00:00:00)
train_batch (Avg. Loss 0.691, Accuracy 50.50%): 100%|██████████| 32/32 [00:00<00:00, 52.88it/s]
test_batch (Avg. Loss 0.685, Accuracy 53.50%): 100%|██████████| 32/32 [00:00<00:00, 69.61it/s]
--- EPOCH 26/10000 --- (time: 00:00:18)
train_batch (Avg. Loss 0.647, Accuracy 64.30%): 100%|██████████| 32/32 [00:00<00:00, 73.19it/s]
test_batch (Avg. Loss 0.640, Accuracy 77.70%): 100%|██████████| 32/32 [00:00<00:00, 102.93it/s]
--- EPOCH 51/10000 --- (time: 00:00:31)
train_batch (Avg. Loss 0.548, Accuracy 80.40%): 100%|██████████| 32/32 [00:00<00:00, 79.46it/s]
test_batch (Avg. Loss 0.541, Accuracy 84.50%): 100%|██████████| 32/32 [00:00<00:00, 121.86it/s]
--- EPOCH 76/10000 --- (time: 00:00:48)
train_batch (Avg. Loss 0.474, Accuracy 89.10%): 100%|██████████| 32/32 [00:00<00:00, 41.55it/s]
test_batch (Avg. Loss 0.475, Accuracy 89.00%): 100%|██████████| 32/32 [00:00<00:00, 119.18it/s]
--- EPOCH 101/10000 --- (time: 00:01:03)
train_batch (Avg. Loss 0.446, Acc

FitResult(num_epochs=1370, train_loss=[0.6934600472450256, 0.6893404722213745, 0.6837733387947083, 0.6849282383918762, 0.6920647621154785, 0.7011557817459106, 0.681333601474762, 0.680396556854248, 0.6850090026855469, 0.6653058528900146, 0.715277910232544, 0.6622262001037598, 0.6902521848678589, 0.6746068000793457, 0.7285387516021729, 0.6905863881111145, 0.6906815767288208, 0.6896626353263855, 0.7217775583267212, 0.7095928192138672, 0.7026039361953735, 0.6641322374343872, 0.7245184183120728, 0.6578350067138672, 0.7049815654754639, 0.6820147633552551, 0.6851122379302979, 0.6914372444152832, 0.7195252180099487, 0.6786820888519287, 0.7107754945755005, 0.6595553159713745, 0.6763572692871094, 0.6677396297454834, 0.6691532731056213, 0.6774218082427979, 0.6939362287521362, 0.701644778251648, 0.6739710569381714, 0.679203450679779, 0.6857841610908508, 0.6650086045265198, 0.7180466055870056, 0.6593151688575745, 0.6891088485717773, 0.6811921000480652, 0.7165300250053406, 0.6873329281806946, 0.6906

#### Attention-Based Sampled Symmetrization 

In [16]:
num = int(math.factorial(N) * 0.05)
num = 30

model = SymmetryModel(
    model=create_transformer_model(N, D),
    perm_creator=lambda: (Permutation(torch.randperm(N)) for _ in range(num)),
    chunksize=10,
)

test_invariant(model, torch.randn(32, N, D))

False

In [17]:
train_model(model, log_dir="runs/symmetry-sampling-attn")

	%perm.1 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%75, %76, %77, %78, %79) # /tmp/ipykernel_72518/3788656092.py:6:0
	%perm.5 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%86, %87, %88, %89, %90) # /tmp/ipykernel_72518/3788656092.py:6:0
	%perm.9 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%97, %98, %99, %100, %101) # /tmp/ipykernel_72518/3788656092.py:6:0
	%perm.13 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%108, %109, %110, %111, %112) # /tmp/ipykernel_72518/3788656092.py:6:0
	%perm.17 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%119, %120, %121, %122, %123) # /tmp/ipykernel_72518/3788656092.py:6:0
	%perm.21 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%130, %131, %132, %133, %134) # /tmp/ipykernel_72518/3788656092.py:6:0
	%perm.25 : Long(100, strides=[1], requires_grad=0, device=cpu) = aten::randperm(%141, %142, %143, %144,

--- EPOCH 1/10000 --- (time: 00:00:00)
train_batch (Avg. Loss 0.695, Accuracy 50.90%): 100%|██████████| 32/32 [00:58<00:00,  1.84s/it]
test_batch (Avg. Loss 0.676, Accuracy 59.70%): 100%|██████████| 32/32 [00:16<00:00,  1.98it/s]
--- EPOCH 26/10000 --- (time: 00:27:24)
train_batch (Avg. Loss 0.423, Accuracy 82.30%): 100%|██████████| 32/32 [00:52<00:00,  1.63s/it]
test_batch (Avg. Loss 0.407, Accuracy 81.00%): 100%|██████████| 32/32 [00:15<00:00,  2.13it/s]
--- Stopping after 28 epochs :: time limit exceeded 00:31:01 --- 


FitResult(num_epochs=28, train_loss=[0.6864078044891357, 0.6611539125442505, 0.6284066438674927, 0.6594600677490234, 0.7286078929901123, 0.6977515816688538, 0.7147784233093262, 0.734311044216156, 0.7029995918273926, 0.7360846996307373, 0.676027774810791, 0.6924319863319397, 0.6946494579315186, 0.6897488236427307, 0.7461787462234497, 0.7033585906028748, 0.6904783248901367, 0.6966283917427063, 0.7080164551734924, 0.6904377937316895, 0.6817857623100281, 0.7514521479606628, 0.6464172601699829, 0.7676995396614075, 0.6677509546279907, 0.694831371307373, 0.6929336786270142, 0.6793301701545715, 0.689302921295166, 0.6665120124816895, 0.6851911544799805, 0.6846842765808105, 0.6628112196922302, 0.6437892913818359, 0.6373413801193237, 0.659019410610199, 0.7814176082611084, 0.8161959648132324, 0.6611182689666748, 0.691798746585846, 0.6822149753570557, 0.6548230648040771, 0.6744604706764221, 0.6931190490722656, 0.6724733710289001, 0.6865214705467224, 0.6546481847763062, 0.6695317625999451, 0.6702057

### Intrinsic Invariant

In [18]:
# TODO: FIX INVARIANT MODEL, IT BARELY TRAINS
model = create_invariant_model(N, D)

test_invariant(model, torch.randn(32, N, D))

True

In [19]:
train_model(model, log_dir="runs/intrinsic")

  assert x.shape[-1] == self.in_channels
  assert x.shape[-1] == self.in_channels


--- EPOCH 1/10000 --- (time: 00:00:00)
train_batch (Avg. Loss 51.320, Accuracy 47.60%): 100%|██████████| 32/32 [00:00<00:00, 70.93it/s]
test_batch (Avg. Loss 51.190, Accuracy 47.90%): 100%|██████████| 32/32 [00:00<00:00, 133.34it/s]
--- EPOCH 26/10000 --- (time: 00:00:09)
train_batch (Avg. Loss 50.850, Accuracy 49.10%): 100%|██████████| 32/32 [00:00<00:00, 120.60it/s]
test_batch (Avg. Loss 51.314, Accuracy 47.40%): 100%|██████████| 32/32 [00:00<00:00, 229.40it/s]
--- EPOCH 51/10000 --- (time: 00:00:17)
train_batch (Avg. Loss 50.765, Accuracy 49.20%): 100%|██████████| 32/32 [00:00<00:00, 111.42it/s]
test_batch (Avg. Loss 52.758, Accuracy 46.50%): 100%|██████████| 32/32 [00:00<00:00, 247.75it/s]
--- EPOCH 76/10000 --- (time: 00:00:23)
train_batch (Avg. Loss 51.173, Accuracy 49.30%): 100%|██████████| 32/32 [00:00<00:00, 101.58it/s]
test_batch (Avg. Loss 52.989, Accuracy 46.60%): 100%|██████████| 32/32 [00:00<00:00, 219.96it/s]
--- EPOCH 101/10000 --- (time: 00:00:30)
train_batch (Avg. Los

FitResult(num_epochs=307, train_loss=[56.25000762939453, 54.161865234375, 59.2497673034668, 50.001285552978516, 42.777618408203125, 42.445274353027344, 56.25, 55.69387435913086, 47.757930755615234, 50.73086166381836, 43.75, 55.61744689941406, 53.125, 56.25, 35.21044921875, 60.627567291259766, 52.62239456176758, 56.25, 56.875179290771484, 40.625, 39.328365325927734, 68.32148742675781, 50.361724853515625, 53.125, 40.625, 43.75, 56.62916564941406, 56.25, 40.625, 60.69572830200195, 43.75, 62.5, 57.014347076416016, 59.375, 68.75, 40.69826126098633, 46.875, 33.233421325683594, 45.19318389892578, 59.375, 50.0, 56.286773681640625, 40.650299072265625, 59.401756286621094, 48.870704650878906, 56.25, 35.77605056762695, 45.612571716308594, 43.75, 33.088829040527344, 50.0, 44.226539611816406, 55.52495193481445, 52.30609893798828, 31.25, 50.0, 43.75010299682617, 53.125213623046875, 49.381065368652344, 50.0, 43.65264892578125, 49.12038803100586, 31.25, 75.0, 58.77573776245117, 59.375, 65.625, 59.37503

### Standard with Augmentation

In [20]:
# TODO: IS THE AUGMENTATION DIMENSION CORRECT?
class Augmentation(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        """
        Randomly permute the input tensor along the channel dimension.

        Args:
            x (Tensor): Input tensor of shape (batch_size, d, channel)
        """
        if self.training is False:
            return x

        rnd = torch.randn_like(x)
        indices = rnd.argsort(dim=-1)
        result = torch.gather(x, -1, indices)
        return result

#### MLP-Based 

In [21]:
model = nn.Sequential(
    Augmentation(),
    create_mlp_model(N, D),
)

test_invariant(model, torch.randn(32, N, D))

False

In [22]:
train_model(model, log_dir="runs/augmented-mlp")

Tensor-likes are not close!

Mismatched elements: 32 / 32 (100.0%)
Greatest absolute difference: 0.050576984882354736 at index (31, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.10315909305594746 at index (31, 0) (up to 1e-05 allowed)
  _check_trace(


--- EPOCH 1/10000 --- (time: 00:00:00)
train_batch (Avg. Loss 0.691, Accuracy 50.70%): 100%|██████████| 32/32 [00:00<00:00, 142.77it/s]
test_batch (Avg. Loss 0.686, Accuracy 53.90%): 100%|██████████| 32/32 [00:00<00:00, 290.79it/s]
--- EPOCH 26/10000 --- (time: 00:00:05)
train_batch (Avg. Loss 0.680, Accuracy 56.10%): 100%|██████████| 32/32 [00:00<00:00, 150.88it/s]
test_batch (Avg. Loss 0.681, Accuracy 56.40%): 100%|██████████| 32/32 [00:00<00:00, 237.74it/s]
--- EPOCH 51/10000 --- (time: 00:00:09)
train_batch (Avg. Loss 0.687, Accuracy 54.20%): 100%|██████████| 32/32 [00:00<00:00, 146.50it/s]
test_batch (Avg. Loss 0.677, Accuracy 57.00%): 100%|██████████| 32/32 [00:00<00:00, 257.39it/s]
--- EPOCH 76/10000 --- (time: 00:00:14)
train_batch (Avg. Loss 0.667, Accuracy 60.90%): 100%|██████████| 32/32 [00:00<00:00, 141.84it/s]
test_batch (Avg. Loss 0.669, Accuracy 59.70%): 100%|██████████| 32/32 [00:00<00:00, 253.47it/s]
--- EPOCH 101/10000 --- (time: 00:00:19)
train_batch (Avg. Loss 0.663

FitResult(num_epochs=984, train_loss=[0.6922390460968018, 0.6958283185958862, 0.6833241581916809, 0.6648596525192261, 0.6796028017997742, 0.7031815648078918, 0.6957847476005554, 0.6928430199623108, 0.6950574517250061, 0.6675847768783569, 0.6996409893035889, 0.6679630279541016, 0.686610996723175, 0.6855564713478088, 0.7132988572120667, 0.6863387823104858, 0.6900065541267395, 0.6938770413398743, 0.7306508421897888, 0.7111431360244751, 0.6836928725242615, 0.6653380393981934, 0.7184361815452576, 0.6627237796783447, 0.7142767906188965, 0.6975323557853699, 0.6868209838867188, 0.6851239204406738, 0.7244538068771362, 0.6820785999298096, 0.7095323801040649, 0.6441780924797058, 0.6829084753990173, 0.6775268316268921, 0.657153844833374, 0.6683919429779053, 0.6807570457458496, 0.7044534683227539, 0.679029643535614, 0.6739048361778259, 0.6820704936981201, 0.6761788129806519, 0.7275882363319397, 0.675166666507721, 0.6893475651741028, 0.6842344999313354, 0.7183050513267517, 0.6815492510795593, 0.6836

#### Attention-Based 

In [23]:
model = nn.Sequential(
    Augmentation(),
    create_transformer_model(N, D),
)

test_invariant(model, torch.randn(32, N, D))



False

In [24]:
train_model(model, log_dir="runs/augmented-attn")

Tensor-likes are not close!

Mismatched elements: 32 / 32 (100.0%)
Greatest absolute difference: 0.45218876004219055 at index (14, 0) (up to 1e-05 allowed)
Greatest relative difference: 1.3826454151267624 at index (28, 0) (up to 1e-05 allowed)
  _check_trace(


--- EPOCH 1/10000 --- (time: 00:00:00)
train_batch (Avg. Loss 0.703, Accuracy 54.70%): 100%|██████████| 32/32 [00:02<00:00, 11.08it/s]
test_batch (Avg. Loss 0.719, Accuracy 49.20%): 100%|██████████| 32/32 [00:00<00:00, 43.22it/s]
--- EPOCH 26/10000 --- (time: 00:01:11)
train_batch (Avg. Loss 0.547, Accuracy 73.90%): 100%|██████████| 32/32 [00:02<00:00, 15.73it/s]
test_batch (Avg. Loss 0.537, Accuracy 72.50%): 100%|██████████| 32/32 [00:00<00:00, 54.14it/s]
--- EPOCH 51/10000 --- (time: 00:02:15)
train_batch (Avg. Loss 0.478, Accuracy 76.50%): 100%|██████████| 32/32 [00:01<00:00, 16.24it/s]
test_batch (Avg. Loss 0.453, Accuracy 78.80%): 100%|██████████| 32/32 [00:00<00:00, 53.75it/s]
--- EPOCH 76/10000 --- (time: 00:03:27)
train_batch (Avg. Loss 0.425, Accuracy 81.30%): 100%|██████████| 32/32 [00:03<00:00,  8.34it/s]
test_batch (Avg. Loss 0.384, Accuracy 83.20%): 100%|██████████| 32/32 [00:00<00:00, 34.06it/s]
--- EPOCH 101/10000 --- (time: 00:04:45)
train_batch (Avg. Loss 0.401, Accura

FitResult(num_epochs=679, train_loss=[0.7534112334251404, 0.5987116694450378, 0.6003865599632263, 0.6472458839416504, 0.814667820930481, 0.8919265866279602, 0.6499448418617249, 0.7109752297401428, 0.7162401676177979, 0.7251662015914917, 0.7261738777160645, 0.6755481958389282, 0.7028063535690308, 0.6685490012168884, 0.679072916507721, 0.7176507115364075, 0.7198669910430908, 0.6998026371002197, 0.6625788807868958, 0.7717240452766418, 0.652584433555603, 0.7154320478439331, 0.7111121416091919, 0.7387382984161377, 0.7179195880889893, 0.7652599215507507, 0.6474567651748657, 0.7159289717674255, 0.6424065232276917, 0.6820570230484009, 0.6938815712928772, 0.6869893670082092, 0.7542162537574768, 0.7395488023757935, 0.7183284163475037, 0.7350982427597046, 0.7634751200675964, 0.725506603717804, 0.6987865567207336, 0.7087991237640381, 0.7175453901290894, 0.700645387172699, 0.781825602054596, 0.6515195369720459, 0.7398138642311096, 0.7163290977478027, 0.7606008052825928, 0.7558640241622925, 0.718043