# Import required libraries

In [None]:
import torch
import torch.nn as nn

from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

# Load Dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
    ])

train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


# Load and define the model


In [None]:
!pip install timm



In [None]:
vit_model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)

num_labels = 10
vit_model.head = nn.Linear(vit_model.head.in_features, num_labels)

Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main
  def deit_tiny_patch16_224(pretrained=False, **kwargs):
  def deit_small_patch16_224(pretrained=False, **kwargs):
  def deit_base_patch16_224(pretrained=False, **kwargs):
  def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
  def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
  def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
  def deit_base_patch16_384(pretrained=False, **kwargs):
  def deit_base_distilled_patch16_384(pretrained=False, **kwargs):


In [None]:
# vit_model.parameters()

In [None]:
for param in vit_model.parameters():
    param.requires_grad = False

for param in vit_model.head.parameters():
    param.requires_grad = True

# Define the training configurations

In [None]:
writer = SummaryWriter()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vit_model.parameters(), lr=0.001)

num_epochs = 2

# Train the model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_model.to(device)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [None]:
train_dataset[0][0].shape

torch.Size([3, 224, 224])

In [None]:
from tqdm import tqdm

In [None]:
len(train_dataset), len(train_loader)

(60000, 938)

In [None]:
for epoch in range(num_epochs):
  vit_model.train()
  running_loss = 0
  correct_train = 0
  total_train = 0

  for i, (inputs, labels) in tqdm(enumerate(train_loader)):
    inputs, labels = inputs.to(device), labels.to(device)
    optimizer.zero_grad()
    # print(inputs.shape)
    outputs = vit_model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

    _, predicted = torch.max(outputs, 1)
    correct_train += (predicted == labels).sum().item()
    total_train += labels.size(0)

  epoch_loss = running_loss / len(train_loader)
  train_accuracy = correct_train / total_train
  print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Training Accuracy: {100 * train_accuracy:.2f}%")

  writer.add_scalar('Loss/train', epoch_loss, epoch)
  writer.add_scalar('Accuracy/train', train_accuracy, epoch)

writer.close()

938it [12:30,  1.25it/s]


Epoch 1/2, Loss: 0.4087, Training Accuracy: 86.33%


938it [12:29,  1.25it/s]

Epoch 2/2, Loss: 0.2935, Training Accuracy: 89.54%





In [None]:
vit_model.eval()

correct_test = 0
total_test = 0

with torch.no_grad():
  for inputs, labels in tqdm(test_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = vit_model(inputs)
    _, predicted = torch.max(outputs, 1)
    total_test += labels.size(0)
    correct_test += (predicted == labels).sum().item()

test_accuracy = correct_test / total_test
print(f"Test Accuracy: {100 * test_accuracy:.2f}%")

100%|██████████| 157/157 [02:05<00:00,  1.25it/s]

Test Accuracy: 89.10%





# Onnx Export

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install onnx
!pip install onnxscript

Collecting onnx
  Downloading onnx-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.9/15.9 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.16.0
Collecting onnxscript
  Downloading onnxscript-0.1.0.dev20240505-py3-none-any.whl (594 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m594.8/594.8 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnxscript
Successfully installed onnxscript-0.1.0.dev20240505


In [None]:
dummy_input = torch.randn(1, 3, 224, 224).to(device)
onnx_program = torch.onnx.dynamo_export(vit_model, dummy_input)

onnx_program.save("/content/drive/MyDrive/dlops/vit_model_fashionmnist.onnx")



# Import onnx model

In [None]:
!pip install onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.17.3-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m62.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: humanfriendly, coloredlogs, onnxruntime
Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-1.17.3


In [None]:
import onnxruntime
import numpy as np

onnx_model_path = "/content/drive/MyDrive/dlops/vit_model_fashionmnist.onnx"
ort_session = onnxruntime.InferenceSession(onnx_model_path)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
inputs, labels = next(iter(test_loader))
input_tensor = inputs.numpy()

ort_inputs = {ort_session.get_inputs()[0].name: input_tensor}
ort_outputs = ort_session.run(None, ort_inputs)

predicted_label_index = np.argmax(ort_outputs[0])
predicted_label = test_dataset.classes[predicted_label_index]

print(f"Predicted Label: {predicted_label}")

Predicted Label: Trouser


# Save the traced model

In [None]:
vit_model = vit_model.to('cpu')

In [None]:
sample_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(vit_model, sample_input)

traced_model.save("/content/drive/MyDrive/dlops/fashion_mnist_vit_scripted_cpu.pt")

  assert condition, message
