# Land segmentation 

This notebooks:
* Predicts land cover classes from test image using the trained UNET or ResNet model. 
    * For prediction the test data is tiled.
* Computes the accuracy and IoU of the predictions by comparing the predictions to the ground truth.
* Plots the results.

Used labels:
* 1 - forest from forest inventory data
* 2 - fields from agricultural parcels data
* 3 - water from CORINE land cover data
* 0 - everything else 

In [None]:
import os
import numpy as np
import math, time
from typing import Optional, Any, Tuple

# Reading and writing raster data
import rasterio

# Torchgeo model
import torch

# Model evaluation
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex, MulticlassF1Score
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

# Plotting
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# Models
from UNET_model import UNET
from resnet_model import ResNet

# fix torch.device()
os.environ["CUDA_VISIBLE_DEVICES"] = ""

## Settings

Define folders and files.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Folders
# Set path to data and labels files
base_folder = os.path.join(os.sep, 'scratch', 'project_2017263') 
data_folder = os.path.join(base_folder,'data', 'raster')

data_test = os.path.join(data_folder, 'data_test.tif')
labels_test = os.path.join(data_folder, 'labels_test.tif')
output_folder = os.path.join(base_folder, os.environ.get('USER'), 'lumi-aif-fmi', 'day2', 'exercise3','inference')
os.makedirs(output_folder, exist_ok=True)
test_output = os.path.join(output_folder, 'segmentation_results.tif') 
test_output_all_classes = os.path.join(output_folder, 'segmentation_results_all_classes.tif') 

Settings for prediction

In [None]:
num_classes = 4
tile_size = 512 # Use the same as for model training, must be smaller than data height/width.
batch_size = 8
overlap = 20
no_of_bands = 8

In [None]:
# Set computing device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Model

Load the trained model from checkpoint. Choose here either the trained ResNet or UNET.

In [None]:
model = ResNet(8,4)
model_name = model.__class__.__name__
model.load_state_dict(torch.load(f'model_training/{model_name}_model.pt', map_location=torch.device('cpu')))
model.eval()

## Tiled inference to predict the classes

During inference the model should be given similar tiles as during model training, so again the big raster has to be tiled. The prediction quality on tile edges is often weak, so therefore we use overlapping tiles and use predictions the model is more confident of.

The steps of tiled inference:
* Calculate importances for each pixel in the tile, the pixels on the edge get lower importance, because usually there the model makes more mistakes.
* Tile the raster into overlapping tiles.
* Run inference on each tile. Each pixel gets probability value for each class - how likely this pixel belongs to any of the classes.
    * Inference is practically run in batches, because so the GPU can be better utilized and the total time of prediction is smaller. 
* Merge the tiles, keep the estimation with highest probability, counting also with importance (distance to tile edge).

One could also use SAHI for automatic inference with RGB images but as we have 6 input channels, we will write the script ourselves.

### Importance matrix for each predicted tile

Calculate importances for each pixel in the tile, the pixels on the edge get lower importance, because usually there the model makes more mistakes. Pixels in the center of the tile have higher importance. This helps with smooth blending at boundaries. Practically only pixels that overlap get reduced importance. In the plot, the darker the pixel, the less importance it has. 

Code modified from: https://github.com/opengeos/geoai/blob/main/geoai/train.py

In [None]:
h = tile_size
w = tile_size

y_grid, x_grid = np.mgrid[0:h, 0:w]

# Calculate distance from each edge
dist_from_left = x_grid
dist_from_right = w - x_grid - 1
dist_from_top = y_grid
dist_from_bottom = h - y_grid - 1

# Combine distances (minimum distance to any edge)
edge_distance = np.minimum.reduce(
    [
        dist_from_left,
        dist_from_right,
        dist_from_top,
        dist_from_bottom,
    ]
)

# Convert to weight (higher weight for center pixels)
# Normalize to [0, 1]
edge_distance = np.minimum(edge_distance + 0.1, overlap / 2)
importance = edge_distance / (overlap / 2)

# Set same importances to all bands
importances = torch.from_numpy(np.repeat(importance[np.newaxis, :, :], num_classes, axis=0))

Function to calculate prediction for a big raster.

Code partly from: https://github.com/opengeos/geoai/blob/main/geoai/train.py

In [None]:
def inference_on_geotiff(
    model: torch.nn.Module,
    data,
    tile_size: int = 512,
    overlap: int = 256,
    batch_size: int = 4,
    num_channels: int = 3,
    device: [torch.device] = "cpu",
    **kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Perform inference on a large GeoTIFF using a sliding window approach with improved blending.

    Args:
        model (torch.nn.Module): Trained model for inference.
        data (numpy Array): Data of a GeoTIFF file.
        tile_size (int): Size of sliding window for inference.
        overlap (int): Overlap between adjacent tiles.
        batch_size (int): Batch size for inference.
        num_channels (int): Number of channels to use from the input image.
        device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
        **kwargs: Additional arguments.

    Returns:
        tuple: Tuple containing output path and inference time in seconds.
    """

    # Put model in evaluation mode
    model.to(device)
    model.eval()

    height = data.shape[1]
    width = data.shape[2]

    # Initialize predictions array with very small numbers
    pixel_predictions = torch.full((num_classes, height, width), -float('inf'), device=device)

    # Calculate the number of windows needed to cover the entire image
    steps_y = math.floor((height - overlap) / (tile_size - overlap))
    steps_x = math.floor((width - overlap) / (tile_size - overlap))

    # Ensure we cover the entire image
    last_y = height - tile_size
    last_x = width - tile_size

    total_windows = steps_y * steps_x
    print(
        f"Processing {steps_y * steps_x} tiles with size {tile_size}x{tile_size} and overlap {overlap}..."
    )

    # Process in batches, the calculation goes faster, if data is fed to GPU in batches.
    batch_inputs = []
    batch_positions = []
    batch_count = 0

    # Change data type to Float as required by the model.
    image = data.astype(np.float32) 

    # Convert to tensor
    image_tensor = torch.tensor(image, device=device)

    # Slide window over the image - make sure we cover the entire image
    for i in range(steps_y + 1):  # +1 to ensure we reach the edge
        y = i * (tile_size - overlap)
        y = min(i * (tile_size - overlap), last_y)

        for j in range(steps_x + 1):  # +1 to ensure we reach the edge
            x = j * (tile_size - overlap)
            x = min(j * (tile_size - overlap), last_x)

            # Add to batch
            batch_inputs.append(image_tensor[:, y:y+tile_size, x:x+tile_size])

            # Keep track where each tile is located
            batch_positions.append((y, x))
            batch_count += 1

            # Process batch when it reaches the batch size or at the end
            if batch_count == batch_size or (i == steps_y and j == steps_x):
                batch_inputs_tensor = torch.stack(batch_inputs)
                # Forward pass, give model a batch of data.
                with torch.no_grad():
                    outputs = model(batch_inputs_tensor)

                # Process each output in the batch.
                for idx, output in enumerate(outputs):
                    y_pos, x_pos, = batch_positions[idx]
                    # Multiply with the importances based on pixel's distance to tile edge.
                    weighted_scores = output * importances
                    # Save predictions for the pixels/classes that have heigher score than previously saved.
                    pixel_predictions[:, y_pos:y_pos+tile_size, x_pos:x_pos+tile_size] = torch.max(pixel_predictions[:, y_pos:y_pos+tile_size, x_pos:x_pos+tile_size], weighted_scores)

                # Reset batch
                batch_inputs = []
                batch_positions = []
                batch_count = 0
    
    # Calculate most probable class for each pixel
    class_predictions = pixel_predictions.argmax(dim=0).numpy().astype(np.uint8)
    
    return pixel_predictions, class_predictions

Read test data from file and calculate predicted classes.

In [None]:
with rasterio.open(data_test) as src:
    data = src.read()
    pixel_predictions, class_predictions = inference_on_geotiff(model, data, tile_size, overlap, batch_size, no_of_bands, device)

### Evaluate results

Calculate accuracy, intersection of union (IoU) and f1 as TorchGeo Multiclass metrics
* Accuracy calculates how many pixels are correctly labeled
* Iou measures the per-class overlap with predicted and true pixels
* F1 calculates the harmonic mean between precision and recall for each class 

Results are averaged over all classes with 'macro'

In [None]:
# Open ground truth data for evaluation
with rasterio.open(labels_test) as src:
    ground_truth = src.read(1)

# Convert both datasets to Pytorch tensors
gt = torch.from_numpy(ground_truth)
prediction_tensor = torch.from_numpy(class_predictions)

In [None]:
acc = MulticlassAccuracy(num_classes=num_classes, average='macro')
iou = MulticlassJaccardIndex(num_classes=num_classes, average='macro')
f1  = MulticlassF1Score(num_classes=num_classes, average='macro')

# calculate metrics using predicted mask and ground truth both with shape (B, H, W)
accuracy = acc(prediction_tensor, gt)
iou_val = iou(prediction_tensor, gt)
f1_val = f1(prediction_tensor, gt)

# print pixel accuracy, mean IoU and mean f1
print({
    "Pixel Accuracy": accuracy.item(),
    "Mean IoU": iou_val.item(),
    "Mean F1": f1_val.item()
})

In [None]:
print('Classification report: \n', classification_report(ground_truth.reshape(-1), class_predictions.reshape(-1)))

In [None]:
plot = ConfusionMatrixDisplay.from_predictions(ground_truth.reshape(-1), class_predictions.reshape(-1), normalize='true', cmap=plt.cm.Blues)

In [None]:
plot.figure_.savefig(f"{output_folder}/{model_name}_confusion_matrix_.png")

### Plot results: input data raster, predicted classes and ground truth labels

In [None]:
# For better plotting of Sentinel image, normalize the values
# Help function to normalize band values and enhance contrast. Just like what QGIS does automatically
def normalize(array):
    min_percent = 2   # Low percentile
    max_percent = 98  # High percentile
    lo, hi = np.percentile(array, (min_percent, max_percent), axis=(0,1), keepdims=True)
    new_min, new_max = 1, 255
    rgb_norm = (rgb - lo) / (hi - lo) * (new_max - new_min) + new_min
    rgb_norm = rgb_norm.astype(np.uint8)    
    return rgb_norm.astype(np.uint8)

In [None]:
# open raster
with rasterio.open(data_test) as src:
    # Get RGB channels of May data
    rgb = src.read([5, 3, 1]) #
    #transpose to get H,W,C
    rgb = np.transpose(rgb, (1, 2, 0))
    rgb_norm = normalize(rgb)

    # Plot the 3 rasters
    fig, axes = plt.subplots(1, 3, figsize=(14, 20))

    # Set the colors for classified data
    cmap = ListedColormap(["black", "forestgreen", "lightyellow", "lightblue"])

    # Sentinel data
    axes[0].imshow(rgb_norm)
    axes[0].set_title("Input image")
    axes[0].axis("off")

    # Predicted classes
    axes[1].imshow(class_predictions, cmap=cmap)
    axes[1].set_title("Predicted classes")
    axes[1].axis("off")

    # Ground truth
    axes[2].imshow(ground_truth, cmap=cmap)
    axes[2].set_title("Ground truth")
    axes[2].axis("off")
      
    plt.savefig(f"{output_folder}/{model_name}_segmentation_results.png")

Note that other class has some roads actually visible, although they are not predicted as `other` with default settings. If needed, we could make `other` class probabilities manually higher to get `other` class better represented in the final classification.