# Segment and Embed geotif orthophotographs

In [None]:
import rasterio
from rasterio.windows import Window
from rasterio.features import shapes
from rasterio.mask import mask as rasterio_mask
from rasterio.plot import show

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import models

import albumentations as album
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2

import cv2
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os
import geopandas as gpd
from shapely.geometry import shape
import json
import osmnx as ox
import glob
import pandas as pd

In [None]:
import warnings
warnings.filterwarnings('ignore')

# Slice geotiff for UNet model

In [None]:
TILE_SIZE = 512

def slice_geotiff(input_tiff, output_folder, tile_size=TILE_SIZE):
    os.makedirs(output_folder, exist_ok=True)

    with rasterio.open(input_tiff) as src:
        meta = src.meta.copy()

        img_width, img_height = src.width, src.height

        x_tiles = img_width // tile_size
        y_tiles = img_height // tile_size

        if img_width % tile_size != 0:
            x_tiles += 1
        if img_height % tile_size != 0:
            y_tiles += 1

        print(f"Slicing into {x_tiles}x{y_tiles} tiles")

        for i in range(x_tiles):
            for j in range(y_tiles):
                # compute window boundaries
                x_offset = i * tile_size
                y_offset = j * tile_size
                width = min(tile_size, img_width - x_offset)
                height = min(tile_size, img_height - y_offset)

                window = Window(x_offset, y_offset, width, height)
                data = src.read(window=window)

                meta.update({
                    'width': width,
                    'height': height,
                    'transform': rasterio.windows.transform(window, src.transform)
                })

                inp_tif_title = input_tiff.split('/')[-1].split('.')[0]
                output_path = os.path.join(output_folder, f"{inp_tif_title}_tile_{i}_{j}.tif")
                with rasterio.open(output_path, "w", **meta) as dst:
                    dst.write(data)

                print(f"Saved: {output_path}")

In [None]:
datasets = [
    {'input_folder': '../data/EE/orto/2020', 'output_folder': '../data/sliced/EE/2020', 'vector_outputs': '../data/sliced/EE/2020_vectors'},
    {'input_folder': '../data/EE/orto/2024', 'output_folder': '../data/sliced/EE/2024', 'vector_outputs': '../data/sliced/EE/2024_vectors'},
    {'input_folder': '../data/LT/orto/2020', 'output_folder': '../data/sliced/LT/2020', 'vector_outputs': '../data/sliced/LT/2020_vectors'},
    {'input_folder': '../data/LT/orto/2024', 'output_folder': '../data/sliced/LT/2024', 'vector_outputs': '../data/sliced/LT/2024_vectors'},
]

for dataset in datasets:
    input_folder = dataset['input_folder']
    output_folder = dataset['output_folder']
    tif_files = glob.glob(os.path.join(input_folder, "*.tif"))
    for input_tiff in tif_files:
        slice_geotiff(input_tiff, output_folder)

print('finished slicing')

# Segmentation using UNet

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class InceptionResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InceptionResNetBlock, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
        )

        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 4, out_channels // 4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
        )

        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 4, out_channels // 4, kernel_size=5, stride=1, padding=2, bias=False),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True),
        )

        self.conv1x1 = nn.Conv2d(out_channels // 4 * 3, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.residual_bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = self.residual_bn(self.residual_conv(x))
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.conv1x1(x)
        x = self.bn(x)
        x += residual
        x = self.relu(x)
        return x

class UNetDecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetDecoderBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = ConvBlock(out_channels * 2, out_channels)
    
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

class InceptionResNetUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2):
        super(InceptionResNetUNet, self).__init__()

        # encoder
        self.enc1 = InceptionResNetBlock(in_channels, 64)
        self.enc2 = InceptionResNetBlock(64, 128)
        self.enc3 = InceptionResNetBlock(128, 256)
        self.enc4 = InceptionResNetBlock(256, 512)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # bottleneck
        self.bottleneck = InceptionResNetBlock(512, 1024)
        
        # decoder
        self.dec4 = UNetDecoderBlock(1024, 512)
        self.dec3 = UNetDecoderBlock(512, 256)
        self.dec2 = UNetDecoderBlock(256, 128)
        self.dec1 = UNetDecoderBlock(128, 64)
        
        # output Layer
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        
        bottleneck = self.bottleneck(self.pool(enc4))
        
        dec4 = self.dec4(bottleneck, enc4)
        dec3 = self.dec3(dec4, enc3)
        dec2 = self.dec2(dec3, enc2)
        dec1 = self.dec1(dec2, enc1)
        
        out = self.out_conv(dec1)
        out = torch.sigmoid(out)
        return out
    

model = InceptionResNetUNet(in_channels=3, out_channels=2)

In [None]:
DEVICE = torch.device('cpu')

if os.path.exists('./best_model.pth'):
    model = torch.load('./best_model.pth', map_location=DEVICE, weights_only=False)
    # model trained on parallel GPU - if not, can skip
    model = model.module
    model.to(DEVICE)
    model.eval()
    print('Loaded UNet model from file.')

In [None]:
def preprocess_unseen_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    transform = album.Compose([
        album.PadIfNeeded(min_height=512, min_width=512, always_apply=True, border_mode=0),
        ToTensorV2()
    ])

    transformed = transform(image=image)
    image_tensor = transformed["image"].unsqueeze(0)
    
    return image_tensor

In [None]:
def crop_image(image, target_image_dims=[512,512,3]):
   
    target_size = target_image_dims[0]
    image_size = len(image)
    padding = (image_size - target_size) // 2

    return image[
        padding:image_size - padding,
        padding:image_size - padding,
        :,
    ]
# helper function for data visualization
def visualize(**images):
    """
    Plot images in one row
    """
    n_images = len(images)
    plt.figure(figsize=(20,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

# Perform one hot encoding on label
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map
    
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 
        
    # Returns
        A 2D array with the same width and hieght as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis = -1)
    return x

# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    # Arguments
        image: single channel array where each value represents the class key.
        label_values

    # Returns
        Colour coded image for segmentation visualization
    """
    colour_codes = np.array(label_values)
    x = colour_codes[(image > 0.35).astype(int)]

    return x

In [None]:
class_names = ['background', 'building']
class_rgb_values = [[0,0,0],[255,255,255]]

select_class_indices = [class_names.index(cls.lower()) for cls in class_names]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

# Process orthophotograph data

In [None]:
for dataset in datasets:
    input_folder = dataset['output_folder']
    vector_outputs = dataset['vector_outputs']
    
    tif_files = glob.glob(os.path.join(input_folder, "*.tif"))
    
    gdfs = []
    
    for i, tif_file in enumerate(tif_files):
        print(f'Processing: {tif_file}', end='')
    
        image_path = tif_file
        img = preprocess_unseen_image(image_path)
        img = img.to(DEVICE)
        
        img = img.to(torch.float32)
        
        pred_mask = model(img)
        
        pred_mask = pred_mask.detach().squeeze().cpu().numpy()
        # Convert pred_mask from `CHW` format to `HWC` format
        pred_mask = np.transpose(pred_mask,(1,2,0))
        
        pred_building_heatmap = pred_mask[:,:,0]
        pred_mask = crop_image(colour_code_segmentation(reverse_one_hot(pred_mask), select_class_rgb_values))
        
        # ---------------------------------------------
        # crop mask to original size
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask_h, mask_w, _ = pred_mask.shape
        img_h, img_w, _ = image.shape
        
        crop_w_start = (mask_w - img_w) // 2 
        crop_w_end = crop_w_start + img_w 
        
        crop_h_start = (mask_h - img_h) // 2 
        crop_h_end = crop_h_start + img_h 
        
        cropped_mask = pred_mask[crop_h_start:crop_h_end, crop_w_start:crop_w_end]
        # ---------------------------------------------
        
        # ---------------------------------------------
        # edge cleaning/separation
        
        small_kernel = np.ones((2,2), np.uint8)
        tiny_kernel = np.ones((1,1), np.uint8)
        
        cropped_mask = cropped_mask.astype(np.uint8)
        cropped_mask_gray = cv2.cvtColor(cropped_mask, cv2.COLOR_RGB2GRAY)
        
        mask = cv2.erode(cropped_mask_gray, small_kernel, iterations=2) 
        mask = cv2.dilate(mask, small_kernel, iterations=1) 
        
        opened_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, tiny_kernel, iterations=3)
        
        _, binary_mask = cv2.threshold(opened_mask, 127, 255, cv2.THRESH_BINARY)
        
        dist_transform = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 3)
        _, sure_fg = cv2.threshold(dist_transform, 0.40 * dist_transform.max(), 255, 0)
        
        sure_bg = cv2.dilate(binary_mask, small_kernel, iterations=2)
        sure_fg = np.uint8(sure_fg)
        unknown = cv2.subtract(sure_bg, sure_fg)
        
        _, markers = cv2.connectedComponents(sure_fg)
        markers = markers + 1
        markers[unknown == 255] = 0
        cv2.watershed(cropped_mask, markers)
        
        separated_mask = np.uint8(markers > 1) * 255
        
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(separated_mask, connectivity=8)
        
        min_size = 500
        filtered_mask = np.zeros_like(separated_mask)
        
        # skip background
        for i in range(1, num_labels):
            area = stats[i, cv2.CC_STAT_AREA]
            if area >= min_size:
                filtered_mask[labels == i] = 255
        
        filtered_mask_colored = cv2.cvtColor(filtered_mask, cv2.COLOR_GRAY2BGR)
        mask_color = np.array([255, 0, 0], dtype=np.uint8)
        overlay = np.where(filtered_mask_colored == 255, mask_color, image)
        alpha = 0.5
        blended = cv2.addWeighted(image, 1 - alpha, overlay, alpha, 0)
        # ---------------------------------------------
        
        # ---------------------------------------------
        # save as vector data
        def raster_to_vector(filtered_mask:
            """
            Convert a raster (.tif) mask to vector polygons and save as GeoJSON.
        
            Parameters:
                input_tif (str): Path to the input .tif file.
            """
            with rasterio.open(image_path) as src:
                image = src.read(1)
                transform = src.transform
                raster_crs = src.crs
                raster_shape = (src.height, src.width)
                raster_bounds = src.bounds
        
            if filtered_mask.shape != raster_shape:
                raise ValueError('Mask dimensions do not match the original image dimensions')
        
            shapes_gen = shapes(filtered_mask, transform=transform)
        
            polygons = [shape(geom) for geom, value in shapes_gen if value > 0]
            gdf = gpd.GeoDataFrame(geometry=polygons, crs=raster_crs)
    
            def plot_raster_with_vectors(raster_path, vector_gdf):
                """
                Plots a raster (.tif) with an overlaid vector layer.
            
                Parameters:
                    raster_path (str): Path to the raster file (.tif).
                    vector_gdf (GeoDataFrame): Geopandas GeoDataFrame containing vector polygons.
                """
                with rasterio.open(raster_path) as src:
                    fig, ax = plt.subplots(figsize=(10, 8))
                    show(src, ax=ax, title='Raster with Vector Overlay', cmap='gray')
                    vector_gdf.boundary.plot(ax=ax, edgecolor='red', linewidth=1)
                    plt.show()
    
            print('Vector data generated')
            return gdf
        
        gdf = raster_to_vector(filtered_mask)
        # ---------------------------------------------
        
        # ---------------------------------------------
        
        # get visual data, get resnet50 embeddings
        resnet50 = models.resnet50(pretrained=True)
        resnet50 = torch.nn.Sequential(*list(resnet50.children())[:-1])
        resnet50.eval()
        
        def preprocess_image(image):
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
            return transform(image).unsqueeze(0)
        
        def extract_features(image):
            image_tensor = preprocess_image(image)
            with torch.no_grad():
                features = resnet50(image_tensor)
            return features.squeeze().numpy()
        
        def embed_building(poly):
            cropped_image, cropped_transform = rasterio_mask(src, [poly], crop=True)
            cropped_image = np.transpose(cropped_image, (1, 2, 0))
            cropped_pil = Image.fromarray(cropped_image.astype(np.uint8))
            features = extract_features(cropped_pil)
            return features
        
        
        with rasterio.open(image_path) as src:
            crs = src.crs
            if gdf.crs != crs:
                gdf = gdf.to_crs(crs)
        
            gdf['resnet50_embed'] = gdf['geometry'].apply(lambda poly: embed_building(poly))     
    
        gdfs.append(gdf)
    
        if len(gdfs) > 100:
            file_name = tif_file.split('/')[-1].split('.')[0]
            (
                gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True))
                # convert coordinate system to WGS84
                .to_crs(epsg=4326)
                .to_file(f"{vecotr_outputs}/{file_name}.geojson", driver='GeoJSON')
            )
            gdfs = []

# -- Image processed --