In [5]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchinfo import summary
import torchvision
from torchvision import transforms
from prepare_data import create_dataloaders

In [26]:
dataset_dir = "../datasets/"

transform_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize([224])
])


train_data = torchvision.datasets.CIFAR10(root=dataset_dir,
                                         download=False,
                                         train=True,
                                         transform=transform_tensor)
test_data = torchvision.datasets.CIFAR10(root=dataset_dir,
                                        download=False,
                                        train=False,
                                        transform=transform_tensor)

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [10]:
vit_l_16_model_weigths = torchvision.models.ViT_L_16_Weights.DEFAULT

vit_l_16_model = torchvision.models.vit_l_16(vit_l_16_model_weigths).to(device)



In [11]:
summary(model=vit_l_16_model, 
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 1000]           1,024                True
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 1024, 14, 14]   787,456              True
├─Encoder (encoder)                                          [32, 197, 1024]      [32, 197, 1024]      201,728              True
│    └─Dropout (dropout)                                     [32, 197, 1024]      [32, 197, 1024]      --                   --
│    └─Sequential (layers)                                   [32, 197, 1024]      [32, 197, 1024]      --                   True
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 1024]      [32, 197, 1024]      12,596,224           True
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 1024]      [32, 197, 10

In [12]:
dataloader_transforms = vit_l_16_model_weigths.transforms

In [14]:
print(dataloader_transforms)

functools.partial(<class 'torchvision.transforms._presets.ImageClassification'>, crop_size=224, resize_size=242)


In [28]:
train_dataloader, test_dataloader, class_names = create_dataloaders(train_data=train_data,
                                                                    test_data=test_data,
                                                                    transform=dataloader_transforms,
                                                                    data_folder_imported=True,
                                                                    batch_size=32)

In [29]:
image, label = next(iter(train_dataloader))

In [30]:
image.shape

torch.Size([32, 3, 224, 224])

In [17]:
for param in vit_l_16_model.parameters():
    param.requires_grad = False
    print(param.requires_grad)

False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
Fals

In [19]:
vit_l_16_model.heads.requires_grad_ = True # unfreezing classifier

vit_l_16_model.heads = nn.Linear(in_features=1024, out_features=len(class_names), device=device)

In [20]:
summary(model=vit_l_16_model, 
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 10]             1,024                Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 1024, 14, 14]   (787,456)            False
├─Encoder (encoder)                                          [32, 197, 1024]      [32, 197, 1024]      201,728              False
│    └─Dropout (dropout)                                     [32, 197, 1024]      [32, 197, 1024]      --                   --
│    └─Sequential (layers)                                   [32, 197, 1024]      [32, 197, 1024]      --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 1024]      [32, 197, 1024]      (12,596,224)         False
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 1024]      [32, 

In [21]:
def manual_seed():
    torch.cuda.manual_seed(1)
    np.random.seed(1)

In [32]:
summary(model=vit_l_16_model, 
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 10]             1,024                Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 1024, 14, 14]   (787,456)            False
├─Encoder (encoder)                                          [32, 197, 1024]      [32, 197, 1024]      201,728              False
│    └─Dropout (dropout)                                     [32, 197, 1024]      [32, 197, 1024]      --                   --
│    └─Sequential (layers)                                   [32, 197, 1024]      [32, 197, 1024]      --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 1024]      [32, 197, 1024]      (12,596,224)         False
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 1024]      [32, 

In [33]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(lr=0.001, params=vit_l_16_model.parameters())
manual_seed()
from train_model import train
vit_l_16_model_results = train(model=vit_l_16_model,
                              train_dataloader=train_dataloader,
                              test_dataloader=test_dataloader,
                              optimizer=optimizer,
                              loss_fn=loss_fn,
                              device=device,
                              epochs=10)

 10%|█         | 1/10 [11:31<1:43:47, 691.97s/it]

Epoch: 1 | train_loss: 0.4643 | train_acc: 0.8985 | test_loss: 0.3813 | test_acc: 0.9026


 20%|██        | 2/10 [23:04<1:32:17, 692.16s/it]

Epoch: 2 | train_loss: 0.3379 | train_acc: 0.9097 | test_loss: 0.3226 | test_acc: 0.9085


 30%|███       | 3/10 [34:36<1:20:45, 692.26s/it]

Epoch: 3 | train_loss: 0.2961 | train_acc: 0.9161 | test_loss: 0.2954 | test_acc: 0.9122


 40%|████      | 4/10 [46:09<1:09:13, 692.32s/it]

Epoch: 4 | train_loss: 0.2736 | train_acc: 0.9198 | test_loss: 0.2787 | test_acc: 0.9154


 50%|█████     | 5/10 [57:41<57:41, 692.31s/it]  

Epoch: 5 | train_loss: 0.2588 | train_acc: 0.9227 | test_loss: 0.2672 | test_acc: 0.9186


 60%|██████    | 6/10 [1:09:13<46:09, 692.37s/it]

Epoch: 6 | train_loss: 0.2479 | train_acc: 0.9245 | test_loss: 0.2587 | test_acc: 0.9200


 70%|███████   | 7/10 [1:20:46<34:37, 692.38s/it]

Epoch: 7 | train_loss: 0.2397 | train_acc: 0.9263 | test_loss: 0.2520 | test_acc: 0.9215


 80%|████████  | 8/10 [1:32:18<23:04, 692.31s/it]

Epoch: 8 | train_loss: 0.2328 | train_acc: 0.9279 | test_loss: 0.2466 | test_acc: 0.9217


 90%|█████████ | 9/10 [1:43:50<11:32, 692.32s/it]

Epoch: 9 | train_loss: 0.2272 | train_acc: 0.9291 | test_loss: 0.2419 | test_acc: 0.9229


100%|██████████| 10/10 [1:55:23<00:00, 692.32s/it]

Epoch: 10 | train_loss: 0.2225 | train_acc: 0.9302 | test_loss: 0.2382 | test_acc: 0.9238



