In [17]:
import torch
from torchvision.models import vit_b_16, ViT_B_16_Weights
import torchvision.models as models
from torchinfo import summary
import torch.nn as nn
from torchvision import transforms, datasets

# Instantiate VisionTransformer with pretrained weights
weights = ViT_B_16_Weights.DEFAULT  # or any other available pretrained weights
pretrained_vit = vit_b_16(weights=weights)



device = "cuda" if torch.cuda.is_available() else "mps"
print(device)
batch_size = 1024
     

mps


In [11]:
pretrained_vit = models.vit_b_16(weights=weights) # b stands for base and 16 stands for the patch size
pretrained_vit

class_names = ["pizza", "steak", "sushi"]
# key idea is to freeze layers in pretrained vit model and only train the last layer of the head

for param in pretrained_vit.parameters():
    param.requires_grad = False # this is where we freeze the layers
    
    embedding_dim = 768 
    # update the pretrained head
    pretrained_vit.heads = nn.Sequential(
        nn.LayerNorm(normalized_shape=embedding_dim),
        nn.Linear(in_features=embedding_dim, out_features=len(class_names)) # we change out_features from 1000 to 3
    )

In [12]:
# summary 
summary(model=pretrained_vit, 
        input_size=(1, 3, 224, 224), 
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20, 
        row_settings=["var_names"]
       )

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [1, 3, 224, 224]     [1, 3]               768                  Partial
├─Conv2d (conv_proj)                                         [1, 3, 224, 224]     [1, 768, 14, 14]     (590,592)            False
├─Encoder (encoder)                                          [1, 197, 768]        [1, 197, 768]        151,296              False
│    └─Dropout (dropout)                                     [1, 197, 768]        [1, 197, 768]        --                   --
│    └─Sequential (layers)                                   [1, 197, 768]        [1, 197, 768]        --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [1, 197, 768]        [1, 197, 768]        (7,087,872)          False
│    │    └─EncoderBlock (encoder_layer_1)                   [1, 197, 768]        [1, 1

In [13]:
# import os
# import wget
# import zipfile

# # Create a directory to store the downloaded data
# data_directory = "pizza_steak_sushi"
# os.makedirs(data_directory, exist_ok=True)

# # Download the ZIP file
# url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi_20_percent.zip"
# zip_file = os.path.join(data_directory, "pizza_steak_sushi_20_percent.zip")
# wget.download(url, zip_file)

# # Extract the contents of the ZIP file
# with zipfile.ZipFile(zip_file, 'r') as zip_ref:
#     zip_ref.extractall(data_directory)


# # Remove the ZIP file
# os.remove(zip_file)

# # Print the paths
# print("Train directory:", train_dir_20_percent)
# print("Test directory:", test_dir_20_percent)


In [21]:
# Set up train/test directories
data_directory = "pizza_steak_sushi"
train_dir_20_percent = os.path.join(data_directory, "train")
test_dir = os.path.join(data_directory, "test")

# Define data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to a specific size
    transforms.ToTensor(),          # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize pixel values
])

# Create train data loader
train_dataset = datasets.ImageFolder(train_dir_20_percent, transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Create test data loader
test_dataset = datasets.ImageFolder(test_dir, transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Get class names
class_names = train_dataset.classes                                                        
                                                         
print(len(train_dataloader))
print(len(test_dataloader))

1
1
