In [19]:
import logging
import torch
import pytorch_lightning

# reduce log noise
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

class RandomAssociator(pytorch_lightning.LightningModule):
    def __init__(self, io_size: int, hidden_size: int,
                 data_loader: torch.utils.data.DataLoader,
                 learning_rate:float=1e-3):
        super().__init__()
        # the name of this attribute is important to work with
        # pytorch_lightning.Trainer(auto_lr_find=True)
        self.learning_rate = learning_rate
        self.data_loader = data_loader
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(io_size, hidden_size, dtype=torch.double),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, io_size, dtype=torch.double),
        )
        self.n_zero = 0
        self.first_zero = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Conceptually we want the network to output only 0s or 1s.
        # So something like:
        # torch.ceil(torch.clamp(self.layers(x), min=0, max=1))
        # But this results in a failure to make progress during training.
        # Not sure why. For now handle it in training_epoch_end.
        return self.layers(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        # enforce batch size == 1 for equivalence to Leabra model
        if x.size()[0] != 1 or y.size()[0] != 1:
            raise ValueError("expected batch size == 1, got", x.size()[0])
        x, y = torch.squeeze(x, 0), torch.squeeze(y, 0)
        preds = self(x)
        loss = torch.nn.functional.l1_loss(preds, y)
        self.log("loss", loss, on_step=False, on_epoch=True)
        return loss

    def training_epoch_end(self, outputs) -> None:
        max_loss = max(output["loss"] for output in outputs)
        # Because the output is not forced to be 0 or 1, the
        # loss will never be zero, so we have a threshold.
        # This is arbitrary. TODO: pick something to make it a fair
        # comparison with Leabra.
        if max_loss < 0.05:
          self.n_zero += 1
          if self.first_zero is None:
            self.first_zero = self.current_epoch
        self.log("n_zero", float(self.n_zero))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def train_dataloader(self):
        return self.data_loader

In [20]:
import numpy as np


class RandomAssociationDataset(torch.utils.data.Dataset):
    @staticmethod
    def random_datum(
        rng: np.random.Generator, num_nonzero: int, size: int
    ) -> torch.Tensor:
        nonzero_idx = rng.choice(size, num_nonzero)
        ret = torch.zeros(size, dtype=torch.double)
        ret[nonzero_idx] = 1
        return ret

    def __init__(self, num_nonzero: int, datum_size: int, size: int):
        super().__init__()
        rng = np.random.default_rng()
        self.xs = [
            self.random_datum(rng, num_nonzero, datum_size) for _ in range(size)
        ]
        self.ys = [
            self.random_datum(rng, num_nonzero, datum_size) for _ in range(size)
        ]

    def __len__(self):
        return len(self.xs)

    def __getitem__(self, idx):
        return self.xs[idx], self.ys[idx]


In [None]:
# https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25.go#L26
datum_size = 6 * 6
dataset = RandomAssociationDataset(
    # https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25.go#L27
    num_nonzero=6,
    datum_size=datum_size,
    # https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25.go#L28
    size=30)

dataset[0]

In [None]:
import sys

early_stopping = pytorch_lightning.callbacks.EarlyStopping(
    "n_zero",
    patience=sys.maxsize, # effectively infinite
    mode="max",
    # https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25.go#L188
    stopping_threshold=5
)

In [None]:
trainer = pytorch_lightning.Trainer(
    auto_lr_find=True,
    # https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25_test.go#L42
    max_epochs=100,
    callbacks=[early_stopping])
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
model = RandomAssociator(datum_size, hidden_size=64, data_loader=data_loader)

trainer.fit(model)

print("first_zero", model.first_zero)
print("last_zero", early_stopping.stopped_epoch)