In [97]:
import datasets
import torch
import torch.nn.functional as F
import torch.utils.data as data
import pytorch_lightning as pl
import numpy as np

def encode(batch):
    # np.array() because torch.tensor() doesn't support PIL.Image
    batch["image"] = [torch.tensor(np.array(img), dtype=torch.float32) / 255 for img in batch["image"]]
    return batch

mnist = datasets.load_dataset("mnist").with_transform(encode)

Found cached dataset mnist (/Users/cztomsik/.cache/huggingface/datasets/mnist/mnist/1.0.0/fda16c03c4ecfb13f165ba7e29cf38129ce035011519968cdaf74894ce91c9d4)
100%|██████████| 2/2 [00:00<00:00, 720.79it/s]


In [121]:
class Model(pl.LightningModule):
    def __init__(self, hidden_dim = 256):
        super().__init__()
        self.inner = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 10),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.inner(x)

    def training_step(self, batch, batch_idx):
        logits = self(batch['image'])
        return F.cross_entropy(logits, batch['label'])

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

model = Model()
trainer = pl.Trainer(max_epochs=10, accelerator='cpu')
trainer.fit(model, data.DataLoader(mnist['train'], batch_size=200), data.DataLoader(mnist['train'], batch_size=200))

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | inner | Sequential | 203 K 
-------------------------------------
203 K     Trainable params
0         Non-trainable params
203 K     Total params
0.814     Total estimated model params size (MB)


Epoch 9: 100%|██████████| 300/300 [00:04<00:00, 68.35it/s, loss=0.0265, v_num=84]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 300/300 [00:04<00:00, 68.30it/s, loss=0.0265, v_num=84]


In [122]:
for i in range(10):
    print(torch.argmax(model.forward(mnist["test"][i]["image"].view(1, -1))), mnist["test"][i]["label"])

tensor(7) 7
tensor(2) 2
tensor(1) 1
tensor(0) 0
tensor(4) 4
tensor(1) 1
tensor(4) 4
tensor(9) 9
tensor(6) 5
tensor(9) 9
