In [1]:
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms
from helper_functions import set_seeds

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [3]:
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT 

pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False
    
class_names = ["basal_rot_disease","Netural"]

set_seeds()
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)


In [4]:
from torchinfo import summary

summary(model=pretrained_vit, 
        input_size=(32, 3, 224, 224), 
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 2]              768                  Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 14, 14]    (590,592)            False
├─Encoder (encoder)                                          [32, 197, 768]       [32, 197, 768]       151,296              False
│    └─Dropout (dropout)                                     [32, 197, 768]       [32, 197, 768]       --                   --
│    └─Sequential (layers)                                   [32, 197, 768]       [32, 197, 768]       --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 197, 768]       [32, 197, 768]       (7,087,872)          False
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 197, 768]       [32, 

In [5]:
train_dir =r'dataset/train'
test_dir =r'dataset/test'

Remember, if you're going to use a pretrained model, it's generally important to ensure your own custom data is transformed/formatted in the same way the data the original model was trained on.

In [6]:
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [7]:
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [8]:
import os

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
    train_dir: str, 
    test_dir: str, 
    transform: transforms.Compose, 
    batch_size: int, 
    num_workers: int=NUM_WORKERS
):

  train_data = datasets.ImageFolder(train_dir, transform=transform)
  test_data = datasets.ImageFolder(test_dir, transform=transform)

  class_names = train_data.classes

  train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True,
  )
  test_dataloader = DataLoader(
      test_data,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      pin_memory=True,
  )

  return train_dataloader, test_dataloader, class_names

In [9]:
train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders(train_dir=train_dir,
                                                                                                     test_dir=test_dir,
                                                                                                     transform=pretrained_vit_transforms,
                                                                                                     batch_size=32) # Could increase if we had more samples, such as here: https://arxiv.org/abs/2205.01580 (there are other improvements there too...)


In [10]:
from engine import engine

optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), 
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

#set_seeds()
pretrained_vit_results = engine.train(model=pretrained_vit,
                                      train_dataloader=train_dataloader_pretrained,
                                      test_dataloader=test_dataloader_pretrained,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=2,
                                      device=device)

  from .autonotebook import tqdm as notebook_tqdm


 50%|█████     | 1/2 [01:10<01:10, 70.33s/it]

Epoch: 1 | train_loss: 0.4746 | train_acc: 0.7593 | test_loss: 0.2632 | test_acc: 0.9375


100%|██████████| 2/2 [02:19<00:00, 69.82s/it]

Epoch: 2 | train_loss: 0.1638 | train_acc: 0.9688 | test_loss: 0.1381 | test_acc: 0.9688





In [11]:
# Save the entire model
torch.save(pretrained_vit, 'model1.pth')



In [12]:
import os
from PIL import Image
import torch
from torchvision import transforms

# Assuming pretrained_vit is already defined
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = 'model1.pth'
pretrained_vit = torch.load(model_path, map_location=device, weights_only=False)
pretrained_vit.eval()
pretrained_vit.to(device)

# Define the folder containing the images
folder_path = "dataset/test/Netural"

# Define the transformation for the images
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# List all files in the folder
image_files = os.listdir(folder_path)

# Iterate through the image files
for image_file in image_files:
    # Construct the full path to the image
    image_path = os.path.join(folder_path, image_file)
    
    # Open and preprocess the image
    img = Image.open(image_path)
    img = img.resize((224, 224))
    img_tensor = image_transform(img).unsqueeze(dim=0).to(device)
    
    # Perform inference
    with torch.no_grad():
        model_output = pretrained_vit(img_tensor)
    
    # Get predicted class
    probabilities = torch.nn.functional.softmax(model_output[0], dim=0)
    predicted_class = torch.argmax(probabilities).item()
    
    # Print the predicted class
    print(f"Image: {image_file}, Predicted class: {predicted_class}")


Image: 428.jpg, Predicted class: 0
Image: 429.jpg, Predicted class: 0
Image: 430.jpg, Predicted class: 0
Image: 431.jpg, Predicted class: 0
Image: 432.jpg, Predicted class: 0
Image: image_0.jpg, Predicted class: 1
Image: image_1.jpg, Predicted class: 1
Image: image_10.jpg, Predicted class: 0
Image: image_11.jpg, Predicted class: 0
Image: image_12.jpg, Predicted class: 0
Image: image_13.jpg, Predicted class: 0
Image: image_14.jpg, Predicted class: 0
Image: image_15.jpg, Predicted class: 0
Image: image_16.jpg, Predicted class: 0
Image: image_17.jpg, Predicted class: 0
Image: image_18.jpg, Predicted class: 0
Image: image_19.jpg, Predicted class: 0
Image: image_2.jpg, Predicted class: 0
Image: image_20.jpg, Predicted class: 0
Image: image_21.jpg, Predicted class: 0
Image: image_3.jpg, Predicted class: 0
Image: image_4.jpg, Predicted class: 0
Image: image_5.jpg, Predicted class: 0
Image: image_6.jpg, Predicted class: 0
Image: image_7.jpg, Predicted class: 0
Image: image_8.jpg, Predicted cla

tensor([0.7204, 0.2796])
0
tensor([0.7166, 0.2834])
0
tensor([0.7303, 0.2697])
0
tensor([0.6938, 0.3062])
0
tensor([0.6771, 0.3229])
0
tensor([0.6586, 0.3414])
0
tensor([0.6373, 0.3627])
0
tensor([0.6132, 0.3868])
0
tensor([0.4771, 0.5229])
1
tensor([0.5136, 0.4864])
0
tensor([0.5551, 0.4449])
0
tensor([0.5373, 0.4627])
0
tensor([0.5605, 0.4395])
0
tensor([0.5780, 0.4220])
0
tensor([0.6193, 0.3807])
0
tensor([0.6901, 0.3099])
0
tensor([0.6514, 0.3486])
0
tensor([0.6721, 0.3279])
0
tensor([0.6473, 0.3527])
0
tensor([0.6346, 0.3654])
0
tensor([0.6388, 0.3612])
0
tensor([0.6386, 0.3614])
0
tensor([0.6179, 0.3821])
0
tensor([0.6441, 0.3559])
0
tensor([0.6255, 0.3745])
0
tensor([0.6331, 0.3669])
0
tensor([0.6332, 0.3668])
0
tensor([0.6311, 0.3689])
0
tensor([0.6496, 0.3504])
0
tensor([0.6435, 0.3565])
0
tensor([0.6345, 0.3655])
0
tensor([0.6373, 0.3627])
0
tensor([0.6522, 0.3478])
0
tensor([0.6454, 0.3546])
0
tensor([0.6521, 0.3479])
0
tensor([0.5351, 0.4649])
0
tensor([0.5867, 0.4133])
0
t