# Data Visualization and Exploration

Simple notebook to visualize data

In [1]:
import os
from pathlib import Path

In [2]:
# change working directory to the root of the project
cwd = Path.cwd()
if cwd.name == "notebooks":
    os.chdir("..")

In [27]:
import numpy as np
import rasterio as rio
import shutil
from matplotlib import pyplot as plt
from PIL import Image
from tqdm import tqdm
from baseg.datasets import EMSImageDataset

In [4]:
targets = {
    "S2L2A": "image",
    "DEL": "mask",
    "CM": "mask",
    "GRA": "mask",
    "ESA_LC": "mask",
}

In [5]:
def mask2rgb(image: np.ndarray, palette: dict) -> np.ndarray:
    lut = np.zeros((256, 3), dtype=np.uint8)
    for k, v in palette.items():
        lut[k, :] = v
    return lut[image]

In [6]:
def create_mask(image, palette):
    # Load the RGB image
    image_array = np.array(image)
    # Create an empty mask array
    mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)

    # Iterate over each pixel in the image
    for i in range(image_array.shape[0]):
        for j in range(image_array.shape[1]):
            # Find the closest color in the palette
            pixel_color = tuple(image_array[i, j])
            closest_color = min(palette, key=lambda x: np.linalg.norm(np.array(x) - np.array(pixel_color)))
            
            # Assign the class index to the mask
            mask[i, j] = palette.index(closest_color)

    return mask

In [14]:
root_path = Path("data/ems")
images = list(root_path.glob("**/*S2L2A.tif"))
len(images)

560

In [15]:
def read_image(path: Path, bands: list = None, return_profile: bool = False) -> np.ndarray:
    """Read a raster image from disk."""
    with rio.open(path) as dataset:
        if bands is None:
            bands = dataset.indexes
        image = dataset.read(bands)
        if return_profile:
            return image, dataset.profile
        return image

In [16]:
def write_image(path: Path, image: np.ndarray, profile: dict) -> None:
    with rio.open(path, "w", **profile) as dst:
        dst.write(image)

In [23]:
def create_mask(image, palette: dict):
    # Load the RGB image
    image_array = np.array(image)
    # Create an empty mask array
    mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
    inv_palette = {v: k for k, v in palette.items()}

    # Iterate over each pixel in the image
    for i in range(image_array.shape[0]):
        for j in range(image_array.shape[1]):
            # Find the closest color in the palette
            pixel_color = tuple(image_array[i, j])
            mask_category = inv_palette[pixel_color]
            
            # Assign the class index to the mask
            mask[i, j] = mask_category

    return np.expand_dims(mask, axis=0)

In [29]:
indices = list(EMSImageDataset.palette.keys())
indices = set(indices)
indices

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 255}

In [36]:
def process_image(image_path, clean=True):
    image, profile = read_image(image_path, return_profile=True)
    png_path = image_path.parent / image_path.name.replace("S2L2A.tif", "ESA_LC.png")
    assert png_path.exists()

    # Load and resize the mask
    png_mask = Image.open(png_path)
    png_mask = png_mask.resize(image.shape[1:], resample=Image.NEAREST)
    mask = np.asarray(png_mask).astype(np.uint8)
    mask = create_mask(png_mask, EMSImageDataset.palette)

    mask_path = image_path.parent / png_path.name.replace(".png", ".tif")
    if mask_path.exists() and clean:
        old_path = mask_path.parent / (mask_path.name.replace(".tif", "_old.tif"))
        shutil.move(mask_path, old_path)
    profile.update(dtype=np.uint8, count=1)
    write_image(mask_path, mask, profile)

In [37]:
import multiprocessing
from functools import partial

clean = True
num_processes = multiprocessing.cpu_count() // 4 # Number of parallel processes
pool = multiprocessing.Pool(processes=num_processes)
parallel_fn = partial(process_image, clean=clean)

# Process images in parallel using the pool
with tqdm(total=len(images)) as pbar:
    for _ in pool.imap_unordered(parallel_fn, images):
        pbar.update(1)

# Close the pool to free resources
pool.close()
pool.join()

100%|██████████| 560/560 [03:43<00:00,  2.51it/s]


In [50]:
output_path = Path("data/ems/plots")

# iterate images, gather the equivalent mask and plot it
for image_path in tqdm(images):
    image, profile = read_image(image_path, bands=(4, 3, 2), return_profile=True)
    mask_path = image_path.parent / image_path.name.replace("S2L2A.tif", "ESA_LC.tif")
    del_path = image_path.parent / image_path.name.replace("S2L2A.tif", "DEL.tif")
    # read and/or transform to obtain an rgb
    image = np.clip(image.transpose(1, 2, 0) * 3, 0, 1)
    mask = read_image(mask_path)
    mask = mask2rgb(mask[0], EMSImageDataset.palette)
    delineation = mask2rgb(read_image(del_path)[0], EMSImageDataset.palette)
    assert mask.shape == image.shape == delineation.shape

    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(image)
    ax[1].imshow(mask)
    ax[2].imshow(delineation)
    # save the plot to disk
    output_path.mkdir(exist_ok=True)
    plt.savefig(output_path / f"{image_path.stem}.png")
    plt.close("all")

100%|██████████| 560/560 [10:18<00:00,  1.10s/it]
