<a href="https://colab.research.google.com/github/agroimpacts/adleo/blob/main/assignments/assignment5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Super-Resolution Visualization


## Summary
The purpose of this notebook is to compare the outputs of our super-resolution model (adpated from WorldStrat) to the original Sentinel 2 input imagery, as well as PlanetScope high-resolution imagery.

Requirements: You must have add a shortcut to our [Google Drive](https://drive.google.com/drive/u/0/folders/1poQVjxeLIgITe0vYxI51rtVrJH9nQrVP) to your own. The file path should be as follows:  
/content/drive/MyDrive/SuperResolution12RV2/'

Note: due to memory limitations of the Colab environment, the results have been clipped to the small extent of the Planet Imagery for comparison in this notebook. To download the full images, please visit our [Google Drive](https://drive.google.com/drive/u/0/folders/1poQVjxeLIgITe0vYxI51rtVrJH9nQrVP) or download the files directly:
____
Site 0:  
Input Sentinel 2 Image  
[Super-Resolution Output (Per-Band Normalization)](https://drive.google.com/file/d/1cVEdAeIEw2C4fLNy8QAfdlOIpOCFM8ha/view)  
[Super-Resolution Output (Cross-Band Normalization)]()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install rasterio leafmap localtileserver

In [3]:
import os
import tempfile
import subprocess
import numpy as np
import rasterio
from rasterio.windows import from_bounds
from rasterio.windows import Window
from rasterio.enums import Resampling
from rasterio.transform import from_origin
import leafmap

import logging
logging.getLogger("rasterio._env").setLevel(logging.ERROR)

### Loading in tifs for visualization (do not run unless new files need to be visualized)
These steps only need to be executed once for each file. It generates pyramids (tiling) and

In [None]:
# # Add gdal for next steps
# !apt-get update && apt-get install -y gdal-bin

In [None]:
# # Build pyramids / tiling for display with Leafmap (only once per file lifetime)
# !gdaladdo -r nearest /content/drive/MyDrive/SuperResolution12RV2/SuperResolutionInference/merged_raster_COG.tif 2 4 8 16 32

..30...40...50...60...70...80...90...100 - done.


In [None]:
# # Confirming tiling pyramids/overview
# !gdalinfo /content/drive/MyDrive/SuperResolution12RV2/SuperResolutionInference/merged_raster_COG.tif

In [None]:
# # Confirming tiling pyramids/overview
# !gdalinfo /content/drive/MyDrive/SuperResolution12RV2/PlanetImagery/Site0_Planet_COG.tif

In [None]:
# Convert Planet images into COG (run only once)

# planet_paths = [
#     '/content/drive/MyDrive/SuperResolution12RV2/PlanetImagery/Site0_Planet.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/PlanetImagery/Site1_Planet.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/PlanetImagery/Site2_Planet.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/PlanetImagery/Site3_Planet.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/PlanetImagery/Site4_Planet.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/PlanetImagery/Site5_Planet.tif'
# ]

# for input_path in planet_paths:
#     output_path = input_path.replace(".tif", "_COG.tif")

#     cmd = [
#         "rio", "cogeo", "create",
#         input_path,
#         output_path,
#         "--overview-resampling", "nearest",
#         "--co", "BLOCKSIZE=512",
#         "--co", "COMPRESS=DEFLATE",
#         "--co", "TILED=YES",
#         "--nodata", "0"
#     ]

#     subprocess.run(cmd, check=True)

# Convert Sentinel 2 images into COG (run only once)
# site0_paths = [
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_0.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_1.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_2.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_3.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_4.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_5.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_6.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_7.tif',
# ]

# for input_path in site0_paths:
#     output_path = input_path.replace(".tif", "_COG.tif")

#     cmd = [
#         "rio", "cogeo", "create",
#         input_path,
#         output_path,
#         "--overview-resampling", "nearest",
#         "--co", "BLOCKSIZE=512",
#         "--co", "COMPRESS=DEFLATE",
#         "--co", "TILED=YES",
#         "--nodata", "0"  # Optional: adjust nodata if needed
#     ]

#     print(f"Converting to COG: {os.path.basename(input_path)}")
#     subprocess.run(cmd, check=True)


# cog_paths = [
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_0_COG.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_1_COG.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_2_COG.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_3_COG.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_4_COG.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_5_COG.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_6_COG.tif',
#     '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_7_COG.tif',
# ]

# pyramid_levels = ["2", "4", "8", "16", "32"]

# for path in cog_paths:
#     if os.path.exists(path):
#         try:
#             subprocess.run(["gdaladdo", "-r", "nearest", path] + pyramid_levels, check=True)
#             print(f"Pyramids built for: {os.path.basename(path)}")
#         except subprocess.CalledProcessError as e:
#             print(f"Error building pyramids for {path}: {e}")
#     else:
#         print(f"File not found: {path}")

### Rescaling of imagery for visualization

In [4]:
def composite_rescale(input_paths, output_path, bands=(4,3,2), band_mins=(0,0,0), band_maxs=(1,1,1), block_size=512):
    """
    Rescale the bands of a raster image to the range [0, 255] for visualization,
    save as a new image.

    Args:
        input_paths (str): Input rasters
        output_path (str): Output raster
        bands (tuple): Bands to be rescaled
        band_mins (tuple): Specified minimum values for each band
        band_maxs (tuple): Specified maximum values for each band
        block_size (int): Size of the blocks to process
    """
    with rasterio.open(input_paths[0]) as src0:
        width, height = src0.width, src0.height
        profile = src0.profile.copy()
        transform = src0.transform

    profile.update({
        "count": 3,
        "dtype": "uint8",
        "compress": "deflate",
        "tiled": True,
        "blockxsize": block_size,
        "blockysize": block_size
    })

    with rasterio.open(output_path, "w", **profile) as dst:
        for y in range(0, height, block_size):
            for x in range(0, width, block_size):
                win = Window(x, y, min(block_size, width - x), min(block_size, height - y))
                win_shape = (win.height, win.width)

                # store sum of pixel values
                stacked_sum = np.zeros((3, *win_shape), dtype=np.float32)

                for path in input_paths:
                    with rasterio.open(path) as src:
                        for i, (band_idx, bmin, bmax) in enumerate(zip(bands, band_mins, band_maxs)):
                            data = src.read(band_idx, window=win).astype(np.float32)
                            data = np.clip((data - bmin) / (bmax - bmin), 0, 1) * 255
                            stacked_sum[i] += data

                averaged = stacked_sum / len(input_paths)
                averaged = np.round(averaged).astype(np.uint8)

                dst.write(averaged, window=win)

def rescale(input_path, output_path, bands=(1,2,3), band_mins=(0,0,0), band_maxs=(1,1,1), block_size=512):
    """
    Rescale the bands of a raster image to the range [0, 255] for visualization,
    save as a new image.

    Args:
        input_path (str): Input raster
        output_path (str): Output raster
        bands (tuple): Bands to be rescaled
        band_mins (tuple): Specified minimum values for each band
        band_maxs (tuple): Specified maximum values for each band
        block_size (int): Size of the blocks to process
    """
    with rasterio.open(input_path) as src:
        width, height = src.width, src.height
        profile = src.profile.copy()
        profile.update({
            "count": 3,
            "dtype": "uint8",
            "compress": "deflate",
            "tiled": True,
            "blockxsize": block_size,
            "blockysize": block_size
        })

        with rasterio.open(output_path, 'w', **profile) as dst:
            for y in range(0, height, block_size):
                for x in range(0, width, block_size):
                    win = Window(x, y, min(block_size, width - x), min(block_size, height - y))
                    win_shape = (win.height, win.width)
                    rgb_scaled = np.zeros((3, *win_shape), dtype=np.uint8)

                    for i, (band_idx, bmin, bmax) in enumerate(zip(bands, band_mins, band_maxs)):
                        data = src.read(band_idx, window=win).astype(np.float32)
                        data = np.clip((data - bmin) / (bmax - bmin), 0, 1) * 255
                        rgb_scaled[i] = np.round(data).astype(np.uint8)

                    dst.write(rgb_scaled, window=win)

from rasterio.windows import from_bounds

def clip(input_path, output_path, bounds, block_size=512):
    """
    Clip the input raster image to bounds and save as a new file.

    Args:
        input_path (str): Input raster
        output_path (str): Output raster
        bounds (tuple): Bounding box in the form (minx, miny, maxx, maxy)
        block_size (int): Size of the blocks to process
    """
    with rasterio.open(input_path) as src:
        window = from_bounds(*bounds, transform=src.transform)
        window = window.round_offsets().round_lengths()
        transform = rasterio.windows.transform(window, src.transform)
        profile = src.profile.copy()
        profile.update({
            "height": int(window.height),
            "width": int(window.width),
            "transform": transform,
            "compress": "deflate",
            "tiled": True,
            "blockxsize": block_size,
            "blockysize": block_size
        })

        with rasterio.open(output_path, 'w', **profile) as dst:
            for band_idx in range(1, src.count + 1):
                for y in range(0, window.height, block_size):
                    for x in range(0, window.width, block_size):
                        win = Window(x + window.col_off, y + window.row_off,
                                     min(block_size, window.width - x),
                                     min(block_size, window.height - y))
                        data = src.read(band_idx, window=win)
                        dst.write(data, window=Window(x, y, data.shape[1], data.shape[0]), indexes=band_idx)


In [9]:
# This could take 10 minutes or more with the size of the images
site0_paths = [
    '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_0_COG.tif',
    '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_1_COG.tif',
    '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_2_COG.tif',
    '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_3_COG.tif',
    '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_4_COG.tif',
    '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_5_COG.tif',
    '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_6_COG.tif',
    '/content/drive/MyDrive/SuperResolution12RV2/Site_0_Image_7_COG.tif',
]

planet_site0_path = '/content/drive/MyDrive/SuperResolution12RV2/PlanetImagery/Site0_Planet_COG.tif'
superres_path = '/content/drive/MyDrive/SuperResolution12RV2/SuperResolutionInference/merged_raster_COG.tif'

# temp paths to hold outputs
temp_planet = os.path.join(tempfile.gettempdir(), "planet_rgb_scaled.tif")
temp_sentinel = os.path.join(tempfile.gettempdir(), "sentinel2_rgb_composite.tif")
temp_superres = os.path.join(tempfile.gettempdir(), "site0_cog_rgb_scaled.tif")

# temp paths to hold outputs
clipped_planet = os.path.join(tempfile.gettempdir(), "planet_rgb_clipped.tif")
clipped_sentinel = os.path.join(tempfile.gettempdir(), "sentinel2_rgb_clipped.tif")
clipped_superres = os.path.join(tempfile.gettempdir(), "superres_rgb_clipped.tif")

# rescaling (Note: The values were taken directly from QGIS for visualization)
rescale(planet_site0_path, temp_planet, bands=(3,2,1), band_mins=(482,454,285), band_maxs=(1326,971,683))
composite_rescale(site0_paths, temp_sentinel, bands=(4,3,2), band_mins=(530,545,435), band_maxs=(1814,1376,1090))
rescale(superres_path, temp_superres, bands=(1,2,3), band_mins=(219,203,425), band_maxs=(663,941,1322))

# clip
with rasterio.open(temp_planet) as src:
    planet_bounds = src.bounds

clip(temp_sentinel, clipped_sentinel, planet_bounds)
clip(temp_superres, clipped_superres, planet_bounds)

# Visualizing and Comparing Super-Resolution (Normalized Per-Band)

### Comparison against Sentinel 2 Imagery (10 m native resolution)
Visualizing our first test site (Site 0) against Sentinel 2 Imagery
___
Instructions: To toggle between layers, click the Options Icon at the top right, and click the Layers pane directly to the left. The images may take time to load at different extents.  
___
Note: The commented code snippet below will display the full extent, which may not respond due to the size of the image, but you can test it. Users of High-Ram settings (Colab Pro) should be able to display it.

In [None]:
# # Display results (Full Extent)
# m = leafmap.Map()
# m.add_raster(temp_sentinel, layer_name='Sentinel-2', opacity=1)
# m.add_raster(temp_superres, layer_name='Super-Resolution', opacity=1)
# m

In [None]:
# Display results (clipped to Planet Extent)
m = leafmap.Map()
m.add_raster(clipped_sentinel, layer_name='Sentinel-2 (Clipped)', opacity=1)
m.add_raster(clipped_superres, layer_name='Super-Resolution (Clipped)', opacity=1)
m

### Comparison against Planet Imagery
Visualizing our first test site (Site 0) against Planet Imagery

In [None]:
m = leafmap.Map()
m.add_raster(temp_planet, layer_name='Planet (Clipped)', opacity=1)
m.add_raster(clipped_superres, layer_name='Super-Res (Clipped)', opacity=1)
m

# Visualizing and Comparing Super-Resolution (Normalized Across Bands)