In [1]:
import os
import cv2
import numpy as np
import pandas as pd
import random, tqdm
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
import albumentations as album

In [2]:
import segmentation_models_pytorch as smp

In [3]:
DATA_DIR = '..data/unet_train/input/inria-aerial-image-labeling-dataset/AerialImageDataset'

x_train_dir = os.path.join(DATA_DIR, 'train/images')
y_train_dir = os.path.join(DATA_DIR, 'train/gt')


In [4]:
class_names = ['background', 'building']
class_rgb_values = [[0,0,0],[255,255,255]]

print('All dataset classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

All dataset classes and their corresponding RGB values in labels:
Class Names:  ['background', 'building']
Class RGB values:  [[0, 0, 0], [255, 255, 255]]


In [5]:
select_class_indices = [class_names.index(cls.lower()) for cls in class_names]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

print('Selected classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

Selected classes and their corresponding RGB values in labels:
Class Names:  ['background', 'building']
Class RGB values:  [[0, 0, 0], [255, 255, 255]]


In [6]:
# helper function for data visualization
def visualize(**images):
    """
    Plot images in one row
    """
    n_images = len(images)
    plt.figure(figsize=(20,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

# perform one hot encoding on label
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map
    
# perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis = -1)
    return x

# perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values

    # Returns
        Colour coded image for segmentation visualization
    """
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x

In [7]:
class BuildingsDataset(torch.utils.data.Dataset):

    """Inria Arial Image Buildings Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_rgb_values (list): RGB values of select classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            class_rgb_values=None, 
            augmentation=None, 
            preprocessing=None,
            crop_size=512,
    ):
        
        self.image_paths = [os.path.join(images_dir, image_id) for image_id in sorted(os.listdir(images_dir))]
        self.mask_paths = [os.path.join(masks_dir, image_id) for image_id in sorted(os.listdir(masks_dir))]

        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.crop_size = crop_size
    
    def __getitem__(self, i):
        
        # read images and masks
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        # one-hot-encode the mask
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')

        assert image.shape[:2] == mask.shape[:2], "Image and mask size mismatch!"
        image, mask = self.random_crop(image, mask, self.crop_size)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.image_paths)

    def random_crop(self, image, mask, crop_size):
        """Randomly crops the same region from both image and mask"""
        h, w = image.shape[:2]
        if h < crop_size or w < crop_size:
            raise ValueError("Crop size is larger than image size!")

        # Randomly select top-left corner for the crop
        x = random.randint(0, w - crop_size)
        y = random.randint(0, h - crop_size)

        # Crop both image and mask
        image_cropped = image[y:y+crop_size, x:x+crop_size]
        mask_cropped = mask[y:y+crop_size, x:x+crop_size]

        return image_cropped, mask_cropped


In [8]:
def get_training_augmentation():
    train_transform = [    
        album.OneOf(
            [
                album.HorizontalFlip(p=1),
                album.VerticalFlip(p=1),
                album.RandomRotate90(p=1),
            ],
            p=0.75,
        ),
    ]
    return album.Compose(train_transform, additional_targets={'mask': 'mask'})


def get_validation_augmentation():   
    # Add sufficient padding to ensure image is divisible by 32
    test_transform = [
        album.PadIfNeeded(min_height=512, min_width=512, always_apply=True, border_mode=0),
    ]
    return album.Compose(test_transform, additional_targets={'mask': 'mask'})


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def grayscale_preprocessing(image, **kwargs):
    """
    Convert an image to grayscale.
    
    Args:
        image (numpy array): Input image in RGB format.
    
    Returns:
        numpy array: Grayscale image with shape (H, W, 1).
    """
    image_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    image_gray = np.expand_dims(image_gray, axis=-1)
    image_gray = np.repeat(image_gray, 3, axis=-1)
    return image_gray


def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """   
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))

    return album.Compose(_transform)

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class InceptionResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InceptionResNetBlock, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
        )

        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 4, out_channels // 4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
        )

        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 4, out_channels // 4, kernel_size=5, stride=1, padding=2, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
        )

        self.conv1x1 = nn.Conv2d(out_channels // 4 * 3, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # ensure residual connection matches output dimensions
        self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.residual_bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = self.residual_bn(self.residual_conv(x))  # Ensure matching dimensions
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.conv1x1(x)
        x = self.bn(x)
        x += residual
        x = self.relu(x)
        return x

class UNetDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetDecoderBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = ConvBlock(out_channels * 2, out_channels)
    
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

class InceptionResNetUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2):
        super(InceptionResNetUNet, self).__init__()

        # encoder
        self.enc1 = InceptionResNetBlock(in_channels, 64)
        self.enc2 = InceptionResNetBlock(64, 128)
        self.enc3 = InceptionResNetBlock(128, 256)
        self.enc4 = InceptionResNetBlock(256, 512)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # bottleneck
        self.bottleneck = InceptionResNetBlock(512, 1024)
        
        # decoder
        self.dec4 = UNetDecoderBlock(1024, 512)
        self.dec3 = UNetDecoderBlock(512, 256)
        self.dec2 = UNetDecoderBlock(256, 128)
        self.dec1 = UNetDecoderBlock(128, 64)
        
        # output Layer
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        
        bottleneck = self.bottleneck(self.pool(enc4))
        
        dec4 = self.dec4(bottleneck, enc4)
        dec3 = self.dec3(dec4, enc3)
        dec2 = self.dec2(dec3, enc2)
        dec1 = self.dec1(dec2, enc1)
        
        out = self.out_conv(dec1)
        out = torch.sigmoid(out)
        return out
    
# initalize
model = InceptionResNetUNet(in_channels=3, out_channels=2)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = torch.nn.DataParallel(model)
    model.eval()
    print('loaded model')

Using 2 GPUs!
loaded model


In [10]:
# define proportions
train_ratio = 0.9
val_ratio = 0.05
test_ratio = 0.05

full_dataset = BuildingsDataset(
    x_train_dir, y_train_dir, 
    class_rgb_values=select_class_rgb_values,
    augmentation=None,
    preprocessing=None
)

total_size = len(full_dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size  

train_indices, val_indices, test_indices = random_split(range(total_size), [train_size, val_size, test_size])

train_dataset = Subset(full_dataset, train_indices)
valid_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

train_dataset.dataset.augmentation = get_training_augmentation()
valid_dataset.dataset.augmentation = get_validation_augmentation()
test_dataset.dataset.augmentation = get_validation_augmentation()

train_dataset.dataset.preprocessing = get_preprocessing(preprocessing_fn=None)
valid_dataset.dataset.preprocessing = get_preprocessing(preprocessing_fn=None)
test_dataset.dataset.preprocessing = get_preprocessing(preprocessing_fn=None)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=12, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

In [11]:
# Set flag to train the model or not. If set to 'False', only prediction is performed (using an older model checkpoint)
TRAINING = True

# Set num of epochs
EPOCHS = 40

# Set device: `cuda` or `cpu`
DEVICE = torch.device("cuda")

# define loss function
loss = smp.utils.losses.DiceLoss()

# define metrics
metrics = [
    smp.utils.metrics.IoU(threshold=0.7),
]

# define optimizer
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.00008),
])

# define learning rate scheduler (not used in this NB)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=5e-5,
)

In [12]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [13]:
best_iou_score = 0.0
train_logs_list, valid_logs_list = [], []

In [None]:
if TRAINING:
    train_device = torch.device("cuda:0")
    valid_device = torch.device("cuda:1")

    model.to(train_device)

    for i in range(0, EPOCHS):

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run((inputs.to(train_device), targets.to(train_device)) for inputs, targets in train_loader)
        train_logs_list.append(train_logs)
        del train_logs
        with torch.no_grad():
            torch.cuda.empty_cache()
            valid_logs = valid_epoch.run((inputs.to(valid_device), targets.to(valid_device)) for inputs, targets in valid_loader)
        
            valid_logs_list.append(valid_logs)
        lr_scheduler.step()

        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model, './best_model.pth')
            print('Model saved!')


Epoch: 0
train: 11it [02:24, 13.12s/it, dice_loss - 0.3426, iou_score - 0.475] 
valid: 9it [00:07,  1.13it/s, dice_loss - 0.3062, iou_score - 0.606] 
Model saved!

Epoch: 1
train: 11it [02:22, 12.92s/it, dice_loss - 0.3304, iou_score - 0.5148]
valid: 9it [00:08,  1.03it/s, dice_loss - 0.3177, iou_score - 0.5887]

Epoch: 2
train: 11it [02:33, 13.93s/it, dice_loss - 0.3191, iou_score - 0.5513]
valid: 9it [00:08,  1.05it/s, dice_loss - 0.3347, iou_score - 0.5132]

Epoch: 3
train: 11it [02:22, 12.98s/it, dice_loss - 0.3152, iou_score - 0.5563]
valid: 9it [00:07,  1.13it/s, dice_loss - 0.3153, iou_score - 0.5333]

Epoch: 4
train: 11it [02:22, 12.96s/it, dice_loss - 0.3091, iou_score - 0.6368]
valid: 9it [00:07,  1.13it/s, dice_loss - 0.3305, iou_score - 0.5878]

Epoch: 5
train: 11it [02:22, 12.96s/it, dice_loss - 0.3031, iou_score - 0.6711]
valid: 9it [00:07,  1.13it/s, dice_loss - 0.336, iou_score - 0.5994] 

Epoch: 6
train: 11it [02:23, 13.07s/it, dice_loss - 0.2915, iou_score - 0.7006]


In [None]:
# load best saved model checkpoint from the current run
if os.path.exists('./best_model.pth'):
    best_model = torch.load('./best_model.pth', map_location=DEVICE)
    print('Loaded UNet model from this run.')

# load best saved model checkpoint from previous commit (if present)
elif os.path.exists('../input/unet-for-building-segmentation-pytorch/best_model.pth'):
    best_model = torch.load('../input/unet-for-building-segmentation-pytorch/best_model.pth', map_location=DEVICE)
    print('Loaded UNet model from a previous commit.')

In [None]:
# center crop padded image / mask to original image dims
def crop_image(image, target_image_dims=[512,512,3]):
   
    target_size = target_image_dims[0]
    image_size = len(image)
    padding = (image_size - target_size) // 2

    return image[
        padding:image_size - padding,
        padding:image_size - padding,
        :,
    ]

In [None]:
sample_preds_folder = '../data/unet_train/sample_predictions/'
if not os.path.exists(sample_preds_folder):
    os.makedirs(sample_preds_folder)

In [None]:
for idx in range(len(test_dataset)):
    
    random_idx = random.randint(0, len(test_dataset)-1)
    image, gt_mask = test_dataset[random_idx]
    image_vis = crop_image(image.astype('uint8'))
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)

    pred_mask = best_model(x_tensor)
    pred_mask = pred_mask.detach().squeeze().cpu().numpy()
    # Convert pred_mask from `CHW` format to `HWC` format
    pred_mask = np.transpose(pred_mask,(1,2,0))
    # Get prediction channel corresponding to building
    pred_building_heatmap = pred_mask[:,:,0]
    pred_mask = crop_image(colour_code_segmentation(reverse_one_hot(pred_mask), select_class_rgb_values))
    # Convert gt_mask from `CHW` format to `HWC` format
    gt_mask = np.transpose(gt_mask,(1,2,0))
    gt_mask = crop_image(colour_code_segmentation(reverse_one_hot(gt_mask), select_class_rgb_values))
    cv2.imwrite(os.path.join(sample_preds_folder, f"sample_pred_{idx}.png"), np.hstack([image_vis, gt_mask, pred_mask, pred_building_heatmap])[:,:,::-1])
    
    visualize(
        original_image = image_vis,
        ground_truth_mask = gt_mask,
        predicted_mask = pred_mask,
        predicted_building_heatmap = pred_building_heatmap
    )