In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

  Referenced from: <F0D48035-EF9E-3141-9F63-566920E60D7C> /Users/cztomsik/miniconda3/lib/python3.10/site-packages/torchvision/image.so
  Expected in:     <07E453B6-4998-32DD-94DC-FA4A3B20022C> /Users/cztomsik/miniconda3/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [5]:
class Model(pl.LightningModule):
    def __init__(self, hidden_dim = 256):
        super().__init__()
        self.inner = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 10),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.inner(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}

    def predict_step(self, batch, batch_idx, dataloader_idx: int = 0):
        x, y = batch
        return self(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

model = Model()
trainer = pl.Trainer(max_epochs=10, accelerator='cpu')
trainer.fit(model, DataLoader(train, batch_size=200), DataLoader(val, batch_size=200))

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | inner | Sequential | 203 K 
-------------------------------------
203 K     Trainable params
0         Non-trainable params
203 K     Total params
0.814     Total estimated model params size (MB)


Epoch 9: 100%|██████████| 275/275 [00:02<00:00, 134.79it/s, loss=0.247, v_num=64]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 275/275 [00:02<00:00, 134.58it/s, loss=0.247, v_num=64]


In [7]:
for i in range(10):
    print(torch.argmax(model(val[i][0])), val[i][1])

tensor(3) 3
tensor(3) 3
tensor(8) 8
tensor(3) 3
tensor(9) 9
tensor(6) 0
tensor(1) 1
tensor(2) 2
tensor(9) 9
tensor(2) 2
