In [1]:
from neural_semigroups.smallsemi_dataset import Smallsemi
from os import environ, path
from neural_semigroups import Magma
from neural_semigroups.utils import corrupt_input

def transform(x):
    new_y = Magma(
        Magma(x[0]).random_isomorphism()
    ).probabilistic_cube
    new_x = corrupt_input(
        new_y.view(1, cardinality, cardinality, cardinality),
        dropout_rate=dropout_rate
    ).view(cardinality, cardinality, cardinality)
    return new_x, new_y

cardinality = 6
dropout_rate = 1 - 1 / cardinality
data = Smallsemi(
    root=path.join(environ["HOME"], "neural-semigroups-data"),
    cardinality=cardinality,
    download=True,
    transform=transform
)

In [2]:
len(data)

15973

In [3]:
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader

data_loaders = tuple(
    DataLoader(data_split, batch_size=32)
    for data_split
    in random_split(data, [len(data) - 200, 100, 100])
)

In [4]:
from neural_semigroups.associator_loss import AssociatorLoss
from torch import Tensor

def loss(prediction: Tensor, target: Tensor) -> Tensor:
    return AssociatorLoss()(prediction)

In [11]:
!rm -rf runs

In [None]:
from neural_semigroups.training_helpers import learning_pipeline
from ignite.metrics.loss import Loss
from neural_semigroups import MagmaDAE
from neural_semigroups.training_helpers import associative_ratio, guessed_ratio

params = {"learning_rate": 0.0001, "epochs": 1000}
metrics = {
    "loss": Loss(loss),
    "associative_ratio": Loss(associative_ratio),
    "guessed_ratio": Loss(guessed_ratio)
}
dae = MagmaDAE(
    cardinality=cardinality,
    hidden_dims=3 * [cardinality ** 3],
    dropout_rate=dropout_rate
)
learning_pipeline(params, dae, loss, metrics, data_loaders)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))