# Implementation of Vision Transformer for building segmentation from INRIA aerial image labeling dataset using PyTorch

### Import libraries

In [None]:
! kaggle datasets download -d sagar100rathod/inria-aerial-image-labeling-dataset

In [None]:
! mkdir data
! unzip inria-aerial-image-labeling-dataset.zip -d data

In [None]:
import einops
from tqdm.notebook import tqdm
import torch
import torchvision
from torch import nn
import torch.optim as optim
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F
import torch.utils
import torch.utils.data
from PIL import Image
import os
import matplotlib.pyplot as plt

In [None]:
!jupyter nbextension enable --py widgetsnbextension

### Set device and hyperparameters

In [None]:
# Set device to run on GPU if available
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

# Set hyperparameters
patch_size = 16
latent_size = 768 # Embedding dimension (or latent vector) (16x16x3) (patch size 16x16) x3 color channels
n_channels = 3 # Number of channels for input images
num_heads = 12 # Number of head
num_encoders = 12 # Number of encoder layers
dropout = 0.1
size = 224  # Size of the input image
num_labels = 1 # Number of output labels (Building/Not building)

epochs = 40
lr = 1e-3   # Learning rate
weight_decay = 0.03    # Weight decay
batch_size = 4

### Preprocess the dataset

#### For training data

In [None]:
# Calculate mean and standard deviation of the training images across all channels (R, G, B) for normalizing the dataset
def compute_mean_std(image_dir):
    training_images = os.listdir(image_dir)

    # Initialize mean and std
    mean = torch.zeros(3)
    std = torch.zeros(3)
    num_pixels = 0

    # Process each image
    for training_image in tqdm(training_images, desc="Processing images"):
        img_path = os.path.join(image_dir, training_image)
        image = Image.open(img_path).convert('RGB')

        # Convert image to tensor
        img_tensor = ToTensor()(image) # Convert image to (C, H, W) tensor
        # Calculate number of pixels
        num_pixels += img_tensor.size(1) * img_tensor.size(2)
        # Sum the mean and squared mean of each channel
        mean += img_tensor.sum(dim=[1, 2])
        std += (img_tensor ** 2).sum(dim=[1, 2])

    # Calculate mean and standard deviation
    mean /= num_pixels
    std = (std / num_pixels - mean ** 2).sqrt()

    return mean.tolist(), std.tolist()

image_dir = "data/AerialImageDataset/train/images"
mean, std = compute_mean_std(image_dir)
print(f"Mean: {mean}")
print(f"Standard deviation: {std}")


In [None]:
# Resize the input images to 224x224
transform_training_data = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.4048401415348053, 0.427262544631958, 0.3927135467529297],
              std=[0.20133039355278015, 0.1835126429796219, 0.17614711821079254])
])

In [None]:
# Class for INRIA training dataset
class INRIATrainDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L') # Convert mask to grayscale
        if self.transform:
            image = self.transform(image)
            # Resize the mask to (224, 224) and convert to tensor
            mask = Resize((224, 224))(mask)  # Resize the mask
            mask = F.to_tensor(mask)

        return image, mask

image_dir = 'data/AerialImageDataset/train/images'
mask_dir = 'data/AerialImageDataset/train/gt'

train_data = INRIATrainDataset(image_dir, mask_dir, transform=transform_training_data)
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

### Visualize the training data (image and mask)

In [None]:
# Unnormalize the dataset for visualization
def unnormalized(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# Visualize the some preprocessed data
def visualize_dataset(dataset, num_samples=5):
    mean = [0.4048401415348053, 0.427262544631958, 0.3927135467529297]
    std = [0.20133039355278015, 0.1835126429796219, 0.17614711821079254]

    fig, axis = plt.subplots(num_samples, 2, figsize=(10, num_samples * 5))
    for i in range(num_samples):
        image, mask = dataset[i]

        # Unnormalize image for visualization
        image = unnormalized(image.clone(), mean, std)  # Clone to avoid modify the preprocessed data
        image_np = image.permute(1, 2, 0).cpu().numpy()  # Convert image tensor to (height, width, channel)
        mask_np = mask.squeeze().cpu().numpy()    # Convert mask tensor to numpy

        # Show the image and corresponding mask
        axis[i, 0].imshow(image_np)
        axis[i, 0].set_title(f"Image {i+1}")
        axis[i, 0].axis('off')

        axis[i, 1].imshow(mask_np, cmap='gray')
        axis[i, 1].set_title(f"Mask {i+1}")
        axis[i, 1].axis('off')
    plt.show()

visualize_dataset(train_data)



#### For test images

In [None]:
# Calculate mean and standard deviation of the test images across all channels (R, G, B) for normalizing the dataset
def compute_mean_std(image_dir):
    test_images = os.listdir(image_dir)

    # Initialize mean and std
    mean = torch.zeros(3)
    std = torch.zeros(3)
    num_pixels = 0

    # Process each image
    for training_image in tqdm(test_images, desc="Processing test images"):
        img_path = os.path.join(image_dir, training_image)
        image = Image.open(img_path).convert('RGB')

        # Convert image to tensor
        img_tensor = ToTensor()(image) # Convert image to (C, H, W) tensor
        # Calculate number of pixels
        num_pixels += img_tensor.size(1) * img_tensor.size(2)
        # Sum the mean and squared mean of each channel
        mean += img_tensor.sum(dim=[1, 2])
        std += (img_tensor ** 2).sum(dim=[1, 2])

    # Calculate mean and standard deviation
    mean /= num_pixels
    std = (std / num_pixels - mean ** 2).sqrt()

    return mean.tolist(), std.tolist()

image_dir = "data/AerialImageDataset/test/images"
mean, std = compute_mean_std(image_dir)
print(f"Mean: {mean}")
print(f"Standard deviation: {std}")


In [None]:
# Define transformations (e.g., resize, normalize) for test set images
test_transforms = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.4543627202510834, 0.47366490960121155, 0.4127490520477295],
              std=[0.20975154638290405, 0.1924573928117752, 0.18913407623767853])
])


In [None]:
# Class for INRIA test dataset
class INRIATestDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image

image_dir = 'data/AerialImageDataset/test/images'

test_data = INRIATestDataset(image_dir, transform=transform_training_data)


### Visualize test images

In [None]:
# Unnormalize the dataset for visualization
def unnormalized(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# Visualize the some preprocessed data
def visualize_dataset(dataset, num_samples=5):
    mean = [0.4543627202510834, 0.47366490960121155, 0.4127490520477295]
    std = [0.20975154638290405, 0.1924573928117752, 0.18913407623767853]

    fig, axis = plt.subplots(num_samples, 1, figsize=(10, num_samples * 5))
    for i in range(num_samples):
        image = dataset[i]

        # Unnormalize image for visualization
        image = unnormalized(image.clone(), mean, std)  # Clone to avoid modify the preprocessed data
        image_np = image.permute(1, 2, 0).cpu().numpy()  # Convert image tensor to (height, width, channel)

        # Show the image and corresponding mask
        axis[i].imshow(image_np)
        axis[i].set_title(f"Test Image {i+1}")
        axis[i].axis('off')

    plt.show()

visualize_dataset(test_data)

## Building Vision Transformer model

### Input Embedding class. This class performs all the steps needed before the data goes into the ViT's encoder block. This includes splitting input images into patches, performing the linear projections of the patches to convert patches into vectors, adding a position embedding to the linear projection to provide spatial information about where each patch comes from the image. The output of this class will be fed into the Encoder.

In [None]:
# Create a subclass of Module (which is the base class for neural network modules)
class InputEmbedding(nn.Module):
    def __init__(self, patch_size=patch_size, n_channels=n_channels, device=device, latent_size=latent_size, batch_size=batch_size):
        super(InputEmbedding, self).__init__()
        self.laten_size = latent_size
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.device = device
        self.batch_size = batch_size
        self.input_size = self.patch_size * self.patch_size * self.n_channels

        self.linearProjection = nn.Linear(self.input_size, self.laten_size)

        # Positional embedding
        self.num_patches = (size // self.patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, self.laten_size)).to(self.device)

    def forward(self, input_data):
        input_data = input_data.to(self.device)
        # Rearrange the image into patches
        patches = einops.rearrange(input_data,
                                   'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
        linear_projection = self.linearProjection(patches).to(self.device)
        b, n, _ = linear_projection.shape       # Extract batch size, number of patches

        # Add positional embedding to linear projection
        linear_projection += self.pos_embedding
        return linear_projection

### Transformer Encoder class. The Transformer Encoder is composed of two main layers: Multi-Head Self-Attention and Multi-Layer Perceptron. Before passing patch embeddings through these two layers, we apply Layer Normalization and right after passing embeddings through both layers, we apply Residual Connection. There are 12 Transformer Encoders.



In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, latent_size=latent_size, num_heads=num_heads, device=device, dropout=dropout):
        super(TransformerEncoder, self).__init__()
        self.latent_size = latent_size
        self.num_heads = num_heads
        self.device = device
        self.dropout = dropout

        # Normalization layers
        self.norm1 = nn.LayerNorm(self.latent_size)
        self.norm2 = nn.LayerNorm(self.latent_size)

        # Multi-head attention layer
        self.multihead = nn.MultiheadAttention(self.latent_size, self.num_heads, dropout=self.dropout)

        # MLP_head layer in the encoder. The ViT-Base variant uses MLP_head size 3072, which is latent_size*4
        self.enc_MLP = nn.Sequential(
            nn.Linear(self.latent_size, self.latent_size*4),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latent_size*4, self.latent_size),
            nn.Dropout(self.dropout)
        )

    def forward(self, embedded_patches):
        # First sublayer: Norm + Multi-Head Attention + residual connection.
        first_norm = self.norm1(embedded_patches)
        attention_output, _ = self.multihead(first_norm, first_norm, first_norm)

        # First residual connection
        first_added_output = attention_output + embedded_patches

        # Second sublayer: Norm + enc_MLP (Feed forward)
        second_norm = self.norm2(first_added_output)
        ff_output = self.enc_MLP(second_norm)

        # Return the output of the second residual connection
        return ff_output + first_added_output


### Put together the whole Vision Transformer. What's added to the input embedding layer and the encoder stack here is the output MLP head, which is used for segmentation at the end of the whole model

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, num_encoders=num_encoders, latent_size=latent_size, device=device, num_labels=num_labels, dropout=dropout):
        super(VisionTransformer, self).__init__()
        self.num_encoders = num_encoders
        self.latent_size = latent_size
        self.device = device
        self.num_labels = num_labels
        self.dropout = dropout

        self.embedding = InputEmbedding()

        # Create a stack of encoder layers
        self.enc_stack = nn.ModuleList([TransformerEncoder() for _ in range(self.num_encoders)])

        # Segmentation head - reshape and output segmentation map
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(latent_size, latent_size // 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(latent_size // 2, num_labels, kernel_size=1),  # n_classes for segmentation output
            nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False)  # Upsampling back to original resolution
        )

    def forward(self, test_input):
        # Apply input embedding (patchify + linear projection + postional embedding)
        enc_output = self.embedding(test_input)

        # Loop through all encoder layers
        for enc_layer in self.enc_stack:
            enc_output = enc_layer(enc_output)

        # Reshape and permute to match image dimensions
        batch_size, num_patches, latent_size = enc_output.shape
        H = W = size // patch_size
        enc_output = enc_output.permute(0, 2, 1).contiguous().view(batch_size, latent_size, H, W)

        # Apply segmentation head
        seg_output = self.segmentation_head(enc_output)

        return seg_output


IoU metric

In [None]:
def iou_score(preds, targets, threshold=0.5):
    preds = torch.sigmoid(preds)  # Apply sigmoid but don't threshold yet
    preds = (preds > threshold).float()  # Threshold after sigmoid
    preds_flat = preds.view(-1)
    targets_flat = targets.view(-1)
    intersection = (preds_flat * targets_flat).sum()
    union = preds_flat.sum() + targets_flat.sum() - intersection
    return (intersection + 1e-6) / (union + 1e-6)


### Call model

In [None]:
model = VisionTransformer().to(device)
# Use BCEWithLogitsLoss for segmentation
criterion = nn.BCEWithLogitsLoss()
iou_metric = iou_score
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)


In [None]:
import matplotlib.pyplot as plt

# Store IoU scores for each epoch
iou_scores = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    running_iou = 0.0
    for inputs, masks in trainloader:
        inputs, masks = inputs.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, masks)

        # Backpropagation
        loss.backward()
        optimizer.step()

        # Compute running loss and IoU
        running_loss += loss.item()

        # Move IoU to CPU and convert to scalar
        running_iou += iou_metric(outputs, masks).cpu().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_iou = running_iou / len(trainloader)
    iou_scores.append(epoch_iou)  # Save IoU for visualization

    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, IoU: {epoch_iou:.4f}")

    # Adjust learning rate
    scheduler.step(epoch_loss)

# Save model after training
torch.save(model.state_dict(), 'vit_segmentation_model.pth')

# Visualize IoU score over epochs
plt.figure(figsize=(8, 6))
plt.plot(range(1, epochs+1), iou_scores, marker='o', label='IoU')
plt.title('IoU Score Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('IoU Score')
plt.grid(True)
plt.legend()
plt.show()


## Predict on test set images

In [None]:
# Function to make predictions on the test set
def predict_and_visualize(model, testloader, num_samples=5):
    mean = [0.4543627202510834, 0.47366490960121155, 0.4127490520477295]
    std = [0.20975154638290405, 0.1924573928117752, 0.18913407623767853]
    model.eval()  # Set the model to evaluation mode

    fig, axis = plt.subplots(num_samples, 2, figsize=(10, num_samples * 5))
    with torch.no_grad():
        for i, image in enumerate(testloader):
            if i == num_samples:
                break

            image = image.to(device)
            pred_mask = model(image)  # Make prediction
            pred_mask = torch.sigmoid(pred_mask)  # Apply sigmoid to get values in [0, 1]
            pred_mask = pred_mask.squeeze().cpu().numpy()  # Convert to numpy array

            # Unnormalize the image for display
            image = unnormalized(image.clone(), mean, std)
            image_np = image.squeeze().permute(1, 2, 0).cpu().numpy()

            # Display the image and predicted mask
            axis[i, 0].imshow(image_np)
            axis[i, 0].set_title(f"Test Image {i+1}")
            axis[i, 0].axis('off')

            axis[i, 1].imshow(pred_mask, cmap='gray')
            axis[i, 1].set_title(f"Predicted Mask {i+1}")
            axis[i, 1].axis('off')

    plt.show()

# Load test data and visualize
testloader = DataLoader(test_data, batch_size=1, shuffle=False)
predict_and_visualize(model, testloader)