In [19]:
import torch
import torch.nn.functional as F
import pytorch_lightning


class RandomAssociator(pytorch_lightning.LightningModule):
    def __init__(self, io_size: int, hidden_size: int, data_loader: torch.utils.data.DataLoader, learning_rate:float=1e-3):
        super().__init__()
        # the name of this attribute is important to work with
        # pytorch_lightning.Trianer(auto_lr_find=True)
        self.learning_rate = learning_rate
        self.data_loader = data_loader
        # TODO: datatypes don't make sense here.
        # need to figure out how to represent a bit vector in input and output
        # properly.
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(io_size, hidden_size, dtype=torch.uint8),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, io_size, dtype=torch.uint8),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)
    
    def training_step(self, batch, unused_batch_idx):
        x, y = batch
        # enforce batch size == 1 for equivalence to Leabra model
        if x.size()[0] != 1 or y.size()[0] != 1:
            raise ValueError("expected batch size == 1, got", x.size()[0])
        x, y = torch.squeeze(x, 0), torch.squeeze(y, 0)
        return F.mse_loss(self(x), y)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def train_dataloader(self):
        return self.data_loader

In [20]:
import numpy as np


class RandomAssociationDataset(torch.utils.data.Dataset):
    @staticmethod
    def random_datum(
        rng: np.random.Generator, num_nonzero: int, size: int
    ) -> torch.Tensor:
        nonzero_idx = rng.choice(size, num_nonzero)
        ret = torch.zeros(size, dtype=torch.uint8)
        ret[nonzero_idx] = 1
        return ret

    def __init__(self, num_nonzero: int, datum_size: int, size: int):
        super().__init__()
        rng = np.random.default_rng()
        self.xs = [
            self.random_datum(rng, num_nonzero, datum_size) for _ in range(size)
        ]
        self.ys = [
            self.random_datum(rng, num_nonzero, datum_size) for _ in range(size)
        ]

    def __len__(self):
        return len(self.xs)

    def __getitem__(self, idx):
        return self.xs[idx], self.ys[idx]


In [21]:
datum_size = 25
dataset = RandomAssociationDataset(num_nonzero=6, datum_size=datum_size, size=100)

trainer = pytorch_lightning.Trainer(auto_lr_find=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
model = RandomAssociator(datum_size, hidden_size=64, data_loader=data_loader)
# find the learning rate
trainer.tune(model)

trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(


RuntimeError: expected scalar type Byte but found Float