In [None]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    #target_transform=ToTensor()
)

In [None]:
print(ds[0][0].shape)

In [1]:
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import cv2
import torchvision.transforms as transforms

def patchify(images: torch.Tensor, patch_dimension: int | tuple) -> torch.Tensor:
    N, c, h, w = images.shape

    unfold = torch.nn.Unfold(patch_dimension, stride=patch_dimension)

    return unfold(images).view(N, c, patch_dimension[0], patch_dimension[1], -1).permute(0, 4, 1, 2, 3)

#Test the patchify function
def plot_patches(tensor: torch.Tensor, patch_count: tuple[int, int]):
    fig = plt.figure(figsize=(8, 8))
    grid = ImageGrid(fig, 111, nrows_ncols=patch_count, axes_pad=0.1)

    for i, ax in enumerate(grid):
        patch = tensor[i].permute(1, 2, 0).numpy()
        patch = cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)
        ax.imshow(patch)
        ax.axis('off')

    plt.show()

In [None]:
square_image = cv2.imread('data/test_patching/Squared_Papyrus_van_Ipoewer_600.jpg', cv2.COLOR_BGR2RGB)
square_image_tensor = transforms.ToTensor()(square_image)
square_image_tensor = square_image_tensor.unsqueeze(0)

#Shape is supposed to be N(batch), C, H, W
print(f"Dimension of image: {square_image_tensor.shape}")

square_patch_dimension = (50, 100)
square_patch_count = (12, 6)
square_as_patches = patchify(square_image_tensor, square_patch_dimension)
square_as_patches = square_as_patches.squeeze(0)
print(f"Size of patches: {square_as_patches.shape}")
plot_patches(square_as_patches, square_patch_count)

In [None]:
rectangle_image = cv2.imread('data/test_patching/Papyrus_van_Ipoewer_1280_924.jpg', cv2.COLOR_BGR2RGB)
rectangle_image_tensor = transforms.ToTensor()(rectangle_image)
rectangle_image_tensor = rectangle_image_tensor.unsqueeze(0)

rectangle_patch_dimension = (66, 64)
rectangle_patch_count = (14, 20)
rectangle_as_patches = patchify(rectangle_image_tensor, rectangle_patch_dimension)
print(rectangle_as_patches.shape)

rectangle_as_patches = rectangle_as_patches.squeeze(0)
plot_patches(rectangle_as_patches, rectangle_patch_count)

In [2]:
import torch.nn as nn
import numpy as np

#https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c

class VitEncoderBlock(nn.Module):
    mha: nn.MultiheadAttention
    n_heads: int
    hidden_d: int

    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super().__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mha = nn.MultiheadAttention(hidden_d, n_heads, batch_first=True)

        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        x = self.norm1(x)
        out = x + self.mha(x, x, x)[0]

        out = self.norm2(out)
        out = out + self.mlp(out)

        return out

class BasicVit(nn.Module):
  #For classification
  n_classes: int

  #For patchification
  dimensions_chw: tuple
  patch_dimensions: tuple[int, int]
  unfolder: torch.nn.Unfold

  #Encoder-related
  hidden_d: int
  n_blocks: int
  n_heads: int
  blocks: nn.ModuleList

  #Classification
  mlp: nn.Sequential

  def patchify(self, images: torch.Tensor) -> torch.Tensor:
    """Patchify an image - either square or rectangle"""
    return self.unfolder(images).permute(0, 2, 1)
  
  def init_unfolder(self):
    """Init the unfolder kernel"""
    self.unfolder = torch.nn.Unfold(self.patch_dimensions, stride=self.patch_dimensions)

  def get_positional_embeddings(self, sequence_length: int, dimension: int) -> torch.Tensor:
      result = torch.ones(sequence_length, dimension)
      for i in range(sequence_length):
          for j in range(dimension):
              result[i][j] = np.sin(i / (10000 ** (j / dimension))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / dimension)))
      return result

  def __init__(self, dimensions_chw, patch_dimensions, hidden_dimension, n_classes,
              n_blocks = 2, n_heads = 2):
    # Super constructor
    super().__init__()

    self.hidden_d = hidden_dimension
    self.n_classes = n_classes
    self.n_blocks = n_blocks
    self.n_heads = n_heads

    # Channels, Height, Width
    self.dimensions_chw = dimensions_chw
    self.patch_dimensions = patch_dimensions

    #Init the unfolder for patchification
    self.init_unfolder()

    # The linear mapper has the flattened patch as input, so c*patch_dim
    c,h,w = self.dimensions_chw
    patch_col, patch_row = patch_dimensions
    self.input_d = int(c * patch_row * patch_col)
    self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

    # Learn v_class, the classification token
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

    batch_dimension = int((h/patch_col * w/patch_row) + 1)
    self.pos_embed = nn.Parameter(self.get_positional_embeddings(batch_dimension, self.hidden_d))
    self.pos_embed.requires_grad = False

    #Attention blocks
    self.blocks = nn.ModuleList([VitEncoderBlock(hidden_dimension, n_heads, 4) for i in range(n_blocks)])

    #Classification layer
    self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, n_classes),
            #nn.Softmax(dim=-1)
        )

  def forward(self, images):
    #Patchify the image batch and embed then through the linear layer
    patches = self.patchify(images)
    embeding_patch_vectors = self.linear_mapper(patches)

    # Adding classification token to the tokens
    embeding_patch_vectors = torch.stack([torch.vstack((self.class_token, embeding_patch_vectors[i])) for i in range(len(embeding_patch_vectors))])
    
    # Add positional encoding
    n = embeding_patch_vectors.shape[0]
    pos_embed = self.pos_embed.repeat(n, 1, 1)
    out = embeding_patch_vectors + pos_embed

    #Now we have our embeddings, we pass them through our blocks
    for block in self.blocks:
       out = block(out)
    
    return self.mlp(out[:, 0])

# Now train it !

In [3]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Lambda
from torch.utils.data import DataLoader

transform = ToTensor()

train_set = MNIST(root='./../data', train=True, download=True, transform=transform, target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)))
test_set = MNIST(root='./../data', train=False, download=True, transform=transform, target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)))

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

In [4]:
print(f"Lenth of train: {len(train_loader.dataset)}") 
print(f"Lenth of test: {len(test_loader.dataset)}")

Lenth of train: 60000
Lenth of test: 10000


In [7]:
import torch
from torch.nn import CrossEntropyLoss, NLLLoss
from torch.optim import Adam
from tqdm import tqdm, trange

np.random.seed(0)
torch.manual_seed(0)

def main():
    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    model = BasicVit(dimensions_chw=(1, 28, 28), patch_dimensions=(4, 4), hidden_dimension=20, n_blocks=2, n_heads=5, n_classes=10).to(device)
    
    N_EPOCHS = 50
    LR = 1e-3

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
                    
        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_pred, dim=1) == torch.argmax(y, dim=1)).detach().cpu().item()
            total += len(x)
            
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

main()

Using device:  cpu 


Training:   2%|▏         | 1/50 [00:16<13:39, 16.73s/it]

Epoch 1/50 loss: 1.98


Training:   4%|▍         | 2/50 [00:32<12:48, 16.00s/it]

Epoch 2/50 loss: 1.41


Training:   6%|▌         | 3/50 [00:48<12:29, 15.95s/it]

Epoch 3/50 loss: 1.07


Training:   8%|▊         | 4/50 [01:03<12:11, 15.91s/it]

Epoch 4/50 loss: 0.89


Training:  10%|█         | 5/50 [01:19<11:48, 15.75s/it]

Epoch 5/50 loss: 0.76


Training:  12%|█▏        | 6/50 [01:36<12:00, 16.37s/it]

Epoch 6/50 loss: 0.68


Training:  14%|█▍        | 7/50 [01:54<12:00, 16.75s/it]

Epoch 7/50 loss: 0.62


Training:  16%|█▌        | 8/50 [02:12<11:55, 17.05s/it]

Epoch 8/50 loss: 0.57


Training:  18%|█▊        | 9/50 [02:29<11:42, 17.14s/it]

Epoch 9/50 loss: 0.53


Training:  20%|██        | 10/50 [02:46<11:25, 17.15s/it]

Epoch 10/50 loss: 0.50


Training:  22%|██▏       | 11/50 [03:04<11:16, 17.36s/it]

Epoch 11/50 loss: 0.47


Training:  24%|██▍       | 12/50 [03:22<11:01, 17.40s/it]

Epoch 12/50 loss: 0.44


Training:  26%|██▌       | 13/50 [03:39<10:42, 17.38s/it]

Epoch 13/50 loss: 0.43


Training:  28%|██▊       | 14/50 [03:56<10:26, 17.39s/it]

Epoch 14/50 loss: 0.41


Training:  30%|███       | 15/50 [04:22<11:33, 19.82s/it]

Epoch 15/50 loss: 0.39


Training:  32%|███▏      | 16/50 [04:43<11:26, 20.20s/it]

Epoch 16/50 loss: 0.38


Training:  34%|███▍      | 17/50 [05:04<11:18, 20.55s/it]

Epoch 17/50 loss: 0.37


Training:  36%|███▌      | 18/50 [05:25<10:55, 20.49s/it]

Epoch 18/50 loss: 0.35


Training:  38%|███▊      | 19/50 [05:42<10:08, 19.62s/it]

Epoch 19/50 loss: 0.35


Training:  40%|████      | 20/50 [06:00<09:29, 18.98s/it]

Epoch 20/50 loss: 0.34


Training:  42%|████▏     | 21/50 [06:17<09:00, 18.63s/it]

Epoch 21/50 loss: 0.34


Training:  44%|████▍     | 22/50 [06:36<08:44, 18.74s/it]

Epoch 22/50 loss: 0.33


Training:  46%|████▌     | 23/50 [06:54<08:14, 18.33s/it]

Epoch 23/50 loss: 0.33


Training:  48%|████▊     | 24/50 [07:11<07:48, 18.02s/it]

Epoch 24/50 loss: 0.32


Training:  50%|█████     | 25/50 [07:28<07:25, 17.81s/it]

Epoch 25/50 loss: 0.32


Training:  52%|█████▏    | 26/50 [07:46<07:03, 17.64s/it]

Epoch 26/50 loss: 0.31


Training:  54%|█████▍    | 27/50 [08:03<06:42, 17.48s/it]

Epoch 27/50 loss: 0.31


Training:  56%|█████▌    | 28/50 [08:20<06:23, 17.44s/it]

Epoch 28/50 loss: 0.30


Training:  58%|█████▊    | 29/50 [08:38<06:06, 17.45s/it]

Epoch 29/50 loss: 0.30


Training:  60%|██████    | 30/50 [08:55<05:49, 17.47s/it]

Epoch 30/50 loss: 0.30


Training:  62%|██████▏   | 31/50 [09:13<05:32, 17.48s/it]

Epoch 31/50 loss: 0.29


Training:  64%|██████▍   | 32/50 [09:30<05:14, 17.45s/it]

Epoch 32/50 loss: 0.29


Training:  66%|██████▌   | 33/50 [09:47<04:56, 17.42s/it]

Epoch 33/50 loss: 0.28


Training:  68%|██████▊   | 34/50 [10:05<04:39, 17.44s/it]

Epoch 34/50 loss: 0.28


Training:  70%|███████   | 35/50 [10:23<04:22, 17.53s/it]

Epoch 35/50 loss: 0.28


Training:  72%|███████▏  | 36/50 [10:40<04:05, 17.56s/it]

Epoch 36/50 loss: 0.28


Training:  74%|███████▍  | 37/50 [10:58<03:47, 17.54s/it]

Epoch 37/50 loss: 0.27


Training:  76%|███████▌  | 38/50 [11:15<03:29, 17.47s/it]

Epoch 38/50 loss: 0.27


Training:  78%|███████▊  | 39/50 [11:32<03:11, 17.42s/it]

Epoch 39/50 loss: 0.27


Training:  80%|████████  | 40/50 [11:50<02:54, 17.47s/it]

Epoch 40/50 loss: 0.27


Training:  82%|████████▏ | 41/50 [12:07<02:36, 17.44s/it]

Epoch 41/50 loss: 0.26


Training:  84%|████████▍ | 42/50 [12:25<02:19, 17.40s/it]

Epoch 42/50 loss: 0.26


Training:  86%|████████▌ | 43/50 [12:42<02:01, 17.38s/it]

Epoch 43/50 loss: 0.26


Training:  88%|████████▊ | 44/50 [12:59<01:44, 17.39s/it]

Epoch 44/50 loss: 0.25


Training:  90%|█████████ | 45/50 [13:17<01:26, 17.40s/it]

Epoch 45/50 loss: 0.25


Training:  92%|█████████▏| 46/50 [13:34<01:09, 17.44s/it]

Epoch 46/50 loss: 0.25


Training:  94%|█████████▍| 47/50 [13:52<00:52, 17.49s/it]

Epoch 47/50 loss: 0.25


Training:  96%|█████████▌| 48/50 [14:09<00:34, 17.49s/it]

Epoch 48/50 loss: 0.24


Training:  98%|█████████▊| 49/50 [14:27<00:17, 17.44s/it]

Epoch 49/50 loss: 0.24


Training: 100%|██████████| 50/50 [14:44<00:00, 17.69s/it]


Epoch 50/50 loss: 0.24


Testing: 100%|██████████| 79/79 [00:01<00:00, 55.98it/s]

Test loss: 0.22
Test accuracy: 93.16%



