In [None]:
import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import pytorch_lightning as pl
import numpy as np

pl.seed_everything(42)

def encode(batch):
    batch["img"] = [(torch.tensor(np.array(img), dtype=torch.float32).transpose(0, 2) / 255) - 0.5 for img in batch["img"]]
    return batch

cifar = datasets.load_dataset("cifar10").with_transform(encode)


In [2]:
cifar["train"][0]["img"].shape

torch.Size([3, 32, 32])

In [3]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def training_step(self, batch, batch_idx):
        return F.cross_entropy(self(batch["img"]), batch["label"])

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

model = Model()
trainer = pl.Trainer(max_epochs=10, accelerator="cpu")
trainer.fit(model, data.DataLoader(cifar["train"], batch_size=200), data.DataLoader(cifar["test"], batch_size=200))

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

  | Name  | Type      | Params
------------------------------------
0 | conv1 | Conv2d    | 456   
1 | pool  | MaxPool2d | 0     
2 | conv2 | Conv2d    | 2.4 K 
3 | fc1   | Linear    | 48.1 K
4 | fc2   | Linear    | 10.2 K
5 | fc3   | Linear    | 850   
------------------------------------
62.0 K    Trainable params
0         Non-trainable params
62.0 K    Total params
0.248     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 9: 100%|██████████| 250/250 [00:13<00:00, 18.06it/s, loss=1.04, v_num=160]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 250/250 [00:13<00:00, 18.05it/s, loss=1.04, v_num=160]


In [4]:
for i in range(20):
    print(cifar["test"][i]["label"], model(cifar["test"][i]["img"].unsqueeze(0)).argmax())

3 tensor(3)
8 tensor(1)
8 tensor(1)
0 tensor(8)
6 tensor(6)
6 tensor(6)
1 tensor(1)
6 tensor(6)
3 tensor(3)
1 tensor(1)
0 tensor(4)
9 tensor(9)
5 tensor(6)
7 tensor(7)
9 tensor(9)
8 tensor(1)
5 tensor(5)
7 tensor(5)
8 tensor(8)
6 tensor(6)
