<a href="https://colab.research.google.com/github/hanzila1/Sentinel-2-HighRes-Stack-Image-downloader/blob/main/S2DR3_High_Resolution_Stack_Downloader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Install the Dependinces

In [None]:
# Make sure to select the T4 GPU instance from the menu Runtime > Change runtime type
# OPTIONAL: To link your Google Drive and save output files, uncomment and modify the line below
#!ln -s /content/drive/MyDrive/{PATH_TO_OUTPUT_FILES} /content/output

# Install GDAL (tool for geospatial data manipulation)
!apt install -qq gdal-bin python3-gdal

# Install the S2DR3 package from a direct link (precompiled for Python 3.11 and Linux)
!pip -q install https://storage.googleapis.com/0x7ff601307fa5/s2dr3-20250307.1-cp311-cp311-linux_x86_64.whl

# Install additional packages for tile processing and stitching
!pip install -q rasterio matplotlib folium ipywidgets

In [None]:
import ee
import os
import shutil

# Authenticate with Earth Engine
ee.Authenticate()
ee.Initialize(project='ee-hanzilabinyounus')  # Replace with your project ID

# Clear the output directory
output_dir = '/content/op'
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)
print("Output folder created!")

# Initial coordinates (WGS84) - these will be our starting point
initial_lon, initial_lat = [72.914531, 33.234768]  # Coordinates for Islamabad
point = ee.Geometry.Point([initial_lon, initial_lat])

print(f"Initial coordinates set: {initial_lon}, {initial_lat}")



#Apply the model S2DR3

In [None]:
# Filter Sentinel-2 collection for the initial point
def find_best_s2_image(point, start_date='2025-01-01', end_date='2025-04-30', cloud_percent=3):
    """Find the best available Sentinel-2 image for the given point and date range."""
    collection = (
        ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
        .filterBounds(point)
        .filterDate(start_date, end_date)
        .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', cloud_percent))
        .sort('CLOUDY_PIXEL_PERCENTAGE')
    )

    # Check the number of available images
    n_images = collection.size().getInfo()

    if n_images == 0:
        # Try with higher cloud percentage if no images found
        if cloud_percent < 20:
            print(f"No images found with {cloud_percent}% cloud cover. Trying with 20%...")
            return find_best_s2_image(point, start_date, end_date, 20)
        else:
            raise ValueError(f"No images found for the specified date range {start_date} to {end_date}")

    # Get the first (best) image and extract the date
    image = ee.Image(collection.first())
    date = image.date().format('YYYY-MM-dd').getInfo()
    cloud = image.get('CLOUDY_PIXEL_PERCENTAGE').getInfo()

    print(f"Best image found: Date {date} with {cloud:.1f}% cloud cover")
    return date

# Get the best date for our initial location
best_date = find_best_s2_image(point)

#Download the result as GeoTIFF

In [None]:
import numpy as np
import glob
from IPython.display import display, HTML
import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Configuration variables
TILE_SIZE_KM = 4  # Default tile size in kilometers
EARTH_CIRCUMFERENCE_KM = 40075  # Earth's approximate circumference at equator

def calculate_adjacent_coordinates(base_lon, base_lat, direction):
    """Calculate coordinates for an adjacent 4x4 km tile."""
    # Calculate approximate degree shift for 4km at the given latitude
    # At equator, 1 degree longitude ≈ 111 km, but this varies with latitude
    lon_shift = (TILE_SIZE_KM / (EARTH_CIRCUMFERENCE_KM * np.cos(np.radians(base_lat)))) * 360
    lat_shift = TILE_SIZE_KM / EARTH_CIRCUMFERENCE_KM * 360

    if direction == "east":
        return [base_lon + lon_shift, base_lat]
    elif direction == "west":
        return [base_lon - lon_shift, base_lat]
    elif direction == "north":
        return [base_lon, base_lat + lat_shift]
    elif direction == "south":
        return [base_lon, base_lat - lat_shift]
    else:
        raise ValueError("Direction must be 'east', 'west', 'north', or 'south'")

def plot_processed_tiles(tiles, current_tile_idx=None):
    """Plot a simple visualization of processed tiles."""
    # Extract all coordinates
    lons = [t['lon'] for t in tiles]
    lats = [t['lat'] for t in tiles]

    # Calculate relative positions for plotting
    base_lon, base_lat = lons[0], lats[0]
    x_positions = [(lon - base_lon) / (TILE_SIZE_KM / (EARTH_CIRCUMFERENCE_KM * np.cos(np.radians(base_lat))) * 360)
                  for lon in lons]
    y_positions = [(lat - base_lat) / (TILE_SIZE_KM / EARTH_CIRCUMFERENCE_KM * 360)
                  for lat in lats]

    # Create plot
    fig, ax = plt.subplots(figsize=(8, 6))

    # Draw tiles
    for i, (x, y) in enumerate(zip(x_positions, y_positions)):
        color = 'blue' if i != current_tile_idx else 'red'
        alpha = 0.5 if i != current_tile_idx else 0.7
        rect = Rectangle((x - 0.5, y - 0.5), 1, 1, linewidth=1,
                         edgecolor=color, facecolor=color, alpha=alpha)
        ax.add_patch(rect)
        ax.text(x, y, str(i+1), ha='center', va='center', color='white', fontweight='bold')

    # Set up axis
    ax.set_xlim(min(x_positions) - 1, max(x_positions) + 1)
    ax.set_ylim(min(y_positions) - 1, max(y_positions) + 1)
    ax.set_aspect('equal')
    ax.grid(True)
    ax.set_title('Processed Tiles')
    ax.set_xlabel('East-West (Tile Units)')
    ax.set_ylabel('North-South (Tile Units)')

    plt.tight_layout()
    plt.show()

In [None]:
import s2dr3.inferutils
import time
from google.colab import files
import rasterio
from rasterio.merge import merge
import os
import subprocess

class TileProcessor:
    def __init__(self, initial_lon, initial_lat, date):
        self.processed_tiles = []  # List to store processed tile information
        self.current_lon = initial_lon
        self.current_lat = initial_lat
        self.date = date
        self.output_base_dir = '/content/output'

        # Display initialization message
        print(f"Tile Processor initialized for date: {date}")
        print(f"Ready to process tiles starting at: {initial_lon}, {initial_lat}")

    def process_current_tile(self):
        """Process the current tile coordinates using S2DR3."""
        start_time = time.time()
        print(f"Processing tile at coordinates: [{self.current_lon}, {self.current_lat}]")

        # Get a list of directories before processing
        before_dirs = self._get_all_subdirs(self.output_base_dir)

        # Run S2DR3 for the current coordinates
        s2dr3.inferutils.test([self.current_lon, self.current_lat], self.date)

        # Wait a moment for file system to update
        time.sleep(2)

        # Get a list of directories after processing
        after_dirs = self._get_all_subdirs(self.output_base_dir)

        # Find new directories created during processing
        new_dirs = [d for d in after_dirs if d not in before_dirs]

        # Find all TIF files in new directories
        output_files = self._find_output_tifs(new_dirs)

        if output_files:
            # Just use the high-resolution tile files (containing 'x10' in the name)
            hi_res_files = [f for f in output_files if 'x10' in f]

            # If no high-res files found, use any TIF file
            tile_file = hi_res_files[0] if hi_res_files else output_files[0]

            # Store tile information
            self.processed_tiles.append({
                'lon': self.current_lon,
                'lat': self.current_lat,
                'output_file': tile_file,
                'all_files': output_files
            })

            elapsed_time = time.time() - start_time
            print(f"✅ Tile {len(self.processed_tiles)} processed successfully in {elapsed_time:.1f} seconds")
            print(f"Primary output file: {tile_file}")
            return True
        else:
            print("❌ Failed to find output TIF files.")
            # Try to find any new TIF files anywhere in the output directory
            all_tifs = self._find_all_tifs(self.output_base_dir)
            if all_tifs:
                print("However, we found these TIF files that might be from this run:")
                for tif in all_tifs[:5]:  # Show just the first 5 to avoid clutter
                    print(f" - {tif}")
                if len(all_tifs) > 5:
                    print(f"  ... and {len(all_tifs) - 5} more")

                # Use the first found TIF file as a fallback
                self.processed_tiles.append({
                    'lon': self.current_lon,
                    'lat': self.current_lat,
                    'output_file': all_tifs[0],
                    'all_files': all_tifs
                })
                print(f"✅ Using {all_tifs[0]} for this tile")
                return True
            return False

    def _get_all_subdirs(self, base_dir):
        """Get all subdirectories in the base directory."""
        if not os.path.exists(base_dir):
            return []

        result = []
        for root, dirs, files in os.walk(base_dir):
            for d in dirs:
                result.append(os.path.join(root, d))
        return result

    def _find_output_tifs(self, directories):
        """Find all TIF files in the given directories."""
        result = []
        for directory in directories:
            for root, dirs, files in os.walk(directory):
                for file in files:
                    if file.endswith('.tif'):
                        result.append(os.path.join(root, file))
        return result

    def _find_all_tifs(self, base_dir):
        """Find all TIF files in the base directory."""
        if not os.path.exists(base_dir):
            return []

        result = []
        for root, dirs, files in os.walk(base_dir):
            for file in files:
                if file.endswith('.tif'):
                    result.append(os.path.join(root, file))
        return result

    def process_adjacent_tile(self, direction):
        """Process an adjacent tile in the specified direction."""
        new_coords = calculate_adjacent_coordinates(self.current_lon, self.current_lat, direction)
        self.current_lon, self.current_lat = new_coords
        success = self.process_current_tile()

        # Always display the direction selector after processing, even if there was an issue
        # This ensures the workflow continues
        self.display_direction_selector()
        return success

    def display_direction_selector(self):
        """Display interactive buttons for selecting the next tile direction."""
        # Create buttons for each direction
        btn_north = widgets.Button(description="↑ North", button_style='info')
        btn_south = widgets.Button(description="↓ South", button_style='info')
        btn_east = widgets.Button(description="→ East", button_style='info')
        btn_west = widgets.Button(description="← West", button_style='info')
        btn_stitch = widgets.Button(description="🔄 Stitch All Tiles", button_style='success')

        # Define button click handlers
        def on_north_clicked(b):
            self.process_adjacent_tile("north")

        def on_south_clicked(b):
            self.process_adjacent_tile("south")

        def on_east_clicked(b):
            self.process_adjacent_tile("east")

        def on_west_clicked(b):
            self.process_adjacent_tile("west")

        def on_stitch_clicked(b):
            self.stitch_tiles()

        # Assign handlers to buttons
        btn_north.on_click(on_north_clicked)
        btn_south.on_click(on_south_clicked)
        btn_east.on_click(on_east_clicked)
        btn_west.on_click(on_west_clicked)
        btn_stitch.on_click(on_stitch_clicked)

        # Display a grid of buttons with a title
        display(HTML("<h3>Select the direction for the next 4x4 km tile:</h3>"))
        display(widgets.HBox([widgets.VBox([widgets.HBox([btn_west, btn_east]),
                                           widgets.HBox([btn_north, btn_south])]),
                             widgets.VBox([btn_stitch])]))

        # Display tile map
        if self.processed_tiles:
            plot_processed_tiles(self.processed_tiles, len(self.processed_tiles)-1)

        # Show status of processed tiles
        print(f"\nCurrent status: {len(self.processed_tiles)} tiles processed")
        print(f"Current position: {self.current_lon}, {self.current_lat}")
        print("Choose a direction to process the next adjacent 4x4 km tile")

    def stitch_tiles(self):
        """Stitch all processed tiles into a single large GeoTIFF."""
        if len(self.processed_tiles) <= 1:
            print("Need at least 2 tiles to stitch!")
            return

        print(f"Stitching {len(self.processed_tiles)} tiles together...")

        # Open all raster files
        src_files_to_mosaic = []
        for tile in self.processed_tiles:
            try:
                src = rasterio.open(tile['output_file'])
                src_files_to_mosaic.append(src)
            except Exception as e:
                print(f"Warning: Could not open {tile['output_file']}: {e}")

        if len(src_files_to_mosaic) < 2:
            print("Not enough valid tiles to create a mosaic!")
            return

        try:
            # Merge rasters
            mosaic, out_trans = merge(src_files_to_mosaic)

            # Copy the metadata from the first image
            out_meta = src_files_to_mosaic[0].meta.copy()

            # Update the metadata
            out_meta.update({
                "driver": "GTiff",
                "height": mosaic.shape[1],
                "width": mosaic.shape[2],
                "transform": out_trans
            })

            # Write the mosaic to disk
            output_filename = f"/content/mosaic_{len(self.processed_tiles)}tiles_{self.date}.tif"
            with rasterio.open(output_filename, "w", **out_meta) as dest:
                dest.write(mosaic)

            print(f"✅ Mosaic created successfully: {output_filename}")

            # Close the input datasets
            for src in src_files_to_mosaic:
                src.close()

            # Provide download link
            print("Downloading stitched image...")
            files.download(output_filename)

        except Exception as e:
            print(f"Error during stitching: {e}")

            # Close the input datasets
            for src in src_files_to_mosaic:
                src.close()

    def add_existing_tile(self, tif_path):
        """Add an already processed tile to the collection manually."""
        if not os.path.exists(tif_path):
            print(f"Error: File {tif_path} not found")
            return False

        self.processed_tiles.append({
            'lon': self.current_lon,
            'lat': self.current_lat,
            'output_file': tif_path,
            'all_files': [tif_path]
        })

        print(f"✅ Added existing tile: {tif_path}")
        return True

In [None]:
# Process the initial tile using S2DR3
import s2dr3.inferutils
import glob
import os

# Use our initial coordinates and best date from previous cells
print(f"Starting inference for initial tile at [{initial_lon}, {initial_lat}] on date {best_date}...")

# Track existing TIF files before processing
before_files = set(glob.glob('/content/*.tif') + glob.glob('/content/op/*.tif') + glob.glob('/content/output/*.tif'))

# Process the first tile
s2dr3.inferutils.test([initial_lon, initial_lat], best_date)

# Find new TIF files that were created
after_files = set(glob.glob('/content/*.tif') + glob.glob('/content/op/*.tif') + glob.glob('/content/output/*.tif'))
new_files = after_files - before_files

# Show what files were created
if new_files:
    print(f"Created {len(new_files)} new TIF files:")
    for file in new_files:
        print(f" - {file}")
    first_tile_path = list(new_files)[0] if new_files else None
else:
    print("No new TIF files were found. This could indicate an issue with the process.")
    # Try to find any TIF files in the system
    all_tifs = !find /content -name "*.tif" -type f
    if all_tifs:
        print("But we found these TIF files in the system:")
        for tif in all_tifs:
            print(f" - {tif}")
        first_tile_path = all_tifs[0] if all_tifs else None
    else:
        first_tile_path = None

print("\nInitial tile processed. You can now expand to adjacent areas.")

In [None]:
# Initialize the tile processor with our initial coordinates
processor = TileProcessor(initial_lon, initial_lat, best_date)

# Find all existing TIF files in the output directory
output_files = []
if os.path.exists('/content/output'):
    for root, dirs, files in os.walk('/content/output'):
        for file in files:
            if file.endswith('.tif'):
                output_files.append(os.path.join(root, file))

# If we have output files from the first run, use them
if output_files:
    # Filter for high-resolution files if available
    hi_res_files = [f for f in output_files if 'x10' in f]
    first_tile_path = hi_res_files[0] if hi_res_files else output_files[0]

    print(f"Found existing tile: {first_tile_path}")
    processor.add_existing_tile(first_tile_path)

    # Now display the direction selector to continue processing
    processor.display_direction_selector()
else:
    print("No existing tiles found. Processing first tile...")
    # Process the first tile directly through the processor
    if processor.process_current_tile():
        processor.display_direction_selector()
    else:
        print("Failed to process first tile. Please check the S2DR3 output and logs.")