In [None]:
import rasterio
from rasterio.windows import Window

import torch
import torch.nn.functional as F

import os
import os.path as path

import numpy as np

In [None]:
project_dir = os.getcwd()
label_dir = path.join(project_dir, "data")
output_dir = path.join(project_dir, "relabeled")

overwrite_existing = False

label_files = sorted(
    [path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith(".tif") and (overwrite_existing or not f in os.listdir(output_dir))]
)

In [None]:
COLOR_TO_CLASS = {
        (255, 255, 255): 0,  # Impervious surfaces
        (0, 0, 255): 1,  # Building
        (255, 255, 0): 2,  # Car
        (0, 255, 255): 3,  # Low vegetation
        (0, 255, 0): 4,  # Tree
        (255, 0, 0): 5,  # Clutter/background
}
CLASS_TO_COLOR = {class_idx: rgb for rgb, class_idx in COLOR_TO_CLASS.items()}
NUM_CLASSES =  len(COLOR_TO_CLASS)

IMAGE_SIZE = 6000

In [None]:
# def relabel_by_neighbourhood(label_patch, target_class):
#     new_patch = label_patch.clone()
#     rows, cols = torch.where(torch.BoolTensor(label_patch == target_class))
#
#     H, W = label_patch.shape
#
#     neighbours = [
#         (-1, -1), (-1, 0), (-1, 1),
#         (0, -1),           (0, 1),
#         (1, -1),  (1, 0),  (1, 1)
#     ]
#
#     num_relabeled = 0
#     for r, c, in zip(rows.tolist(), cols.tolist()):
#         neighbour_labels = []
#         for dr, dc in neighbours:
#             rr, cc = r + dr, c + dc
#             if 0 <= rr < H and 0 <= cc < W:
#                 neighbour_label = label_patch[rr, cc].item()
#                 if neighbour_label != target_class:
#                     neighbour_labels.append(neighbour_label)
#
#         if neighbour_labels:
#             majority_label = max(set(neighbour_labels), key=neighbour_labels.count)
#             new_patch[r, c] = majority_label
#             num_relabeled += 1
#
#     return new_patch, num_relabeled
#
# def iterative_relabel_by_neighbourhood(label_patch, target_class, max_iterations=100):
#     new_patch = label_patch.clone()
#     for i in range(max_iterations):
#         if not torch.any(torch.BoolTensor(new_patch == target_class)):
#             print(f"Relabeling stopped after {i} iterations")
#             break
#         new_patch, num_relabeled = relabel_by_neighbourhood(new_patch, target_class)
#         # print(f"Iteration {i}: relabeled {num_relabeled} pixels")
#
#         if num_relabeled == 0:
#             print(f"No more pixels could be relabeled at iteration {i}. Stopping.")
#             break
#
#     return new_patch

In [None]:
def relabel_by_neighbourhood_vectorized(label_patch, target_class, num_classes):
    # Transform H, W into C, H, W with one channel per class to obtain a stack of binary masks
    one_hot = F.one_hot(label_patch, num_classes=num_classes).permute(2, 0, 1).unsqueeze(0).float()
    one_hot[:, target_class, :, :] = 0 # ignore neighbours with target_class label

    # 3x3 kernel, with center = 0
    kernel = torch.ones((num_classes, 1, 3, 3), dtype=torch.float32)
    kernel[:, 0, 1, 1] = 0

    # Majority vote based on neighbouring classes
    neighbor_counts = F.conv2d(one_hot, kernel, padding=1, groups=num_classes)
    majority_classes = torch.argmax(neighbor_counts, dim=1).squeeze(0)

    mask = (neighbor_counts.sum(dim=1).squeeze(0) > 0) & (label_patch == target_class)

    new_patch = label_patch.clone()
    new_patch[mask] = majority_classes[mask]
    num_relabeled = mask.sum().item()

    return new_patch, num_relabeled

def iterative_relabel_by_neighbourhood_vectorized(label_patch, target_class, num_classes, max_iterations=100):
    new_patch = label_patch.clone()
    for i in range(max_iterations):
        if not torch.any(torch.BoolTensor(new_patch == target_class)):
            print(f"Relabeling stopped after {i} iterations")
            break
        new_patch, num_relabeled = relabel_by_neighbourhood_vectorized(new_patch, target_class, num_classes)
        # print(f"Iteration {i}: relabeled {num_relabeled} pixels")

        if num_relabeled == 0:
            print(f"No more pixels could be relabeled at iteration {i}. Stopping.")
            break

    return new_patch

In [None]:
TARGET_CLASS = 2

In [None]:
for label_file_path in label_files:
    # Load label images
    label_file = rasterio.open(label_file_path)
    label_rgb = label_file.read(window=Window(0, 0, IMAGE_SIZE, IMAGE_SIZE))
    label_rgb = np.transpose(label_rgb, (1, 2, 0))

    # Ensure that no unknown colors are present
    unique_colors = {tuple(color) for color in np.unique(label_rgb.reshape(-1, 3), axis=0)}
    unknown_colors = unique_colors - set(COLOR_TO_CLASS.keys())
    if unknown_colors:
        print(f"Unknown colors found in {label_file_path}:\n{[(int(r), int(g), int(b)) for r,g,b in unknown_colors]}")
        continue

    # Map rgb to class indices
    class_mask = np.zeros((label_rgb.shape[0], label_rgb.shape[1]), dtype=np.int64)
    for color, cls_idx in COLOR_TO_CLASS.items():
        class_mask[np.all(label_rgb == np.array(color), axis=-1)] = cls_idx

    label_patch = torch.from_numpy(class_mask).long()

    # Relabeling
    relabeled_patch = iterative_relabel_by_neighbourhood_vectorized(label_patch, TARGET_CLASS, NUM_CLASSES)

    # Map class indices back to rgb
    H, W = relabeled_patch.shape
    relabeled_rgb = np.zeros((H, W, 3), dtype=np.uint8)
    for class_index, rgb in CLASS_TO_COLOR.items():
        mask = relabeled_patch == class_index
        relabeled_rgb[mask] = np.array(rgb, dtype=np.uint8)


    # Save relabeled file
    meta_rgb = label_file.meta.copy()
    meta_rgb.update({
        "count": 3,
        "dtype": "uint8",
        "height": H,
        "width": W
    })

    output_file_path = path.join(output_dir, path.basename(label_file_path))
    with rasterio.open(output_file_path, "w", **meta_rgb) as f:
        f.write(np.transpose(relabeled_rgb, (2, 0, 1)))

    print(f"Saved relabeled file to {output_file_path}")