In [1]:
import torch.nn as nn
import torchvision
from torchvision.transforms import v2
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from random import choices
import numpy as np
import random
from sklearn.model_selection import train_test_split
import warnings
from sklearn.feature_extraction.image import extract_patches_2d

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
warnings.filterwarnings("ignore")

In [3]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [4]:
training_dataset = torchvision.datasets.CIFAR10(root="./", download=True, train=True)
test_dataset = torchvision.datasets.CIFAR10(root="./", download=True, train=False)

100%|██████████| 170M/170M [00:01<00:00, 105MB/s]  


In [5]:
training_images = []
test_images = []
transformation = v2.Compose([
    v2.ToTensor(),
    v2.Pad(padding=2)

])

for img in training_dataset.data:
  training_images.append(transformation(img))

for img in test_dataset.data:
  test_images.append(transformation(img))

training_images = torch.stack(training_images)
test_images = torch.stack(test_images)

In [6]:
print(training_images.shape)

torch.Size([50000, 3, 36, 36])


In [7]:
channel_1 = training_images[:, 0, :, :]
channel_2 = training_images[:, 1, :, :]
channel_3 = training_images[:, 2, :, :]

In [8]:
print(channel_1.shape)

torch.Size([50000, 36, 36])


In [9]:
# https://stackoverflow.com/a/16858283
def blockshaped(arr, nrows, ncols):
    """
    Return an array of shape (n, nrows, ncols) where
    n * nrows * ncols = arr.size

    If arr is a 2D array, the returned array should look like n subblocks with
    each subblock preserving the "physical" layout of arr.
    """
    h, w = arr.shape
    assert h % nrows == 0, f"{h} rows is not evenly divisible by {nrows}"
    assert w % ncols == 0, f"{w} cols is not evenly divisible by {ncols}"
    return (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols))

In [10]:

channel_1_patches = []
channel_2_patches = []
channel_3_patches = []
for i in range(channel_1.shape[0]):
  channel_1_patches.append(blockshaped(channel_1[i], 12, 12))
channel_1_patches = torch.stack(channel_1_patches)
for i in range(channel_2.shape[0]):
  channel_2_patches.append(blockshaped(channel_2[i], 12, 12))
channel_2_patches = torch.stack(channel_2_patches)
for i in range(channel_3.shape[0]):
  channel_3_patches.append(blockshaped(channel_3[i], 12, 12))
channel_3_patches = torch.stack(channel_3_patches)

In [11]:

patches = torch.stack([channel_1_patches, channel_2_patches, channel_3_patches], dim=1).permute(0, 2, 1, 3, 4)

In [12]:
train_labels = training_dataset.targets
test_labels = test_dataset.targets
y_train = torch.Tensor([[1 if i == el else 0 for i in range(10)] for el in train_labels])
y_test = torch.Tensor([[1 if i == el else 0 for i in range(10)] for el in test_labels])

In [13]:


class MLP(nn.Module):

  def __init__(self, in_features, hidden_size, out_features, dropout, output=False):
    super().__init__()
    self.fc1 = nn.Linear(in_features, hidden_size)
    self.fc2 = nn.Linear(hidden_size, out_features)
    self.dropout = nn.Dropout(dropout)
    self.output = output

  def forward(self, x):
    x = self.fc1(x)
    x = self.dropout(x)
    x = nn.functional.gelu(x)
    x = self.fc2(x)
    if not self.output:
      x = self.dropout(x)
    return x

class TransformerEncoder(nn.Module):
  def __init__(self, n_patches, d_model, n_heads, mlp_size, dropout):
    super().__init__()
    self.ln1 = nn.LayerNorm([n_patches + 1, d_model])
    self.ln2 = nn.LayerNorm([n_patches + 1, d_model])
    self.msa = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
    self.mlp = MLP(d_model, mlp_size, d_model, dropout)

  def forward(self, x):
    x_skip = x
    x = self.ln1(x)
    x = self.msa(x, x, x)[0]
    x = x + x_skip
    x_skip = x
    x = self.ln2(x)
    x = self.mlp(x)
    return x + x_skip

class TransformerEncoderLayer(nn.Module):

  def __init__(self, n_patches, d_model, n_heads, mlp_size, layer, dropout):
    super().__init__()
    self.transformer_layers = nn.Sequential(*[TransformerEncoder(n_patches, d_model, n_heads, mlp_size, dropout) for i in range(layer)])

  def forward(self, x):
    return self.transformer_layers(x)


class VisionTransformer(nn.Module):
  def __init__(self, layer, d_model, n_heads, mlp_size, patch_size, channels, n_patches, n_classes, dropout=0):
    super().__init__()

    self.linear_projection = nn.Parameter(torch.zeros((patch_size**2 * channels, d_model)))
    self.class_token = nn.Parameter(torch.zeros(1, 1, d_model))
    self.pos_embedding = nn.Parameter(torch.zeros(n_patches+1, d_model))
    self.transformer_layers = TransformerEncoderLayer(n_patches, d_model, n_heads, mlp_size, layer, dropout)
    self.mlp_head = MLP(d_model, mlp_size, n_classes, dropout, output=True)
    self.dropout = nn.Dropout(dropout)


    nn.init.normal_(self.linear_projection)
    nn.init.normal_(self.class_token)
    nn.init.normal_(self.pos_embedding)

  def forward(self, x):
    x = torch.flatten(x, 2, -1)
    z0 = torch.einsum("nkp,pd -> nkd", x, self.linear_projection)
    class_token = self.class_token.expand(x.shape[0], -1, -1)
    z0 = torch.cat([class_token, z0], dim=1) + self.pos_embedding
    z0 = self.dropout(z0)
    z0 = self.transformer_layers(z0)
    y = z0[:, 0]
    out = self.mlp_head(y)
    return out

In [13]:
#transformer = VisionTransformer(layer=12, d_model=768, n_heads=12, mlp_size=3072, patch_size=12, channels=3, n_patches=9, n_classes=10)

In [14]:
X_train, X_val, y_train, y_val = train_test_split(patches, y_train, test_size=0.02, random_state=seed)

In [15]:
def test_model(lr):
  torch.random.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.manual_seed(seed)
  np.random.seed(seed)
  random.seed(seed)
  model = VisionTransformer(layer=12, d_model=768, n_heads=12, mlp_size=3072, patch_size=12, channels=3, n_patches=9, n_classes=10).to(device)
  train_batches = DataLoader([*zip(X_train, y_train)], batch_size=512, shuffle=True)
  val_batches = DataLoader([*zip(X_val, y_val)], batch_size=512, shuffle=True)
  loss_fn = nn.CrossEntropyLoss()
  val_loss_fn = nn.CrossEntropyLoss(reduction="sum")
  optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 1000)
  iterations = 0
  over=False
  while True:
    for batch in train_batches:
      model.train()
      iterations += 1
      optimizer.zero_grad()
      features, target = batch[:-1], batch[-1]
      features = features[0].to(device)
      target = target.to(device)
      outputs = model(features)
      perte = loss_fn(outputs, target)
      perte.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
      optimizer.step()
      scheduler.step()
      if iterations % 100 == 0:
        model.eval()
        total_loss = 0
        for batch in val_batches:
          features, target = batch[:-1], batch[-1]
          features = features[0].to(device)
          target = target.to(device)
          outputs = model(features)
          perte = val_loss_fn(outputs, target)
          total_loss += perte.item()
        print(f"Iteration {iterations}: Loss {total_loss/X_val.shape[0]}")
      if iterations == 20000:
        del model, features, target
        return total_loss


In [None]:
min_error = 10**6
min_lr = None

for lr in [0.001, 0.003, 0.01, 0.03]:
  print("Currently testing lr:", lr)
  err = test_model(lr)
  if err < min_error:
    min_error = err
    min_lr = lr

print(f"The learning rate chosen is: {min_lr}")

Currently testing lr: 0.001
Iteration 100: Loss 2.17852294921875
Iteration 200: Loss 2.0801536865234374
Iteration 300: Loss 2.0406788330078127
Iteration 400: Loss 2.0003892211914063
Iteration 500: Loss 1.9865011596679687
Iteration 600: Loss 1.9752722778320313
Iteration 700: Loss 1.9727693481445312
Iteration 800: Loss 1.9691539916992187
Iteration 900: Loss 1.967044189453125
Iteration 1000: Loss 1.9669603881835938
Iteration 1100: Loss 1.9667037963867187
Iteration 1200: Loss 1.9653912963867188
Iteration 1300: Loss 1.96433984375
Iteration 1400: Loss 1.9592341918945313
Iteration 1500: Loss 1.956561767578125
Iteration 1600: Loss 1.9523311157226562
Iteration 1700: Loss 1.9519368896484375
Iteration 1800: Loss 1.944051025390625
Iteration 1900: Loss 1.9343546752929688
Iteration 2000: Loss 1.9359947509765625
Iteration 2100: Loss 1.9245740966796876
Iteration 2200: Loss 1.9311153564453125
Iteration 2300: Loss 1.9175234985351564
Iteration 2400: Loss 1.9175988159179687
Iteration 2500: Loss 1.91234631

In [18]:

lr = 0.01

print("Currently testing lr:", lr)
err = test_model(lr)

Currently testing lr: 0.01
Iteration 100: Loss 1.982242431640625
Iteration 200: Loss 1.9592686157226562
Iteration 300: Loss 1.9076957397460939
Iteration 400: Loss 1.892860595703125
Iteration 500: Loss 1.8807515258789063
Iteration 600: Loss 1.8674205932617187
Iteration 700: Loss 1.8692074584960938
Iteration 800: Loss 1.8518797607421875
Iteration 900: Loss 1.851481201171875
Iteration 1000: Loss 1.8524962158203124
Iteration 1100: Loss 1.8515333251953126
Iteration 1200: Loss 1.8477855224609374
Iteration 1300: Loss 1.8623739624023437
Iteration 1400: Loss 1.846939208984375
Iteration 1500: Loss 1.850939697265625
Iteration 1600: Loss 1.8290136108398438
Iteration 1700: Loss 1.8207162475585938
Iteration 1800: Loss 1.792864013671875
Iteration 1900: Loss 1.8232410278320312
Iteration 2000: Loss 1.7804913940429687
Iteration 2100: Loss 1.7964934692382812
Iteration 2200: Loss 1.7578609619140626
Iteration 2300: Loss 1.7287144165039063
Iteration 2400: Loss 1.7257636108398438
Iteration 2500: Loss 1.69236

In [19]:

lr = 0.03

print("Currently testing lr:", lr)
err = test_model(lr)

Currently testing lr: 0.03
Iteration 100: Loss 1.9483055419921875
Iteration 200: Loss 1.9030689697265626
Iteration 300: Loss 1.8389835815429687
Iteration 400: Loss 1.7852335815429687
Iteration 500: Loss 1.7535896606445311
Iteration 600: Loss 1.7310767822265625
Iteration 700: Loss 1.6822145385742187
Iteration 800: Loss 1.6629968872070313
Iteration 900: Loss 1.652461669921875
Iteration 1000: Loss 1.65121875
Iteration 1100: Loss 1.6529755859375
Iteration 1200: Loss 1.6461273803710939
Iteration 1300: Loss 1.6409962768554687
Iteration 1400: Loss 1.6303756103515625
Iteration 1500: Loss 1.6188065185546876
Iteration 1600: Loss 1.6315648803710938
Iteration 1700: Loss 1.5675325927734376
Iteration 1800: Loss 1.562041015625
Iteration 1900: Loss 1.5080671997070312
Iteration 2000: Loss 1.4927706298828125
Iteration 2100: Loss 1.4659063110351562
Iteration 2200: Loss 1.4345357055664063
Iteration 2300: Loss 1.375076904296875
Iteration 2400: Loss 1.3707401123046874
Iteration 2500: Loss 1.3455325317382814

KeyboardInterrupt: 

In [21]:
X_train = torch.concat([X_train, X_val])

In [23]:
y_train = torch.concat([y_train, y_val])

In [24]:
X_train.shape
y_train.shape

torch.Size([50000, 10])

In [25]:
lr = 0.03  
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model = VisionTransformer(layer=12, d_model=768, n_heads=12, mlp_size=3072, patch_size=12, channels=3, n_patches=9, n_classes=10).to(device)
train_batches = DataLoader([*zip(X_train, y_train)], batch_size=512, shuffle=True)
val_batches = DataLoader([*zip(X_val, y_val)], batch_size=512, shuffle=True)
loss_fn = nn.CrossEntropyLoss()
val_loss_fn = nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 1000)
iterations = 0
over=False
while True:
    for batch in train_batches:
        model.train()
        iterations += 1
        optimizer.zero_grad()
        features, target = batch[:-1], batch[-1]
        features = features[0].to(device)
        target = target.to(device)
        outputs = model(features)
        perte = loss_fn(outputs, target)
        perte.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        scheduler.step()
        if iterations % 100 == 0:
            print(f"Iteration {iterations}: DONE")
        if iterations == 3400:
            over=True
            break
    if over:
        break

Iteration 100: DONE
Iteration 200: DONE
Iteration 300: DONE
Iteration 400: DONE
Iteration 500: DONE
Iteration 600: DONE
Iteration 700: DONE
Iteration 800: DONE
Iteration 900: DONE
Iteration 1000: DONE
Iteration 1100: DONE
Iteration 1200: DONE
Iteration 1300: DONE
Iteration 1400: DONE
Iteration 1500: DONE
Iteration 1600: DONE
Iteration 1700: DONE
Iteration 1800: DONE
Iteration 1900: DONE
Iteration 2000: DONE
Iteration 2100: DONE
Iteration 2200: DONE
Iteration 2300: DONE
Iteration 2400: DONE
Iteration 2500: DONE
Iteration 2600: DONE
Iteration 2700: DONE
Iteration 2800: DONE
Iteration 2900: DONE
Iteration 3000: DONE
Iteration 3100: DONE
Iteration 3200: DONE
Iteration 3300: DONE
Iteration 3400: DONE


In [26]:
channel_1 = test_images[:, 0, :, :]
channel_2 = test_images[:, 1, :, :]
channel_3 = test_images[:, 2, :, :]
channel_1_patches = []
channel_2_patches = []
channel_3_patches = []
for i in range(channel_1.shape[0]):
  channel_1_patches.append(blockshaped(channel_1[i], 12, 12))
channel_1_patches = torch.stack(channel_1_patches)
for i in range(channel_2.shape[0]):
  channel_2_patches.append(blockshaped(channel_2[i], 12, 12))
channel_2_patches = torch.stack(channel_2_patches)
for i in range(channel_3.shape[0]):
  channel_3_patches.append(blockshaped(channel_3[i], 12, 12))
channel_3_patches = torch.stack(channel_3_patches)
patches = torch.stack([channel_1_patches, channel_2_patches, channel_3_patches], dim=1).permute(0, 2, 1, 3, 4)

In [27]:
patches.shape

torch.Size([10000, 9, 3, 12, 12])

In [28]:
model.eval()
test_loader = DataLoader([*zip(patches, test_labels)], batch_size=512, shuffle=False)
correct = 0
for batch in test_loader:
  features, target = batch[:-1], batch[-1]
  features = features[0].to(device)
  target = target.to(device)
  outputs = model(features)
  correct += torch.where(torch.argmax(torch.softmax(outputs, dim=1), dim=1)==target, 1, 0).sum()
print((correct/test_images.shape[0])*100)

tensor(55.2700, device='cuda:0')


In [None]:
del model, features, target

Have to pretrain or augment data to get better results