# Implementation of Vision Transformer for building segmentation from INRIA aerial image labeling dataset using PyTorch with SegFormer pre-trained weight

### Import libraries

In [None]:
! kaggle datasets download -d sagar100rathod/inria-aerial-image-labeling-dataset

In [None]:
! mkdir data
! unzip inria-aerial-image-labeling-dataset.zip -d data

In [None]:
import einops
from tqdm.notebook import tqdm
import torch
import torchvision
from torch import nn
import torch.optim as optim
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F
import torch.utils
import torch.utils.data
from PIL import Image
import os
import matplotlib.pyplot as plt

In [None]:
!jupyter nbextension enable --py widgetsnbextension

### Set device and hyperparameters

In [None]:
# Set device to run on GPU if available
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

# Set hyperparameters
patch_size = 16
latent_size = 768 # Embedding dimension (or latent vector) (16x16x3) (patch size 16x16) x3 color channels
n_channels = 3 # Number of channels for input images
num_heads = 12 # Number of head
num_encoders = 12 # Number of encoder layers
dropout = 0.1
size = 224  # Size of the input image
num_labels = 1 # Number of output labels (Building/Not building)

epochs = 40
lr = 1e-3   # Learning rate
weight_decay = 0.03    # Weight decay
batch_size = 4

### Preprocess the dataset

#### For training data

In [None]:
# Calculate mean and standard deviation of the training images across all channels (R, G, B) for normalizing the dataset
def compute_mean_std(image_dir):
    training_images = os.listdir(image_dir)

    # Initialize mean and std
    mean = torch.zeros(3)
    std = torch.zeros(3)
    num_pixels = 0

    # Process each image
    for training_image in tqdm(training_images, desc="Processing images"):
        img_path = os.path.join(image_dir, training_image)
        image = Image.open(img_path).convert('RGB')

        # Convert image to tensor
        img_tensor = ToTensor()(image) # Convert image to (C, H, W) tensor
        # Calculate number of pixels
        num_pixels += img_tensor.size(1) * img_tensor.size(2)
        # Sum the mean and squared mean of each channel
        mean += img_tensor.sum(dim=[1, 2])
        std += (img_tensor ** 2).sum(dim=[1, 2])

    # Calculate mean and standard deviation
    mean /= num_pixels
    std = (std / num_pixels - mean ** 2).sqrt()

    return mean.tolist(), std.tolist()

image_dir = "data/AerialImageDataset/train/images"
mean, std = compute_mean_std(image_dir)
print(f"Mean: {mean}")
print(f"Standard deviation: {std}")


In [None]:
# Resize the input images to 224x224
transform_training_data = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.4048401415348053, 0.427262544631958, 0.3927135467529297],
              std=[0.20133039355278015, 0.1835126429796219, 0.17614711821079254])
])

In [None]:
# Class for INRIA training dataset
class INRIATrainDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L') # Convert mask to grayscale
        if self.transform:
            image = self.transform(image)
            # Resize the mask to (224, 224) and convert to tensor
            mask = Resize((56, 56))(mask)  # Resize the mask
            mask = F.to_tensor(mask)

        return image, mask

image_dir = 'data/AerialImageDataset/train/images'
mask_dir = 'data/AerialImageDataset/train/gt'

train_data = INRIATrainDataset(image_dir, mask_dir, transform=transform_training_data)
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

### Visualize the training data (image and mask)

In [None]:
# Unnormalize the dataset for visualization
def unnormalized(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# Visualize the some preprocessed data
def visualize_dataset(dataset, num_samples=5):
    mean = [0.4048401415348053, 0.427262544631958, 0.3927135467529297]
    std = [0.20133039355278015, 0.1835126429796219, 0.17614711821079254]

    fig, axis = plt.subplots(num_samples, 2, figsize=(10, num_samples * 5))
    for i in range(num_samples):
        image, mask = dataset[i]

        # Unnormalize image for visualization
        image = unnormalized(image.clone(), mean, std)  # Clone to avoid modify the preprocessed data
        image_np = image.permute(1, 2, 0).cpu().numpy()  # Convert image tensor to (height, width, channel)
        mask_np = mask.squeeze().cpu().numpy()    # Convert mask tensor to numpy

        # Show the image and corresponding mask
        axis[i, 0].imshow(image_np)
        axis[i, 0].set_title(f"Image {i+1}")
        axis[i, 0].axis('off')

        axis[i, 1].imshow(mask_np, cmap='gray')
        axis[i, 1].set_title(f"Mask {i+1}")
        axis[i, 1].axis('off')
    plt.show()

visualize_dataset(train_data)



#### For test images

In [None]:
# Calculate mean and standard deviation of the test images across all channels (R, G, B) for normalizing the dataset
def compute_mean_std(image_dir):
    test_images = os.listdir(image_dir)

    # Initialize mean and std
    mean = torch.zeros(3)
    std = torch.zeros(3)
    num_pixels = 0

    # Process each image
    for training_image in tqdm(test_images, desc="Processing test images"):
        img_path = os.path.join(image_dir, training_image)
        image = Image.open(img_path).convert('RGB')

        # Convert image to tensor
        img_tensor = ToTensor()(image) # Convert image to (C, H, W) tensor
        # Calculate number of pixels
        num_pixels += img_tensor.size(1) * img_tensor.size(2)
        # Sum the mean and squared mean of each channel
        mean += img_tensor.sum(dim=[1, 2])
        std += (img_tensor ** 2).sum(dim=[1, 2])

    # Calculate mean and standard deviation
    mean /= num_pixels
    std = (std / num_pixels - mean ** 2).sqrt()

    return mean.tolist(), std.tolist()

image_dir = "data/AerialImageDataset/test/images"
mean, std = compute_mean_std(image_dir)
print(f"Mean: {mean}")
print(f"Standard deviation: {std}")


In [None]:
# Define transformations (e.g., resize, normalize) for test set images
test_transforms = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.4543627202510834, 0.47366490960121155, 0.4127490520477295],
              std=[0.20975154638290405, 0.1924573928117752, 0.18913407623767853])
])


In [None]:
# Class for INRIA test dataset
class INRIATestDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image

image_dir = 'data/AerialImageDataset/test/images'

test_data = INRIATestDataset(image_dir, transform=transform_training_data)


### Visualize test images

In [None]:
# Unnormalize the dataset for visualization
def unnormalized(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# Visualize the some preprocessed data
def visualize_dataset(dataset, num_samples=5):
    mean = [0.4543627202510834, 0.47366490960121155, 0.4127490520477295]
    std = [0.20975154638290405, 0.1924573928117752, 0.18913407623767853]

    fig, axis = plt.subplots(num_samples, 1, figsize=(10, num_samples * 5))
    for i in range(num_samples):
        image = dataset[i]

        # Unnormalize image for visualization
        image = unnormalized(image.clone(), mean, std)  # Clone to avoid modify the preprocessed data
        image_np = image.permute(1, 2, 0).cpu().numpy()  # Convert image tensor to (height, width, channel)

        # Show the image and corresponding mask
        axis[i].imshow(image_np)
        axis[i].set_title(f"Test Image {i+1}")
        axis[i].axis('off')

    plt.show()

visualize_dataset(test_data)

## Load the SegFormer pre-trained weight

In [None]:
from transformers import SegformerForSemanticSegmentation, SegformerConfig
# Pre-trained SegFormer (from Hugging Face)
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0",
    num_labels=num_labels
).to(device)

In [None]:
# BCELossWithLogits for binary classification
criterion = nn.BCEWithLogitsLoss()

# Adam optimizer
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

def iou_score(preds, targets, threshold=0.5):
    # Apply sigmoid and threshold to get binary predictions
    preds = torch.sigmoid(preds) > threshold

    # Flatten predictions and targets
    preds_flat = preds.view(-1).float()
    targets_flat = targets.view(-1).float()

    intersection = (preds_flat * targets_flat).sum()
    union = preds_flat.sum() + targets_flat.sum() - intersection

    return (intersection + 1e-6) / (union + 1e-6)



In [None]:
# Fine-tune SegFormer
def train_model(model, dataloader, criterion, optimizer, epochs=10):
    model.train()

    # Lists to store the loss and IoU scores for each epoch
    loss_history = []
    iou_history = []

    for epoch in range(epochs):
        running_loss = 0.0
        running_iou = 0.0

        for images, masks in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, masks = images.to(device), masks.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images).logits
            loss = criterion(outputs, masks)
            iou = iou_score(outputs, masks)

            # Backward pass
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_iou += iou.item()

        epoch_loss = running_loss / len(dataloader)
        epoch_iou = running_iou / len(dataloader)

        loss_history.append(epoch_loss)
        iou_history.append(epoch_iou)

        print(f"Epoch {epoch+1} Loss: {epoch_loss}, IoU: {epoch_iou}")

    # Plot IoU score over epochs
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epochs+1), iou_history, label="IoU Score")
    plt.xlabel("Epoch")
    plt.ylabel("IoU Score")
    plt.title("IoU Score Over Epochs")
    plt.legend()
    plt.show()

# Train the model
train_model(model, trainloader, criterion, optimizer, epochs=40)

## Predict on test set images

In [None]:
# Function to make predictions on the test set
def predict_and_visualize(model, testloader, num_samples=5):
    mean = [0.4543627202510834, 0.47366490960121155, 0.4127490520477295]
    std = [0.20975154638290405, 0.1924573928117752, 0.18913407623767853]
    model.eval()  # Set the model to evaluation mode

    fig, axis = plt.subplots(num_samples, 2, figsize=(10, num_samples * 5))
    with torch.no_grad():
        for i, image in enumerate(testloader):
            if i == num_samples:
                break

            image = image.to(device)
            pred_mask = model(image)  # Make prediction
            # Access the logits from the SemanticSegmenterOutput object
            logits = pred_mask.logits  # This extracts the tensor from the output object

            # Apply sigmoid for binary segmentation
            probabilities = torch.sigmoid(logits)

            # Apply threshold to convert probabilities into binary predictions
            pred_mask = (probabilities > 0.5).float()
            pred_mask = pred_mask.squeeze().cpu().numpy()  # Convert to numpy array

            # Unnormalize the image for display
            image = unnormalized(image.clone(), mean, std)
            image_np = image.squeeze().permute(1, 2, 0).cpu().numpy()

            # Display the image and predicted mask
            axis[i, 0].imshow(image_np)
            axis[i, 0].set_title(f"Test Image {i+1}")
            axis[i, 0].axis('off')

            axis[i, 1].imshow(pred_mask, cmap='gray')
            axis[i, 1].set_title(f"Predicted Mask {i+1}")
            axis[i, 1].axis('off')

    plt.show()

# Load test data and visualize
testloader = DataLoader(test_data, batch_size=1, shuffle=False)
predict_and_visualize(model, testloader)