# Training & Evaluation Code

In [None]:
import os
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.datasets import Flowers102

In [None]:
! nvidia-smi

Thu Sep  8 02:39:30 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   64C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from torch.utils.data import Dataset, DataLoader
image_preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

train_dataset = Flowers102(root = '.', split = "train", transform = image_preprocess, download=True)
validation_dataset = Flowers102(root = '.', split = "val", transform = image_preprocess)
test_dataset = Flowers102(root = '.', split = "test", transform = image_preprocess)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
validation_dataloader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

In [None]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.deivce('cpu')

from torchvision.models.mobilenetv2 import MobileNetV2
model = MobileNetV2(num_classes=102).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch_num in tqdm(range(100)):
  # train loop
  loss_sum = 0
  for iteration_num, batch in enumerate(train_dataloader):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    loss_sum += loss.item()

  print(f"{epoch_num:04d} epoch train loss: {loss_sum / len(train_dataloader)}")
  if epoch_num % 10 == 0:
    model.eval()
    # validatoin loop
    loss_sum = 0
    correct = 0
    total = 0
    with torch.no_grad():
      for iteration_num, batch in enumerate(validation_dataloader):
        images, labels = batch
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss_sum += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f'{epoch_num:04d} epoch validation loss: {loss_sum / len(validation_dataloader)}')
    print(f'{epoch_num:04d} epoch validation accuracy: {100 * correct / float(total)} %')
    model.train() 



  0%|          | 0/100 [00:00<?, ?it/s]

0000 epoch train loss: 4.703148562461138


  1%|          | 1/100 [00:28<46:13, 28.01s/it]

0000 epoch validation loss: 4.53040530611487
0000 epoch validation accuracy: 1.3725490196078431 %


  2%|▏         | 2/100 [00:37<28:12, 17.28s/it]

0001 epoch train loss: 4.541913975030184


  3%|▎         | 3/100 [00:47<22:22, 13.85s/it]

0002 epoch train loss: 4.443713501095772


  4%|▍         | 4/100 [00:57<19:33, 12.23s/it]

0003 epoch train loss: 4.259026499465108


  5%|▌         | 5/100 [01:07<18:00, 11.38s/it]

0004 epoch train loss: 4.018710654228926


  6%|▌         | 6/100 [01:17<17:27, 11.14s/it]

0005 epoch train loss: 3.9088589008897543


  7%|▋         | 7/100 [01:27<16:36, 10.71s/it]

0006 epoch train loss: 3.7978671938180923


  8%|▊         | 8/100 [01:38<16:27, 10.73s/it]

0007 epoch train loss: 3.681610157713294


  9%|▉         | 9/100 [01:48<15:50, 10.45s/it]

0008 epoch train loss: 3.6498721539974213


 10%|█         | 10/100 [01:58<15:24, 10.27s/it]

0009 epoch train loss: 3.570178894326091
0010 epoch train loss: 3.512020258232951


 11%|█         | 11/100 [02:20<20:36, 13.89s/it]

0010 epoch validation loss: 3.5201761686918784
0010 epoch validation accuracy: 11.666666666666666 %


 12%|█▏        | 12/100 [02:30<18:33, 12.65s/it]

0011 epoch train loss: 3.4555709920823574


 13%|█▎        | 13/100 [02:40<17:10, 11.84s/it]

0012 epoch train loss: 3.36345280893147


 14%|█▍        | 14/100 [02:49<16:08, 11.27s/it]

0013 epoch train loss: 3.3145977910608053


 15%|█▌        | 15/100 [02:59<15:23, 10.86s/it]

0014 epoch train loss: 3.2684223521500826


 16%|█▌        | 16/100 [03:09<14:48, 10.58s/it]

0015 epoch train loss: 3.2387063428759575


 17%|█▋        | 17/100 [03:19<14:19, 10.36s/it]

0016 epoch train loss: 3.181736074388027


 18%|█▊        | 18/100 [03:29<13:57, 10.21s/it]

0017 epoch train loss: 3.1163027808070183


 19%|█▉        | 19/100 [03:40<14:03, 10.41s/it]

0018 epoch train loss: 3.048980388790369


 20%|██        | 20/100 [03:50<13:40, 10.25s/it]

0019 epoch train loss: 3.001675721257925
0020 epoch train loss: 2.9591119941323996


 21%|██        | 21/100 [04:12<18:13, 13.84s/it]

0020 epoch validation loss: 3.343587352262408
0020 epoch validation accuracy: 17.45098039215686 %


 22%|██▏       | 22/100 [04:22<16:28, 12.67s/it]

0021 epoch train loss: 2.9072993770241737


 23%|██▎       | 23/100 [04:32<15:13, 11.86s/it]

0022 epoch train loss: 2.7745768409222364


 24%|██▍       | 24/100 [04:42<14:19, 11.31s/it]

0023 epoch train loss: 2.778054365888238


 25%|██▌       | 25/100 [04:52<13:37, 10.91s/it]

0024 epoch train loss: 2.7429682202637196


 26%|██▌       | 26/100 [05:02<13:06, 10.63s/it]

0025 epoch train loss: 2.673994604498148


 27%|██▋       | 27/100 [05:12<12:43, 10.46s/it]

0026 epoch train loss: 2.582981782965362


 28%|██▊       | 28/100 [05:22<12:21, 10.30s/it]

0027 epoch train loss: 2.5446953531354666


 29%|██▉       | 29/100 [05:32<12:01, 10.16s/it]

0028 epoch train loss: 2.4150648526847363


 30%|███       | 30/100 [05:43<12:08, 10.40s/it]

0029 epoch train loss: 2.3903857180848718
0030 epoch train loss: 2.2593083577230573


 31%|███       | 31/100 [06:05<16:03, 13.97s/it]

0030 epoch validation loss: 3.2785057016684873
0030 epoch validation accuracy: 20.19607843137255 %


 32%|███▏      | 32/100 [06:15<14:25, 12.73s/it]

0031 epoch train loss: 2.273761701770127


 33%|███▎      | 33/100 [06:25<13:15, 11.88s/it]

0032 epoch train loss: 2.1127526154741645


 34%|███▍      | 34/100 [06:35<12:26, 11.31s/it]

0033 epoch train loss: 2.102050017565489


 35%|███▌      | 35/100 [06:45<11:49, 10.91s/it]

0034 epoch train loss: 2.053819213062525


 36%|███▌      | 36/100 [06:55<11:20, 10.63s/it]

0035 epoch train loss: 1.972987374290824


 37%|███▋      | 37/100 [07:05<10:56, 10.43s/it]

0036 epoch train loss: 1.902940553613007


 38%|███▊      | 38/100 [07:14<10:35, 10.25s/it]

0037 epoch train loss: 1.8570461990311742


 39%|███▉      | 39/100 [07:24<10:17, 10.12s/it]

0038 epoch train loss: 1.8562166057527065


 40%|████      | 40/100 [07:34<10:03, 10.05s/it]

0039 epoch train loss: 1.6765793655067682
0040 epoch train loss: 1.6047869115136564


 41%|████      | 41/100 [07:57<13:46, 14.01s/it]

0040 epoch validation loss: 3.365970396735322
0040 epoch validation accuracy: 22.941176470588236 %


 42%|████▏     | 42/100 [08:07<12:20, 12.77s/it]

0041 epoch train loss: 1.5299201677553356


 43%|████▎     | 43/100 [08:17<11:18, 11.90s/it]

0042 epoch train loss: 1.493717124685645


 44%|████▍     | 44/100 [08:27<10:31, 11.28s/it]

0043 epoch train loss: 1.429032081272453


 45%|████▌     | 45/100 [08:37<09:58, 10.88s/it]

0044 epoch train loss: 1.3556502759456635


 46%|████▌     | 46/100 [08:47<09:32, 10.61s/it]

0045 epoch train loss: 1.3361254949122667


 47%|████▋     | 47/100 [08:57<09:11, 10.41s/it]

0046 epoch train loss: 1.3160971459001303


 48%|████▊     | 48/100 [09:07<08:55, 10.30s/it]

0047 epoch train loss: 1.2423184358049184


 49%|████▉     | 49/100 [09:18<09:04, 10.67s/it]

0048 epoch train loss: 1.1222286813426763


 50%|█████     | 50/100 [09:30<09:01, 10.83s/it]

0049 epoch train loss: 1.0318728655111045
0050 epoch train loss: 1.0579576510936022


 51%|█████     | 51/100 [09:53<11:54, 14.58s/it]

0050 epoch validation loss: 3.67556517147402
0050 epoch validation accuracy: 21.96078431372549 %


 52%|█████▏    | 52/100 [10:03<10:32, 13.19s/it]

0051 epoch train loss: 1.0203673930373043


 53%|█████▎    | 53/100 [10:13<09:34, 12.21s/it]

0052 epoch train loss: 0.9345624158158898


 54%|█████▍    | 54/100 [10:23<08:49, 11.51s/it]

0053 epoch train loss: 0.8930054039228708


 55%|█████▌    | 55/100 [10:33<08:21, 11.14s/it]

0054 epoch train loss: 0.8844450765755028


 56%|█████▌    | 56/100 [10:43<07:55, 10.80s/it]

0055 epoch train loss: 0.7953005027957261


 57%|█████▋    | 57/100 [10:53<07:34, 10.58s/it]

0056 epoch train loss: 0.7490074665984139


 58%|█████▊    | 58/100 [11:03<07:16, 10.38s/it]

0057 epoch train loss: 0.6936205751262605


 59%|█████▉    | 59/100 [11:13<07:00, 10.27s/it]

0058 epoch train loss: 0.73809605371207


 60%|██████    | 60/100 [11:23<06:47, 10.19s/it]

0059 epoch train loss: 0.6529801802244037
0060 epoch train loss: 0.6139501583529636


 61%|██████    | 61/100 [11:45<09:00, 13.85s/it]

0060 epoch validation loss: 3.595623810457376
0060 epoch validation accuracy: 23.137254901960784 %


 62%|██████▏   | 62/100 [11:56<08:13, 12.99s/it]

0061 epoch train loss: 0.6459455764852464


 63%|██████▎   | 63/100 [12:06<07:27, 12.09s/it]

0062 epoch train loss: 0.5617168387398124


 64%|██████▍   | 64/100 [12:16<06:51, 11.42s/it]

0063 epoch train loss: 0.5769044760381803


 65%|██████▌   | 65/100 [12:26<06:23, 10.95s/it]

0064 epoch train loss: 0.5173674813122489


 66%|██████▌   | 66/100 [12:36<06:02, 10.65s/it]

0065 epoch train loss: 0.535263801633846


 67%|██████▋   | 67/100 [12:46<05:44, 10.44s/it]

0066 epoch train loss: 0.5326763191260397


 68%|██████▊   | 68/100 [12:56<05:29, 10.29s/it]

0067 epoch train loss: 0.42367219034349546


 69%|██████▉   | 69/100 [13:06<05:16, 10.20s/it]

0068 epoch train loss: 0.41007834702031687


 70%|███████   | 70/100 [13:16<05:04, 10.14s/it]

0069 epoch train loss: 0.43730927509022877
0070 epoch train loss: 0.44013053382514045


 71%|███████   | 71/100 [13:38<06:39, 13.77s/it]

0070 epoch validation loss: 3.9099902548772447
0070 epoch validation accuracy: 23.03921568627451 %


 72%|███████▏  | 72/100 [13:48<05:53, 12.63s/it]

0071 epoch train loss: 0.3945621619350277


 73%|███████▎  | 73/100 [13:59<05:26, 12.10s/it]

0072 epoch train loss: 0.41999379836488515


 74%|███████▍  | 74/100 [14:09<04:59, 11.52s/it]

0073 epoch train loss: 0.36132220638683066


 75%|███████▌  | 75/100 [14:19<04:36, 11.05s/it]

0074 epoch train loss: 0.377717976924032


 76%|███████▌  | 76/100 [14:29<04:16, 10.69s/it]

0075 epoch train loss: 0.3851268233556766


 77%|███████▋  | 77/100 [14:39<04:01, 10.49s/it]

0076 epoch train loss: 0.3075634903216269


 78%|███████▊  | 78/100 [14:49<03:47, 10.32s/it]

0077 epoch train loss: 0.27352738624904305


 79%|███████▉  | 79/100 [14:59<03:33, 10.18s/it]

0078 epoch train loss: 0.3355127993854694


 80%|████████  | 80/100 [15:09<03:22, 10.12s/it]

0079 epoch train loss: 0.32627833579317667
0080 epoch train loss: 0.31301490135956556


 81%|████████  | 81/100 [15:31<04:22, 13.84s/it]

0080 epoch validation loss: 3.9666111910802155
0080 epoch validation accuracy: 24.41176470588235 %


 82%|████████▏ | 82/100 [15:41<03:48, 12.68s/it]

0081 epoch train loss: 0.26743385472218506


 83%|████████▎ | 83/100 [15:51<03:22, 11.89s/it]

0082 epoch train loss: 0.25976024684496224


 84%|████████▍ | 84/100 [16:02<03:06, 11.65s/it]

0083 epoch train loss: 0.29112469803658314


 85%|████████▌ | 85/100 [16:12<02:47, 11.15s/it]

0084 epoch train loss: 0.360466378449928


 86%|████████▌ | 86/100 [16:22<02:30, 10.77s/it]

0085 epoch train loss: 0.2707072865741793


 87%|████████▋ | 87/100 [16:32<02:16, 10.50s/it]

0086 epoch train loss: 0.2690725880820537


 88%|████████▊ | 88/100 [16:42<02:04, 10.35s/it]

0087 epoch train loss: 0.2881027713010553


 89%|████████▉ | 89/100 [16:52<01:52, 10.23s/it]

0088 epoch train loss: 0.24944184778723866


 90%|█████████ | 90/100 [17:02<01:41, 10.14s/it]

0089 epoch train loss: 0.22109247493790463
0090 epoch train loss: 0.18111696525011212


 91%|█████████ | 91/100 [17:24<02:04, 13.84s/it]

0090 epoch validation loss: 4.020668491619633
0090 epoch validation accuracy: 25.88235294117647 %


 92%|█████████▏| 92/100 [17:34<01:41, 12.68s/it]

0091 epoch train loss: 0.23292005185066955


 93%|█████████▎| 93/100 [17:44<01:23, 11.87s/it]

0092 epoch train loss: 0.22918009896966396


 94%|█████████▍| 94/100 [17:54<01:07, 11.33s/it]

0093 epoch train loss: 0.19065124177723192


 95%|█████████▌| 95/100 [18:05<00:56, 11.24s/it]

0094 epoch train loss: 0.21227868631103775


 96%|█████████▌| 96/100 [18:15<00:43, 10.87s/it]

0095 epoch train loss: 0.17902731374488212


 97%|█████████▋| 97/100 [18:25<00:31, 10.60s/it]

0096 epoch train loss: 0.19936686467553955


 98%|█████████▊| 98/100 [18:35<00:20, 10.41s/it]

0097 epoch train loss: 0.20716054188960698


 99%|█████████▉| 99/100 [18:46<00:10, 10.36s/it]

0098 epoch train loss: 0.196897069356055


100%|██████████| 100/100 [18:56<00:00, 11.36s/it]

0099 epoch train loss: 0.20180261260975385





In [None]:
torch.save(model.state_dict(), 'last.pth')
model.eval()

# test loop
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(test_dataloader):
      images, labels = batch
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the test images: {100 * correct / float(total)} %')

100%|██████████| 6149/6149 [01:19<00:00, 77.19it/s]

Accuracy of the network on the test images: 20.86518133029761 %





# ViT Inference Code

In [None]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.deivce('cpu')

In [None]:
! pip install timm
import timm

model = timm.create_model('vit_tiny_r_s16_p8_224', pretrained=True)
model.eval()
model.to(device)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


VisionTransformer(
  (patch_embed): HybridEmbed(
    (backbone): Sequential(
      (conv): StdConv2dSame(3, 64, kernel_size=(7, 7), stride=(2, 2), bias=False)
      (norm): GroupNormAct(
        32, 64, eps=1e-05, affine=True
        (drop): Identity()
        (act): ReLU(inplace=True)
      )
      (pool): MaxPool2dSame(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ceil_mode=False)
    )
    (proj): Conv2d(64, 192, kernel_size=(8, 8), stride=(8, 8))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e

In [None]:
import urllib
from PIL import Image

image_preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)
img = Image.open(filename).convert('RGB')
tensor = image_preprocess(img).unsqueeze(0).to(device)

with torch.no_grad():
  output = model(tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities.shape)

url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
urllib.request.urlretrieve(url, filename) 
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# Print top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

torch.Size([1000])
Samoyed 0.8307119607925415
Pomeranian 0.03616778552532196
white wolf 0.02717074751853943
Great Pyrenees 0.019983064383268356
Arctic fox 0.016034426167607307
