### Device agnoistic code

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else 'cpu'

In [None]:
from pathlib import Path
data_path = Path("./data/")
image_path = data_path / "CT-KIDNEY"

In [None]:
from helperfunctions import walk_through_dir 
image_path = data_path / "CT-KIDNEY-VAL"
walk_through_dir(image_path)

In [None]:
# Setup train and testing paths 
train_dir = image_path / "train"
test_dir = image_path / "test"
val_dir = image_path / "val"

train_dir, test_dir, val_dir

### 2. Create Datasets and DataLoaders

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [None]:
IMG_SIZE = 224
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])

In [None]:
train_data = datasets.ImageFolder(root= train_dir, 
                                transform= transform)
test_data = datasets.ImageFolder(root=test_dir, 
                                transform=transform)
val_data = datasets.ImageFolder(root=val_dir, 
                                transform=transform)
train_data, test_data, val_data

In [None]:
import os
BATCH_SIZE = 32
train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=os.cpu_count(),
                              shuffle=True)
test_dataloader = DataLoader(dataset=test_data,
                             batch_size=BATCH_SIZE,
                             num_workers=os.cpu_count(),
                             shuffle=False)
val_dataloader = DataLoader(dataset=val_data,
                             batch_size=BATCH_SIZE,
                             num_workers=os.cpu_count(),
                             shuffle=False)
train_dataloader, test_dataloader, val_dataloader

### 3. Visualization

In [None]:
# Get a batch of images 
image_batch, label_batch = next(iter(train_dataloader))

# Get a single image and label from the batch 
image, label = image_batch[0], label_batch[0]

# View the batch shape 
image.shape, label

In [None]:
class_names = train_data.classes

In [None]:
import matplotlib.pyplot as plt 
plt.imshow(image.permute(1,2,0))
plt.axis(False)
plt.title(class_names[label])

### Vit from Start

In [None]:
# Create example values 
height = 224 
width = 224 
color_channels = 3
patch_size = 16

# Calculate number of patches 
number_of_patches = int((height*width)/patch_size**2)
number_of_patches 

In [None]:
# Input shape 
embedding_layer_input_shape = (height, width, color_channels)
# Output shape 
embedding_layer_output_shape = (number_of_patches, patch_size**2*color_channels)
embedding_layer_input_shape, embedding_layer_output_shape

### Turning a single image into patches

In [None]:
# View a single image 
plt.imshow(image.permute(1,2,0))
plt.title(class_names[label])
plt.axis(False)

In [None]:
image_permuted = image.permute(1,2,0)
patch_size = 16 
plt.figure(figsize=(patch_size, patch_size))
plt.imshow(image_permuted[:patch_size,:,:])

In [None]:
img_size = 224 
patch_size = 16 
num_patches = img_size/patch_size 
assert img_size % patch_size == 0, "Image size must be divisible by patch size."
print(f'Number of patches per row: {num_patches}\nPatch size: {patch_size} pixels x {patch_size} pixels')
# Create a series of subplot 
fig, axs = plt.subplots(nrows=1,
                        ncols=img_size // patch_size, # one column for each patch
                        sharex=True,
                        sharey=True,
                        figsize=(patch_size, patch_size))
# Iterate through number of patches in top row 
for i, patch in enumerate(range(0,img_size,patch_size)):
    axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size,:])
    axs[i].set_xlabel(i+1)
    axs[i].set_xticks([])
    axs[i].set_yticks([])

In [None]:
# Setup code to plot whole image as patch
img_size = 224 
patch_size = 16 
num_patches = img_size/patch_size 
assert img_size%patch_size==0, "Image size must be divisible by patch size."
print(f'''Number of patches per row: {num_patches}
Number of patches per column: {num_patches}
Total patches: {num_patches*num_patches}
Patch size: {patch_size} pixels x {patch_size} pixels''')

# Create a series of subplots 
fig, axs = plt.subplots(nrows=img_size//patch_size,
                        ncols=img_size//patch_size,
                        figsize=(num_patches, num_patches),
                        sharex=True,
                        sharey=True)
# Loop through height and width 
for i, patch_height in enumerate(range(0,img_size,patch_size)):
    for j, patch_width in enumerate(range(0,img_size, patch_size)):
        # Plot the permuted image on different axis 
        axs[i,j].imshow(image_permuted[patch_height:patch_height+patch_size,
                                       patch_width:patch_width+patch_size, 
                                       :])
        axs[i,j].set_ylabel(i+1,
                            rotation="horizontal",
                            horizontalalignment='right',
                            verticalalignment='center')
        axs[i,j].set_xlabel(j+1)
        axs[i,j].set_xticks([])
        axs[i,j].set_yticks([])
        axs[i,j].label_outer()
fig.suptitle(f'{class_names[label]} -> Patchified', fontsize=14)

In [None]:
from torch import nn  
# Set the patch size 
patch_size = 16 
# Create a conv2d layer with hyperparameters from ViT paper 
conv2d = nn.Conv2d(in_channels=3, 
                   out_channels=768, # D size from table 1 
                   kernel_size=patch_size,
                   stride=patch_size,
                   padding=0)
image_out_of_conv = conv2d(image.unsqueeze(0)) 
image_out_of_conv.shape, image_out_of_conv.requires_grad

In [None]:
import random 
random_indexs = random.sample(range(0,758),k=5)
print(f'Showing random convolutional feature maps from indexes: {random_indexs}')

# Create plot 
fig, axs = plt.subplots(nrows=1,ncols=5)

for i, idx in enumerate(random_indexs):
    image_conv_feature_map = image_out_of_conv[:, idx, :, :] 
    axs[i].imshow(image_conv_feature_map.squeeze().detach().numpy())# squueze->removes batch dimension, detach -> removed grads, numpy -> turns to the numpy array 
    axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
flatten_layer = nn.Flatten(start_dim=2,
                           end_dim=3)
image_out_of_conv_flattened = flatten_layer(image_out_of_conv)
image_out_of_conv_flattened.shape

In [None]:
print(f'Image feature map (patches) shape: {image_out_of_conv.shape}')
print(f'Flattened image feature map shape: {image_out_of_conv_flattened.shape}')


In [None]:
image_out_of_conv_flattened_permuted = image_out_of_conv_flattened.permute(0,2,1)
single_flattened_feature_map = image_out_of_conv_flattened_permuted[:, :, 0]
plt.figure(figsize=(22,22))
plt.imshow(single_flattened_feature_map.detach().numpy())
plt.title(f'Flattened feature map shape: {single_flattened_feature_map.shape}')
plt.axis(False)

## Equation 1: Patch Embedding

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, 
                 in_channels:int=3,
                 patch_size:int=16,
                 embedding_dim:int=768):
        super().__init__()
        self.patch_size = patch_size
        self.patcher = nn.Conv2d(in_channels=in_channels,
                                 out_channels=embedding_dim,
                                 kernel_size=patch_size,
                                 stride=patch_size,
                                 padding=0)
        self.flatten = nn.Flatten(start_dim=2,
                                  end_dim=3)
    def forward(self, x):
        image_resolution = x.shape[-1]
        assert image_resolution%patch_size==0, f"Imput image size must be divisible by patch size, image shape: {image_resolution}, potch size: {self.patch_size}"
        x_patched = self.patcher(x)
        x_flattened = self.flatten(x_patched)
        return x_flattened.permute(0,2,1)

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
patchify = PatchEmbedding(in_channels=3,
                          patch_size=16,
                          embedding_dim=768)
print(f'Input image size: {image.unsqueeze(0).shape}')
patch_embedded_image = patchify(image.unsqueeze(0)) 
print(f'Output patch embedding sequeence shape: {patch_embedded_image.shape}')

### Creating the class token embedding

In [None]:
batch_size = patch_embedded_image.shape[0]
embedding_dimension = patch_embedded_image.shape[-1]
batch_size, embedding_dimension

In [None]:
class_token = nn.Parameter(torch.ones(batch_size, 1, embedding_dimension),
                           requires_grad=True)
class_token.shape

In [None]:
patch_embedded_image_with_class_embedding = torch.cat((class_token, patch_embedded_image),
                                                       dim=1)
print(patch_embedded_image_with_class_embedding)
print(f'Sequence of the patch embeddings with class token prepend shape: {patch_embedded_image_with_class_embedding.shape} -> (bathc_size, class_token + number_of_patchs, embedding_dim)')

### Creating the position embedding

In [None]:
patch_embedded_image_with_class_embedding, patch_embedded_image_with_class_embedding.shape

In [None]:
number_of_patches = int((height*width) / patch_size**2)
embedding_dimension = patch_embedded_image_with_class_embedding.shape[-1]
position_embedding = nn.Parameter(torch.ones(1, number_of_patches+1, embedding_dimension),
                                  requires_grad=True)
position_embedding

In [None]:
# Add the position embedding to the patch and class token embedding 
patch_and_position_embedding = patch_embedded_image_with_class_embedding + position_embedding
print(patch_and_position_embedding)
print(f'patch and position embedding shape: {patch_and_position_embedding.shape}')

## Equation 2: Multihead Self-Attention (MSA Block)

In [None]:
class MultiHeadSelfAttentionBlock(nn.Module):
    def __init__(self, 
                 embedding_dimension:int=768, 
                 num_heads:int=12, 
                 attn_dropout:int=0):
        super().__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dimension)
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dimension, 
                                                    num_heads=num_heads, 
                                                    dropout=attn_dropout,
                                                    batch_first=True) 
    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, _ = self.multihead_attn(query=x,
                                             key=x,
                                             value=x,
                                             need_weights=False)
        return attn_output 

In [None]:
multihead_self_attention_block = MultiHeadSelfAttentionBlock(embedding_dimension=768,
                                                             num_heads=12,
                                                             attn_dropout=0)
patched_image_through_msa_block = multihead_self_attention_block(patch_and_position_embedding)
print(f'Input shape of MSA block: {patch_and_position_embedding.shape}')
print(f'Output shape of MSA block: {patched_image_through_msa_block.shape}')

## Equation 3: MultiLayer Perceptron block

```python
#MLP 
x = Linear -> non-linear -> droput -> linear -> dropout
```

In [None]:
class MLPBlock(nn.Module):
    def __init__(self, 
                 embedding_dim:int=768,
                 mlp_size:int=3072,
                 dropout:int=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                    out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size,
                    out_features=embedding_dim),
            nn.Dropout(p=dropout)
        )
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

In [None]:
mlp_block = MLPBlock(embedding_dim=768,
                     mlp_size=3072,
                     dropout=0.1)
patched_image_through_mlp_block = mlp_block(patched_image_through_msa_block)
print(f'Input shape of MLP block: {patched_image_through_msa_block.shape}')
print(f'Output shape of MLP block: {patched_image_through_mlp_block.shape}')

### Creating Transformer Encoder

In [None]:
class TransformerEncoderBlock(nn.Module):
    """Creates a Transformer Encoder block."""
    def __init__(self,
                 embedding_dim:int=768, 
                 num_heads:int=12, 
                 mlp_size:int=3072, 
                 mlp_dropout:float=0.1, 
                 attn_dropout:float=0): 
        super().__init__()
        self.msa_block = MultiHeadSelfAttentionBlock(embedding_dimension=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)
        self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)
    def forward(self, x):
        x =  self.msa_block(x) + x
        x = self.mlp_block(x) + x
        return x

In [None]:
from torchinfo import summary
transformer_encoder_block = TransformerEncoderBlock()
summary(model=transformer_encoder_block,
        input_size=(1, 197, 768),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

### Putting it all together

In [None]:
class ViT(nn.Module):
    def __init__(self, 
                 img_size:int=224,
                 in_channels:int=3, 
                 patch_size:int=16,
                 num_transformer_layers:int=12,
                 embedding_dim:int=768,
                 mlp_size:int=3072,
                 num_heads:int=12,
                 attn_dropout:int=0,
                 mlp_dropout:int=0.1,
                 embedding_dropout:int=0.1, 
                 num_classes:int=1000): 
        super().__init__()
        assert img_size%patch_size==0, f'Image size must be divisible by patch size, image: {img_size}, patch size: {patch_width}'
        self.num_patches=(img_size*img_size) // patch_size**2
        self.class_embedding = nn.Parameter(data=torch.randn(1,1,embedding_dim),
                                            requires_grad=True)
        self.position_embedding = nn.Parameter(data=torch.randn(1,self.num_patches+1, embedding_dim))
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)
        self.patch_embedding = PatchEmbedding(in_channels=in_channels,
                                              patch_size=patch_size,
                                              embedding_dim=embedding_dim)
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                           num_heads=num_heads,
                                                                           mlp_size=mlp_size,
                                                                           mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                      out_features=num_classes)
        )  
    
    def forward(self, x): 
        x = x.to(next(self.parameters()).device)
        batch_size = x.shape[0]
        class_token = self.class_embedding.expand(batch_size,-1,-1)
        x = self.patch_embedding(x)
        x = torch.cat((class_token, x), dim=1)
        x = self.position_embedding + x
        x = self.embedding_dropout(x)
        x = self.transformer_encoder(x)
        x = self.classifier(x[:,0])
        return x   

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
random_image_tensor = torch.randn(1,3,224,224)
vit = ViT(num_classes=len(class_names)).to(device)
vit(random_image_tensor)

In [None]:
summary(model=ViT(num_classes=len(class_names)),
        input_size=(1, 3, 224, 224), # (batch_size, color_channels, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

### Training code

In [None]:
optimizer = torch.optim.Adam(vit.parameters(), 
                             lr=1e-3,
                             betas=(0.9,0.999),
                             weight_decay=0.1)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
from model_train import train

torch.manual_seed(42)
torch.cuda.manual_seed(42)

results = train(model=vit,
                    train_dataloader=train_dataloader,
                    test_dataloader=val_dataloader,
                    epochs=30,
                    optimizer=optimizer,
                    loss_fn=loss_fn,
                    device=device)

In [None]:
results

In [None]:
from typing import Tuple, Dict, List

def plot_loss_curves(results: Dict[str, List[float]]):
    """Plots training curves of rsults dictionary"""
    # Get the loss values of results dictionary (training and test)
    loss = results["train_loss"]
    test_loss= results["test_loss"]
    # Get the accuracy values of the results dictionary (training and val)
    accuracy = results["train_acc"]
    test_accuracy = results["test_acc"]
    # Figure out number of epochs 
    epochs = range(len(results["train_loss"]))
    # Setup a plot 
    plt.figure(figsize=(15,7))
    # Plot the loss 
    plt.subplot(1,2,1)
    plt.plot(epochs, loss, label="train_loss")
    plt.plot(epochs, test_loss, label="test_loss")
    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.legend()

    # Plot the accuracy 
    plt.subplot(1,2,2)
    plt.plot(epochs, accuracy, label="train_acc")
    plt.plot(epochs, test_accuracy, label="test_acc")
    plt.title("Accuracy")
    plt.xlabel("Epochs")
    plt.legend();

In [None]:
plot_loss_curves(results)

In [None]:
# Testing loop
vit.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = vit(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
 
test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy:.2f}%')

## Pretrained ViT from `torchvision.models`

In [None]:
import torchvision

In [None]:
from torch import nn
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
for parameter in pretrained_vit.parameters():
    parameter.requires_grad=False
torch.manual_seed(42)
torch.cuda.manual_seed(42)
pretrained_vit.heads = nn.Linear(in_features=768,
                                 out_features=len(class_names)).to(device)

In [None]:
summary(model=pretrained_vit,
        input_size=(1, 3, 224, 224), # (batch_size, color_channels, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

### Preparing data for the pretrained ViT model

In [None]:
pretrained_vit_transforms = pretrained_vit_weights.transforms()
pretrained_vit_transforms

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
# Image folders
train_data_pretrained = datasets.ImageFolder(root= train_dir, 
                                transform= transform)
test_data_pretrained = datasets.ImageFolder(root=test_dir, 
                                transform=transform)
val_data_pretrained = datasets.ImageFolder(root=val_dir, 
                                transform=transform)
# Data Loaders
train_dataloader_pretrained = DataLoader(dataset=train_data_pretrained,
                              batch_size=BATCH_SIZE,
                              num_workers=os.cpu_count(),
                              shuffle=True)
test_dataloader_pretrained = DataLoader(dataset=test_data_pretrained,
                             batch_size=BATCH_SIZE,
                             num_workers=os.cpu_count(),
                             shuffle=False)
val_dataloader_pretrained = DataLoader(dataset=val_data_pretrained,
                             batch_size=BATCH_SIZE,
                             num_workers=os.cpu_count(),
                             shuffle=False)

In [None]:
class_names = train_data_pretrained.classes
class_names

In [None]:
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(),
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
pretrained_vit_results = train(model=pretrained_vit,
                                      train_dataloader=train_dataloader_pretrained,
                                      test_dataloader=val_dataloader_pretrained,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=10,
                                      device=device)

In [None]:
from sklearn.metrics import confusion_matrix 
import seaborn as sns

all_preds = []
all_labels = []

pretrained_vit.eval() 
with torch.inference_mode():
    for images, labels in test_dataloader_pretrained:
        images, labels = images.to(device), labels.to(device)
        outputs = pretrained_vit(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy:.2f}%')
# Compute the confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot the confusion matrix using seaborn
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()

In [None]:
plot_loss_curves(pretrained_vit_results)

In [None]:
vit.state_dict()

In [None]:
from pathlib import Path 
# Create model dictory path
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True,
                exist_ok=True)
# Create model save 
MODEL_NAME = "vit.pth"
MODEL_SAVE_PATH = MODEL_PATH/MODEL_NAME
# Save the model state dict 
print(f"Saving model to : {MODEL_SAVE_PATH}")
torch.save(obj=vit.state_dict(),
          f=MODEL_SAVE_PATH)

In [None]:
MODEL_NAME = "vit_pretrained.pth"
MODEL_SAVE_PATH = MODEL_PATH/MODEL_NAME
# Save the model state dict 
print(f"Saving model to : {MODEL_SAVE_PATH}")
torch.save(obj=vit.state_dict(),
          f=MODEL_SAVE_PATH)

## Prediction on custom image

In [None]:
custom_image_path = data_path / "Tumor1.jpg"
custom_image_uint8 = torchvision.io.read_image(custom_image_path)
print(f" Custom image tensor: \n{custom_image_uint8}")
print(f" Custom image shape: {custom_image_uint8.shape}")
print(f" Custom image dtype: {custom_image_uint8.dtype}")

In [None]:
from helperfunctions import pred_and_plot_image
custom_transform = transforms.Compose([
    transforms.Resize((224, 224))
])
pred_and_plot_image(model=pretrained_vit,
                    image_path=custom_image_path,
                    transform = custom_transform,
                    class_names=class_names)