In [1]:
"""
   Copyright 2021 Žarko Bulić, Boris Shminke

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
"""
# uncomment this if you've just uploaded this notebook to Google Colaboratory
# better use a GPU runtime (TPU ones are not supported by the package yet)

# !pip install neural-semigroups

'\n   Copyright 2021 Žarko Bulić, Boris Shminke\n\n   Licensed under the Apache License, Version 2.0 (the "License");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an "AS IS" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n'

In [2]:
# this is a simple example for semigroups from n=4 elements
from neural_semigroups.datasets import RandomDataset

CARDINALITY = 4
BATCH_SIZE = 32
data = RandomDataset(2000, tuple(2 * [3 * [CARDINALITY]]))

In [3]:
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader

# for this case we split all available data into three subsets:
# for training, validating after each epoch and for testing the final model
test_size = len(data) // 3
data_loaders = tuple(
    DataLoader(data_split, batch_size=BATCH_SIZE)
    for data_split
    in random_split(data, [len(data) - 2 * test_size, test_size, test_size])
)

In [4]:
from neural_semigroups.associator_loss import AssociatorLoss
from torch import Tensor
import torch
from neural_semigroups.constants import CURRENT_DEVICE

discrete = torch.stack(
    BATCH_SIZE * [torch.tensor(
        [0.0] * (CARDINALITY ** 3 - CARDINALITY) + [1.0] * CARDINALITY
    )]
).to(CURRENT_DEVICE)
ALPHA = 0.5

def loss(prediction: Tensor, target: Tensor) -> Tensor:
    sorted_tensor = torch.sort(
        prediction.view(prediction.size()[0], -1),
        dim=1
    )[0]
    return (
        ALPHA * AssociatorLoss()(prediction) +
        (1 - ALPHA) * torch.nn.CrossEntropyLoss(reduction="sum")(
            sorted_tensor, discrete[:sorted_tensor.shape[0]]
        )
    )

In [5]:
from neural_semigroups import MagmaDAE

dae = MagmaDAE(
    cardinality=CARDINALITY,
    hidden_dims=2 * [CARDINALITY ** 3],
    do_reparametrization = True
)

In [6]:
%load_ext tensorboard

In [7]:
%tensorboard --logdir runs

In [8]:
!rm -rf runs

In [9]:
from neural_semigroups.training_helpers import learning_pipeline
from ignite.metrics.loss import Loss
from neural_semigroups.training_helpers import associative_ratio

params = {"learning_rate": 0.001, "epochs": 20}
metrics = {
    "loss": Loss(loss),
    "associative_ratio": Loss(associative_ratio),
}
learning_pipeline(params, dae, loss, metrics, data_loaders)

[1/20]   5%|5          [00:00<?]

In [10]:
from neural_semigroups.utils import make_discrete
from torch.functional import einsum
from tqdm.notebook import tqdm
from collections import Counter

data_loader = DataLoader(
    RandomDataset(2000000, tuple(2 * [3 * [CARDINALITY]])),
    2 ** 12
)
counter = Counter()
for batch in tqdm(data_loader):
    cubes = make_discrete(dae(batch[0]).detach())
    one = einsum("biml,bjkm->bijkl", cubes, cubes)
    two = einsum("bmkl,bijm->bijkl", cubes, cubes)
    is_associative = (
        (one == two).reshape(one.shape[0], -1).min(axis=1)[0]
    )
    associatives = cubes[is_associative].argmax(axis=3).long()
    counter.update([
        "".join(map(str, associatives[i].view(-1).tolist()))
        for i in range(associatives.shape[0])
    ])
print(len(counter.keys()))

  0%|          | 0/489 [00:00<?, ?it/s]

345


In [11]:
str(counter)

"Counter({'3313333333333333': 265, '0000012002000300': 152, '0000010002000200': 126, '0000012002000000': 123, '0000010000200000': 114, '0000012000000300': 107, '3313333313333333': 101, '0000012002200000': 100, '0000010000000000': 85, '0000012000000000': 60, '0323313323333333': 48, '0333333323333333': 47, '0333313323333333': 46, '0323313333333333': 45, '0323333323333333': 43, '0000012002200300': 42, '0000010000300000': 42, '0100111101230133': 39, '1001011001231031': 38, '0000012002200330': 37, '0123111121113113': 33, '0000010000000300': 33, '0000012302300300': 33, '0000010000200300': 33, '0333333333333333': 32, '0000010002000000': 31, '0323333333333333': 29, '0333323333333333': 24, '0000000000000000': 24, '0000010002000300': 22, '0000012002100000': 21, '0000000000000200': 19, '0100011101230133': 18, '0323133323333333': 16, '2022012322222022': 16, '0323312322323323': 16, '0323313332333333': 16, '0333312323333333': 15, '0000012302000300': 15, '0333313333333333': 15, '0121111111111111': 15