In [1]:
from PIL import Image, ImageColor
import numpy as np
import rasterio
import os, glob
import shutil

In [2]:
DATASET_DIR = "dataset"

## Setup utility functions

In [3]:
def preprocess_images(images: list[Image.Image]) -> list[np.ndarray]:
    # Convert images to numpy arrays
    color_arrays = [np.array(image) for image in images]

    # Ignore any images that contain missing chunks
    for color_array in color_arrays:
        if np.isnan(color_array).any() or np.mean(color_array) < 200.0:
            return None

    # Normalize pixel values to range [1.0, 0.0]
    def normalize_color(array):
        max = np.max(array)
        min = np.min(array)
        return (array - min) / (max - min)

    return [normalize_color(color_array) for color_array in color_arrays]

In [4]:
def process_image(
        image_1: Image.Image,
        image_2: Image.Image,
        image_3: Image.Image
        ) -> Image.Image:
    color_arrays = preprocess_images([image_1, image_2, image_3])
    if color_arrays == None:
        return None

    colors = [Image.fromarray((color_array * 255.0).astype(np.uint8)) for color_array in color_arrays]

    # Combine the three images into a single RGB image
    return Image.merge('RGB', colors)

In [5]:
def process_ndvi_image(
        image_B04: Image.Image,
        image_B08: Image.Image
        ) -> Image.Image:
    color_arrays = preprocess_images([image_B04, image_B08])
    if color_arrays == None:
        return None

    # https://custom-scripts.sentinel-hub.com/sentinel-2/ndvi/
    def ndvi_to_color(ndvi: float):
        match ndvi:
            case _ if ndvi <  -0.5:   return ImageColor.getcolor("#0c0c0c", "RGB")
            case _ if ndvi <= -0.2:   return ImageColor.getcolor("#bfbfbf", "RGB")
            case _ if ndvi <= -0.1:   return ImageColor.getcolor("#dbdbdb", "RGB")
            case _ if ndvi <=  0.0:   return ImageColor.getcolor("#eaeaea", "RGB")
            case _ if ndvi <=  0.025: return ImageColor.getcolor("#fff9cc", "RGB")
            case _ if ndvi <=  0.05:  return ImageColor.getcolor("#ede8b5", "RGB")
            case _ if ndvi <=  0.075: return ImageColor.getcolor("#ddd89b", "RGB")
            case _ if ndvi <=  0.1:   return ImageColor.getcolor("#ccc682", "RGB")
            case _ if ndvi <=  0.125: return ImageColor.getcolor("#bcb76b", "RGB")
            case _ if ndvi <=  0.15:  return ImageColor.getcolor("#afc160", "RGB")
            case _ if ndvi <=  0.175: return ImageColor.getcolor("#a3cc59", "RGB")
            case _ if ndvi <=  0.2:   return ImageColor.getcolor("#91bf51", "RGB")
            case _ if ndvi <=  0.25:  return ImageColor.getcolor("#7fb247", "RGB")
            case _ if ndvi <=  0.3:   return ImageColor.getcolor("#70a33f", "RGB")
            case _ if ndvi <=  0.35:  return ImageColor.getcolor("#609635", "RGB")
            case _ if ndvi <=  0.4:   return ImageColor.getcolor("#4f892d", "RGB")
            case _ if ndvi <=  0.45:  return ImageColor.getcolor("#3f7c23", "RGB")
            case _ if ndvi <=  0.5:   return ImageColor.getcolor("#306d1c", "RGB")
            case _ if ndvi <=  0.55:  return ImageColor.getcolor("#216011", "RGB")
            case _ if ndvi <=  0.6:   return ImageColor.getcolor("#0f540a", "RGB")
            case _ if ndvi <=  1.0:   return ImageColor.getcolor("#004400", "RGB")
            case _: raise
    
    B04 = color_arrays[0]
    B08 = color_arrays[1]
    ndvi = (B08 - B04) / (B08 + B04)

    # When both B04 and B08 are zero, we get division by 0 and NaNs, fix that
    ndvi = np.nan_to_num(ndvi, nan=-1.0)
    
    color_arrays = np.vectorize(ndvi_to_color)(ndvi)
    colors = [Image.fromarray(color_array.astype(np.uint8)) for color_array in color_arrays]

    return Image.merge('RGB', colors)

In [6]:
def convert_geotiff_to_tiff(path: str) -> Image.Image:
    image_classification = rasterio.open(path)
    # Layer 1 is for labels, layer 2 are the probabilities
    image_classification_array = image_classification.read(1)

    image_classification.close()

    return Image.fromarray(image_classification_array)

In [7]:
def generate_mask_visualization(mask: Image.Image) -> Image.Image:
    mask_array = np.array(mask)

    class_colors = [
        ImageColor.getcolor("#000000", "RGB"), # Unknown/Clouds
        ImageColor.getcolor("#0000ff", "RGB"), # Water (Permanent)
        ImageColor.getcolor("#888888", "RGB"), # Artificial Bare Ground
        ImageColor.getcolor("#d1a46d", "RGB"), # Natural Bare Ground
        ImageColor.getcolor("#f5f5ff", "RGB"), # Snow/Ice (Permanent)
        ImageColor.getcolor("#d64c2b", "RGB"), # Woody
        ImageColor.getcolor("#186818", "RGB"), # Non-Woody Cultivated
        ImageColor.getcolor("#00ff00", "RGB"), # Non-Woody (Semi) Natural
    ]

    image_out = Image.new("RGB", (255, 255), (0, 0, 0))
    pixels = image_out.load()

    for y in range(255):
        for x in range(255):
            class_index = mask_array[y, x]
            pixels[x, y] = class_colors[class_index]
    
    return image_out

In [8]:
def get_cloud_coverage(image_coverage) -> float:
    coverage_array = np.array(image_coverage)
    return np.mean(coverage_array)

## Preprocess dataset

In [9]:
for tile in glob.glob("*", root_dir = DATASET_DIR):
    tile_path = os.path.join(DATASET_DIR, tile)
    for chip in glob.glob("*", root_dir = tile_path):
        chip_path = os.path.join(tile_path, chip)
        S2_path = os.path.join(chip_path, "S2")

        tc_path = os.path.join(chip_path, "TrueColor")
        fc_path = os.path.join(chip_path, "FalseColor")
        ndvi_path = os.path.join(chip_path, "NDVI")
        swir_path =  os.path.join(chip_path, "SWIR")

        if not os.path.exists(tc_path):
            os.mkdir(tc_path)
        if not os.path.exists(fc_path):
            os.mkdir(fc_path)
        if not os.path.exists(ndvi_path):
            os.mkdir(ndvi_path)
        if not os.path.exists(swir_path):
            os.mkdir(swir_path)

        class_mask = convert_geotiff_to_tiff(os.path.join(chip_path, f"{tile}_{chip}_2018_LC_10m.tif"))
        class_mask.save(os.path.join(chip_path, f"{tile}_{chip}_2018_MASK.tif"))

        mask_visualization = generate_mask_visualization(class_mask)
        mask_visualization.save(os.path.join(chip_path, f"{tile}_{chip}_2018_MASK_VISUAL.png"))

        for instance in glob.glob("*", root_dir = S2_path):
            instance_path = os.path.join(S2_path, instance)

            image_coverage = Image.open(os.path.join(instance_path, f"{instance}_CLD_10m.tif"),)
            cloud_coverage = get_cloud_coverage(image_coverage)
            image_coverage.close()

            if cloud_coverage > 0.5:
                continue

            image_B02 = Image.open(os.path.join(instance_path, f"{instance}_B02_10m.tif"))
            image_B03 = Image.open(os.path.join(instance_path, f"{instance}_B03_10m.tif"))
            image_B04 = Image.open(os.path.join(instance_path, f"{instance}_B04_10m.tif"))
            image_B08 = Image.open(os.path.join(instance_path, f"{instance}_B08_10m.tif"))
            image_B8A = Image.open(os.path.join(instance_path, f"{instance}_B8A_10m.tif"))
            image_B12 = Image.open(os.path.join(instance_path, f"{instance}_B12_10m.tif"))

            # True Color
            tc_image = process_image(image_B04, image_B03, image_B02)
            if tc_image != None:
                tc_image.save(os.path.join(tc_path, f"{instance}_TC.png"))
            
            # False Color
            fc_image = process_image(image_B08, image_B04, image_B03)
            if fc_image != None:
                fc_image.save(os.path.join(fc_path, f"{instance}_FC.png"))
            
            # SWIR
            swir_image = process_image(image_B12, image_B8A, image_B04)
            if swir_image != None:
                swir_image.save(os.path.join(swir_path, f"{instance}_SWIR.png"))
            
            # NDVI
            ndvi_image = process_ndvi_image(image_B04, image_B08)
            if ndvi_image != None:
                ndvi_image.save(os.path.join(ndvi_path, f"{instance}_NDVI.png"))
            
            image_B02.close()
            image_B03.close()
            image_B04.close()
            image_B08.close()
            image_B8A.close()
            image_B12.close()


  ndvi = (B08 - B04) / (B08 + B04)


## Cleanup remaining files

In [53]:
for dirpath, dirnames, filenames in os.walk(DATASET_DIR, topdown=False):
    for filename in filenames:
        if filename.endswith("2018_LC_10m.tif"):
            os.remove(os.path.join(dirpath, filename))
        if filename.endswith("labeling_dates.csv"):
            os.remove(os.path.join(dirpath, filename))

    for dirname in dirnames:
        if dirname in ["L8", "S1", "S2"]:
            dir_to_delete = os.path.join(dirpath, dirname)
            shutil.rmtree(dir_to_delete)
