In [101]:
import logging
import torch
import pytorch_lightning
import random

# reduce log noise
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

class OneToManyAssociator(pytorch_lightning.LightningModule):
    def __init__(self, io_size: int, hidden_size: int,
                 data_loader: torch.utils.data.DataLoader,
                 learning_rate:float=1e-3):
        super().__init__()
        # the name of this attribute is important to work with
        # pytorch_lightning.Trainer(auto_lr_find=True)
        self.learning_rate = learning_rate
        self.data_loader = data_loader
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(io_size, hidden_size, dtype=torch.double),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, io_size, dtype=torch.double),
        )
        self.n_zero = 0
        self.first_zero = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Conceptually we want the network to output only 0s or 1s.
        # So something like:
        # torch.ceil(torch.clamp(self.layers(x), min=0, max=1))
        # But this results in a failure to make progress during training.
        # Not sure why. For now handle it in training_epoch_end.
        return self.layers(x) + torch.tensor(0.5)
    
    def training_step(self, batch, batch_idx):
        x, ys = batch
        y0, y1 = ys
        y = random.choice((y0, y1))
        
        # enforce batch size == 1 for equivalence to Leabra model
        if x.size()[0] != 1 or y.size()[0] != 1:
            raise ValueError("expected batch size == 1, got", x.size()[0])
        x, y = torch.squeeze(x, 0), torch.squeeze(y, 0)
        preds = self(x)
        loss = torch.pow(torch.nn.functional.l1_loss(preds, y), 1/4)
        self.log("loss", loss, on_step=False, on_epoch=True)
        return loss

    def training_epoch_end(self, outputs) -> None:
        max_loss = max(output["loss"] for output in outputs)
        # Because the output is not forced to be 0 or 1, the
        # loss will never be zero, so we have a threshold.
        # This is arbitrary. TODO: pick something to make it a fair
        # comparison with Leabra.
        if max_loss < 0.05:
          self.n_zero += 1
          if self.first_zero is None:
            self.first_zero = self.current_epoch
        self.log("n_zero", float(self.n_zero))
        
    def test_step(self, batch, batch_idx):
        x, ys = batch
        y0, y1 = ys
        
        # enforce batch size == 1 for equivalence to Leabra model
        if x.size()[0] != 1 or y0.size()[0] != 1:
            raise ValueError("expected batch size == 1, got", x.size()[0])
        x, y0, y1 = torch.squeeze(x, 0), torch.squeeze(y0, 0), torch.squeeze(y1, 0)
        preds = self(x)
        loss0, loss1 = torch.nn.functional.l1_loss(preds, y0), torch.nn.functional.l1_loss(preds, y1)
        loss = torch.min(loss0, loss1)
        self.log("loss", loss, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def train_dataloader(self):
        return self.data_loader
    
    def test_dataloader(self):
        return self.data_loader

In [102]:
import numpy as np


class OneToManyDataset(torch.utils.data.Dataset):
    @staticmethod
    def random_datum(
        rng: np.random.Generator, num_nonzero: int, size: int
    ) -> torch.Tensor:
        nonzero_idx = rng.choice(size, num_nonzero)
        ret = torch.zeros(size, dtype=torch.double)
        ret[nonzero_idx] = 1
        return ret

    def __init__(self, num_nonzero: int, datum_size: int, size: int):
        super().__init__()
        rng = np.random.default_rng()
        def rd():
            return self.random_datum(rng, num_nonzero, datum_size)
        self.xs = [
            rd() for _ in range(size)
        ]
        self.ys = [
            (rd(), rd()) for _ in range(size)
        ]

    def __len__(self):
        return len(self.xs)

    def __getitem__(self, idx):
        return self.xs[idx], self.ys[idx]


In [103]:
# https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25.go#L26
datum_size = 6 * 6
dataset = OneToManyDataset(
    # https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25.go#L27
    num_nonzero=6,
    datum_size=datum_size,
    # https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25.go#L28
    size=30)

dataset[0]

(tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        dtype=torch.float64),
 (tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         dtype=torch.float64),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.],
         dtype=torch.float64)))

In [104]:
import sys

early_stopping = pytorch_lightning.callbacks.EarlyStopping(
    "n_zero",
    patience=sys.maxsize, # effectively infinite
    mode="max",
    # https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25.go#L188
    stopping_threshold=5
)

In [108]:
trainer = pytorch_lightning.Trainer(
    #accelerator='gpu', devices=1,
    auto_lr_find=True,
    # https://github.com/Astera-org/models/blob/0de0c8005cdc57c28b2c663c89b3741508d013d2/mechs/ra25/ra25_test.go#L42
    max_epochs=400,
    callbacks=[early_stopping])
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
model = OneToManyAssociator(datum_size, hidden_size=64, data_loader=data_loader)

trainer.fit(model)

print("first_zero", model.first_zero)
print("last_zero", early_stopping.stopped_epoch)

Training: 0it [00:00, ?it/s]

first_zero None
last_zero 0


In [109]:
trainer.test(model)

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          loss             0.056163542822281455
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'loss': 0.056163542822281455}]

In [110]:
onesample = dataset[10]
pred = model(onesample[0])
print(pred)
print("")
print(torch.where(pred > 0.2, 1.0, 0.0))
print("and\n", torch.logical_and(*onesample[1]).to(torch.float))
print(onesample[1][0], "\n", onesample[1][1])


tensor([-2.2628e-03, -1.1025e-03,  1.2184e-02, -2.2784e-02,  4.1489e-03,
         6.7394e-03,  1.8130e-02, -2.5453e-02,  8.7493e-03, -1.6770e-02,
         1.6138e-03, -5.9646e-03, -2.3347e-02,  9.4242e-01,  3.4339e-02,
        -1.6437e-02, -7.6369e-03,  9.5718e-01, -6.3848e-03,  1.1124e-02,
        -2.4374e-02,  1.4364e-03,  3.1065e-04,  9.6909e-01, -1.0165e-02,
         1.4589e-02,  8.6016e-03,  1.6176e-03,  9.3904e-01,  2.6515e-03,
         9.7517e-01,  1.7261e-02,  9.4192e-01,  3.7791e-02, -8.7301e-03,
        -1.1589e-02], dtype=torch.float64, grad_fn=<AddBackward0>)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0.])
and
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.