<a href="https://colab.research.google.com/github/jonbaer/googlecolab/blob/master/Keras_30_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install lightning einops --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.9/776.9 kB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch, torchvision
from torch import nn

In [None]:
from torchvision.transforms import v2 as T

In [None]:
import pytorch_lightning as ptl

In [None]:
import torchmetrics

In [None]:
from einops.layers.torch import Rearrange

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [None]:
train_transforms = T.Compose([
    T.RandomRotation(10),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True)
])

In [None]:
test_transforms = T.Compose([
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True)
])

In [None]:
train_data = torchvision.datasets.CIFAR10('/cifar10/', download=True, transform=train_transforms)
test_data = torchvision.datasets.CIFAR10('/cifar10/', download=True, transform=test_transforms)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 38868204.47it/s]


Extracting /cifar10/cifar-10-python.tar.gz to /cifar10/
Files already downloaded and verified


In [None]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
from einops import rearrange

In [None]:
class InputEmbedding(nn.Module):
    def __init__(self, patch_size=4, input_shape=32, embedding_dim=32):
        super().__init__()
        self.patch_size = patch_size
        self.embedding_dim = embedding_dim
        num_patches = (input_shape // self.patch_size) * (input_shape // self.patch_size)
        self.patchifier = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
        self.proj = nn.Linear(self.patch_size*self.patch_size*3, self.embedding_dim)
        self.pos_embeddings = nn.Parameter(torch.randn((1, num_patches  + 1, self.embedding_dim)))
        self.cls_token = nn.Parameter(torch.randn((1, 1, self.embedding_dim)))

    def forward(self, x):
        x = self.proj(self.patchifier(x))
        cls_token = self.cls_token.repeat(x.shape[0], 1, 1)
        x = torch.concat([cls_token, x], axis=1)
        return x + self.pos_embeddings

In [None]:
class RecurrentMHA(nn.Module):
  def __init__(self, num_heads=4, embedding_dim=32, n_latents=8):
    super().__init__()
    self.Q = nn.Parameter(torch.randn((1, n_latents, embedding_dim)))
    self.keys = nn.Linear(embedding_dim, embedding_dim)
    self.values = nn.Linear(embedding_dim, embedding_dim)
    self.mha = nn.MultiheadAttention(embedding_dim, num_heads, batch_first=True)

  def forward(self, x):
    K = self.keys(x)
    V = self.values(x)
    Q_t = self.Q.repeat(x.shape[0], 1, 1)
    y = torch.zeros_like(x, device=x.device)
    for i in range(x.shape[1]):
      Q_t = self.mha(Q_t, K[:, i:i+1], V[:, i:i+1], need_weights=False)[0]
      y[:, i] = Q_t[:, 0]
    return y

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embedding_dim=32):
      super().__init__()
      self.mha = RecurrentMHA()
      self.mlp = nn.Sequential(
          nn.Linear(embedding_dim, embedding_dim * 3),
          nn.GELU(),
          nn.Linear(embedding_dim*3, embedding_dim)
      )
      self.ln1 = nn.LayerNorm(embedding_dim)
      self.ln2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
      h = self.mha(self.ln1(x)) + x
      h = self.mlp(self.ln1(x)) + x
      return x

In [None]:
class VisionTransformer(nn.Module):
  def __init__(self, embedding_dim=32):
    super().__init__()
    self.inp_embedding = InputEmbedding()
    self.block1 = TransformerBlock()
    self.block2 = TransformerBlock()
    self.block3 = TransformerBlock()
    self.classifier = nn.Linear(embedding_dim, 10)

  def forward(self, x):
    x = self.inp_embedding(x)
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    return self.classifier(x[:, 0, :])

In [None]:
vit = VisionTransformer()

In [None]:
class PTLWrapper(ptl.LightningModule):
    def __init__(self, cnn):
      super().__init__()
      self.cnn = cnn
      self.acc = torchmetrics.Accuracy(task='multiclass', num_classes=10)
      self.loss = torch.nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
      xb, yb = batch
      ypred = self.cnn(xb)
      loss = self.loss(ypred, yb)
      self.log("train_loss", loss, prog_bar=True, logger=False)
      return loss

    def validation_step(self, batch, batch_idx):
      xb, yb = batch
      ypred = self.cnn(xb)
      self.acc(ypred, yb)
      self.log('val_accuracy', self.acc, on_step=False, on_epoch=True, prog_bar=True, logger=False)

    def configure_optimizers(self):
      return torch.optim.Adam(self.cnn.parameters())

In [None]:
#cnn = torch.compile(VisionTransformer())
cnn = VisionTransformer()
ptl_wrapper = PTLWrapper(cnn)

In [None]:
trainer = ptl.Trainer(accelerator='gpu', devices=1,  max_epochs=5)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(ptl_wrapper, train_dataloaders=train_loader,
            val_dataloaders=test_loader)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type               | Params
--------------------------------------------
0 | cnn  | VisionTransformer  | 43.0 K
1 | acc  | MulticlassAccuracy | 0     
2 | loss | CrossEntropyLoss   | 0     
--------------------------------------------
43.0 K    Trainable params
0         Non-trainable params
43.0 K    Total params
0.172     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
