In [1]:
import matplotlib.pyplot as plt
import torch
import torchvision
import torchinfo

from torch import nn
from torchvision import transforms

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
from PIL import Image, UnidentifiedImageError
import os

# Input directory
input_dir = "dataset/train"

# Ensure the input directory exists
if not os.path.exists(input_dir):
    print(f"Input directory '{input_dir}' does not exist.")
    exit(1)

# Recursive conversion function
def convert_images_in_directory(directory):
    for root, _, files in os.walk(directory):
        for file_name in files:
            if file_name.lower().endswith((".png", ".jpg", ".jpeg")):
                file_path = os.path.join(root, file_name)

                try:
                    # Open the image
                    image = Image.open(file_path)

                    # Check if the image is already in RGBA format
                    if image.mode == "RGBA":
                        print(f"Skipping {file_path} (already in RGBA format).")
                        continue

                    # Convert the image to RGBA format
                    image = image.convert("RGBA")

                    # Generate the output RGBA file name
                    rgba_file = os.path.splitext(file_name)[0] + "_rgba.png"

                    # Save the image as RGBA format, replacing the original file
                    output_path = os.path.join(root, rgba_file)
                    image.save(output_path)

                    # Remove the original file
                    os.remove(file_path)

                    print(f"Converted {file_path} to {output_path}")

                except (OSError, UnidentifiedImageError) as e:
                    print(f"Skipping {file_path} due to an error: {e}")

# Start the conversion process from the input directory
convert_images_in_directory(input_dir)

print("Conversion and replacement to RGBA complete.")

Skipping dataset/train/apple/Image_1.jpg due to an error: cannot identify image file 'dataset/train/apple/Image_1.jpg'
Skipping dataset/train/apple/Image_10.jpg due to an error: cannot identify image file 'dataset/train/apple/Image_10.jpg'
Skipping dataset/train/apple/Image_16.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_17.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_18.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_19.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_2.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_20.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_21.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_23.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_24.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_25.jpg (already in RGBA format).
Skipping dataset/train/apple/Image_26.jpg (already in RGBA format).
Skipping data

In [None]:
train_dir = "dataset/train"
test_dir = "dataset/test"

In [None]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_WORKERS= os.cpu_count()

def create_dataloaders(
        train_dir: str,
        test_dir: str,
        transform: transforms.Compose,
        batch_size: int,
        num_workers: int = NUM_WORKERS
):
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    test_data = datasets.ImageFolder(test_dir, transform=transform)
    
    class_names = train_data.classes
    
    train_dataloader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_dataloader = DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_dataloader, test_dataloader, class_names

In [None]:
IMG_SIZE = 224
manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(),
])
print(f"manually created transforms: {manual_transforms}")


In [None]:
BATCH_SIZE = 32
train_dataloader, test_dataloader, class_names = create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=manual_transforms, 
    batch_size=BATCH_SIZE)
train_dataloader, test_dataloader, class_names

In [None]:
image_batch, label_batch = next(iter(train_dataloader))
image, label = image_batch[0], label_batch[0]
print(image.shape, label)
plt.imshow(image.permute(1,2,0))
plt.title(class_names[label])
plt.axis(False);

In [None]:
class PatchEmbedding(nn.Module):
    
    def __init__(self, 
                 in_channels:int=3,
                 patch_size:int=16,
                 embedding_dim:int=768):
        super().__init__()
        
        self.patcher = nn.Conv2d(in_channels=in_channels,
                                 out_channels=embedding_dim,
                                 kernel_size=patch_size,
                                 stride=patch_size,
                                 padding=0)

        self.flatten = nn.Flatten(start_dim=2,
                                  end_dim=3)

    def forward(self, x):
        image_resolution = x.shape[-1]
        assert image_resolution % patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {patch_size}"
        
        x_patched = self.patcher(x)
        x_flattened = self.flatten(x_patched) 
        
        return x_flattened.permute(0, 2, 1)

In [None]:
patch_size = 16
def set_seeds(seed: int=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
set_seeds()
patchify = PatchEmbedding(in_channels=3,
                              patch_size=16,
                              embedding_dim=768)
print(f"Input image shape: {image.unsqueeze(0).shape}")
patch_embedded_image = patchify(image.unsqueeze(0))
print(f"Output patch embedding shape: {patch_embedded_image.shape}")

In [None]:
set_seeds()
patch_size = 16
print(f"Image tensor shape: {image.shape}")
height, width = image.shape[1], image.shape[2]
x = image.unsqueeze(0)
print(f"Input image with batch dimension shape: {x.shape}")
patch_embedding_layer = PatchEmbedding(in_channels=3,
                                       patch_size=patch_size,
                                       embedding_dim=768)
patch_embedding = patch_embedding_layer(x)
print(f"Patching embedding shape: {patch_embedding.shape}")
batch_size = patch_embedding.shape[0]

embedding_dimension = patch_embedding.shape[-1]

class_token = nn.Parameter(torch.ones(batch_size, 1, embedding_dimension), requires_grad=True) 

print(f"Class token embedding shape: {class_token.shape}")

patch_embedding_class_token = torch.cat((class_token, patch_embedding), dim=1)

print(f"Patch embedding with class token shape: {patch_embedding_class_token.shape}")

number_of_patches = int((height * width) / patch_size**2)

position_embedding = nn.Parameter(torch.ones(1, number_of_patches+1, embedding_dimension),
                                  requires_grad=True)
patch_and_position_embedding = patch_embedding_class_token + position_embedding
print(f"Patch and position embedding shape: {patch_and_position_embedding.shape}")

print(patch_embedding_class_token)

In [None]:
class MultiheadSelfAttentionBlock(nn.Module):
    
    def __init__(self,
                 embedding_dim:int=768, 
                 num_heads:int=12, 
                 attn_dropout:float=0): 
        super().__init__()
        
        
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        
        
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True) 
        
    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multihead_attn(query=x, 
                                             key=x,
                                             value=x,
                                             need_weights=False) 
        return attn_output

In [None]:
class MLPBlock(nn.Module):
    
    def __init__(self,
                 embedding_dim:int=768, 
                 mlp_size:int=3072, 
                 dropout:float=0.1): 
        super().__init__()
        
        
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        
       
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                      out_features=mlp_size),
            nn.GELU(), 
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size, 
                      out_features=embedding_dim), 
            nn.Dropout(p=dropout) 
        )
    
    
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

In [None]:
class TransformerEncoderBlock(nn.Module):
    
    def __init__(self,
                 embedding_dim:int=768, 
                 num_heads:int=12, 
                 mlp_size:int=3072, 
                 mlp_dropout:float=0.1, 
                 attn_dropout:float=0): 
        super().__init__()

        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)
        
        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)
        
    def forward(self, x):
        
        x =  self.msa_block(x) + x 
        
        x = self.mlp_block(x) + x 
        
        return x

In [None]:
transformer_encoder_block = TransformerEncoderBlock()

from torchinfo import summary
summary(model=transformer_encoder_block,
        input_size=(1, 197, 768), 
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
       row_settings=["var_names"])

In [None]:
class ViT(nn.Module):
    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 img_size:int=224, 
                 in_channels:int=3,
                 patch_size:int=16,
                 num_transformer_layers:int=12, 
                 embedding_dim:int=768, 
                 mlp_size:int=3072, 
                 num_heads:int=12, 
                 attn_dropout:float=0, 
                 mlp_dropout:float=0.1, 
                 embedding_dropout:float=0.1, 
                 num_classes:int=1000): 
        super().__init__() 
        
        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."
        
        self.num_patches = (img_size * img_size) // patch_size**2
                 
        self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
                                            requires_grad=True)
        
        self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
                                               requires_grad=True)
                
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)
        

        self.patch_embedding = PatchEmbedding(in_channels=in_channels,
                                              patch_size=patch_size,
                                              embedding_dim=embedding_dim)
        
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
       
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim, 
                      out_features=num_classes)
        )
    
    def forward(self, x):
        
        batch_size = x.shape[0]
        
        class_token = self.class_embedding.expand(batch_size, -1, -1) 

        x = self.patch_embedding(x)

        x = torch.cat((class_token, x), dim=1)

        x = self.position_embedding + x
       
        x = self.embedding_dropout(x)

        x = self.transformer_encoder(x)

        x = self.classifier(x[:, 0])

        return x   

In [None]:
vit = ViT(num_classes=len(class_names))

In [None]:
from going_modular.going_modular import engine

# Setup the optimizer to optimize our ViT model parameters using hyperparameters from the ViT paper 
optimizer = torch.optim.Adam(params=vit.parameters(), 
                             lr=3e-3,
                             betas=(0.9, 0.999),
                             weight_decay=0.3) 

loss_fn = torch.nn.CrossEntropyLoss()

set_seeds()

results = engine.train(model=vit,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=25,
                       device=device)

In [None]:
from helper_functions import plot_loss_curves

plot_loss_curves(results)

In [None]:
import requests


from going_modular.going_modular.predictions import pred_and_plot_image


custom_image_path = "test_img.jpg"

pred_and_plot_image(model=vit,
                    image_path=custom_image_path,
                    class_names=class_names)