# Postprocessing Annotations Supervisely

**Objective:** Process annotated rooftop segmentation data for quality control and standardization.

**Workflow:**
1. Load dataset and annotation metadata
2. Clip images using building polygons
3. Handle tile overlaps and remove duplicates
4. Standardize dimensions to 1280x1280
5. Validate data integrity and export

## Imports

In [None]:
import datetime
import os
import random
import shutil
import tempfile
import traceback
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import geopandas as gpd
from shapely.geometry import Polygon, mapping

import rasterio
import rasterio.mask
from rasterio.crs import CRS
from rasterio.windows import from_bounds, Window

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor


## Configuration

In [None]:
todays_date = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

DATASET_ANNOTATED_BRUT_PATH = "datasets/supervisely/341575_free_space_rooftop_geneva_20250511_binary_mask"
DATASET_TILES_INFORMATION_CSV_PATH = "data/notebook_06/dataset_20250405-193125/PNG_dataset_roboflow_20250405-193143/sampled_tiles.csv"
CORRECT_CRS = CRS.from_epsg(2056)
EPSG_SUISSE = "EPSG:2056"

CAD_BATIMENT_HORSOL_TOIT_MERGE_PARQUET_PATH = "data/notebook_04/parquet/04_02_merged_rooftops_poly.parquet"

# Parquet files
VERIFICATION_OUTPUT_PARQUET_PATH = "data/notebook_06/parquet/06_01_verification.parquet"
DATASET_OUTPUT_PARQUET_PATH = "data/notebook_06/parquet/06_02_dataset_processed.parquet"
DATASET_FINAL_OUTPUT_PARQUET_PATH = "data/notebook_06/parquet/06_03_dataset_final.parquet"

# dataset processed
DATASET_PROCESSED_NAME = "dataset_processed_" + str(todays_date)
DATASET_PROCESSED_PATH = "datasets/supervisely/" + DATASET_PROCESSED_NAME
DATASET_OUTPUT_IMG_PATH = DATASET_PROCESSED_PATH + "/images"
DATASET_OUTPUT_MASKS_PATH = DATASET_PROCESSED_PATH + "/masks"
DATASET_OUTPUT_CHECKS_PATH = DATASET_PROCESSED_PATH + "/check_dataset"

# buffer pour les chevauchements
BUFFER_DISTANCE = 0 # en mètre
OVERLAP_POSITIONS=['top', 'right', 'top-left', 'top-right']

os.makedirs(DATASET_PROCESSED_PATH)
os.makedirs(DATASET_OUTPUT_IMG_PATH)
os.makedirs(DATASET_OUTPUT_MASKS_PATH)
os.makedirs(DATASET_OUTPUT_CHECKS_PATH)

#! Régénérer les tuiles dans tile_1024_split depuis les geotiff de 1.4Gb de SITG. Environ 1h
#! Utiles si jamais les tuiles de 1024 sont corrompues ou effacées par erreur
REGENERATE_TILE_1024_SPLIT_FROM_SITG = False
REGENERATE_TILE_1024_SPLIT_FROM_SITG_NUM_PROCESSES = 2
REGENERATE_TILE_1024_SPLIT_FROM_SITG_NUM_THREADS = 2
REGENERATE_TILE_1024_SPLIT_FROM_SITG_COMBINED_METADATA_PARQUET = "data/notebook_04/geotiff/tile_1024_split_old_20250519-120028/combined_metadata.parquet"
REGENERATE_TILE_1024_SPLIT_FROM_SITG_OUTPUT_DIR_TILE_1024 = "data/notebook_04/geotiff/tile_1024_split"
REGENERATE_TILE_1024_SPLIT_FROM_SITG_OUTPUT_DIR_TEMP_1280 = 'data/notebook_04/geotiff/tile_1280_split'

## Optional tile_1024_split Regeneration

Note: optionnel, c'est dans le cas ou le tile_1024_split a été modifié par erreur

Note: regenerate tile_1024_split tiles from SITG GeoTIFFs if source files are corrupted or missing.

In [None]:
def process_tile(row, output_dir, debug_dir):
    """
    Process a single tile from source GeoTIFF.
    
    Extracts the specified bounds from the source GeoTIFF and saves
    as a new tile with proper georeferencing.
    
    Parameters:
        row: DataFrame row containing tile metadata
        output_dir: Directory to save processed tiles
        debug_dir: Directory for error logs
        
    Returns:
        tuple: (index, success, error_message)
    """
    try:
        # Extract bounds from buffered_bounds column
        bounds_str = row['buffered_bounds']
        if isinstance(bounds_str, str):
            bounds_str = bounds_str.replace(' ', '')
            bounds = tuple(float(x) for x in bounds_str.strip('()').split(','))
        else:
            bounds = bounds_str
        
        min_x, min_y, max_x, max_y = bounds
        
        # Extract filename and create output path
        tile_path = row['tile_path']
        output_filename = os.path.basename(tile_path)
        output_path = os.path.join(output_dir, output_filename)
        
        # Open source GeoTIFF
        with rasterio.open(row['geotiff_path']) as src:
            # Get window from original bounds
            window = from_bounds(min_x, min_y, max_x, max_y, src.transform)
            
            # Round window coordinates to avoid floating point issues
            window = rasterio.windows.Window(
                col_off=int(round(window.col_off)),
                row_off=int(round(window.row_off)),
                width=int(round(window.width)),
                height=int(round(window.height))
            )
            
            # Verify window fits within source bounds
            if (window.col_off < 0 or window.row_off < 0 or 
                window.col_off + window.width > src.width or 
                window.row_off + window.height > src.height):
                
                # Adjust window to fit within image bounds
                window = window.intersection(
                    rasterio.windows.Window(0, 0, src.width, src.height)
                )
            
            # Read data from window
            data = src.read(window=window)
            
            # Get transform for the window
            window_transform = rasterio.windows.transform(window, src.transform)
            
            # Create output profile with maximum quality settings
            profile = src.profile.copy()
            profile.update({
                'height': window.height,
                'width': window.width,
                'transform': window_transform,
                'crs': CORRECT_CRS,
                'driver': 'GTiff',
                'compress': None,  # No compression for highest quality
                'predictor': 1,
                'tiled': False,
                'interleave': 'band',
                'bigtiff': True,
                'dtype': src.dtypes[0],
            })
            
            # Write new GeoTIFF
            with rasterio.open(output_path, 'w', **profile) as dst:
                dst.write(data)
        
        return (row.name, True, None)
    
    except Exception as e:
        # Capture full exception traceback
        tb = traceback.format_exc()
        return (row.name, False, f"Error: {str(e)}\n{tb}")


def copy_file(args):
    """
    Copy a single file with error handling.
    
    Parameters:
        args: Tuple of (source_file, destination_file)
        
    Returns:
        tuple: (success, message)
    """
    src_file, dst_file = args
    try:
        shutil.copy2(src_file, dst_file)
        return (True, src_file)
    except Exception as e:
        return (False, f"Error copying {src_file} to {dst_file}: {str(e)}")


def process_geotiffs(chunk_df, output_dir, debug_dir):
    """
    Process a chunk of GeoTIFF tiles.
    
    Parameters:
        chunk_df: DataFrame chunk to process
        output_dir: Output directory for tiles
        debug_dir: Directory for error logs
        
    Returns:
        list: Processing results for each tile
    """
    results = []
    # Process each row in the chunk
    for idx, row in chunk_df.iterrows():
        result = process_tile(row, output_dir, debug_dir)
        results.append(result)
    return results


if REGENERATE_TILE_1024_SPLIT_FROM_SITG:
    # Suppress warnings
    warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)
    
    # Load metadata
    df = pd.read_parquet(REGENERATE_TILE_1024_SPLIT_FROM_SITG_COMBINED_METADATA_PARQUET)
    
    # Create output directories
    os.makedirs(REGENERATE_TILE_1024_SPLIT_FROM_SITG_OUTPUT_DIR_TEMP_1280, exist_ok=True)
    debug_dir = os.path.join(REGENERATE_TILE_1024_SPLIT_FROM_SITG_OUTPUT_DIR_TEMP_1280, 'debug')
    os.makedirs(debug_dir, exist_ok=True)
    
    # Group by source GeoTIFF for efficient processing
    grouped = df.groupby('geotiff_path')
    group_dfs = [group for _, group in grouped]
    
    print(f"Processing {len(df)} tiles from {len(group_dfs)} source GeoTIFFs using {REGENERATE_TILE_1024_SPLIT_FROM_SITG_NUM_PROCESSES} processes")
    
    # Process chunks in parallel
    with ProcessPoolExecutor(max_workers=REGENERATE_TILE_1024_SPLIT_FROM_SITG_NUM_PROCESSES) as executor:
        futures = [executor.submit(process_geotiffs, group_df, REGENERATE_TILE_1024_SPLIT_FROM_SITG_OUTPUT_DIR_TEMP_1280, debug_dir) 
                  for group_df in group_dfs]
        
        # Track progress
        all_results = []
        for future in tqdm(futures, total=len(futures), desc="Processing GeoTIFF groups"):
            results = future.result()
            all_results.extend(results)
    
    # Process results
    success_count = 0
    error_count = 0
    
    for idx, success, error_msg in all_results:
        if success:
            success_count += 1
        else:
            error_count += 1
            # Log error
            error_info_path = os.path.join(debug_dir, f"error_row_{idx}.txt")
            with open(error_info_path, 'w') as f:
                f.write(error_msg)
    
    print(f"Tile processing complete: {success_count} successful, {error_count} errors")
    
    # Copy additional files maintaining directory structure
    src_dir = REGENERATE_TILE_1024_SPLIT_FROM_SITG_OUTPUT_DIR_TILE_1024
    dst_dir = REGENERATE_TILE_1024_SPLIT_FROM_SITG_OUTPUT_DIR_TEMP_1280
    
    # Get list of processed TIF files
    processed_tif_files = set(os.path.basename(row['tile_path']) for _, row in df.iterrows())
    
    # Build list of files to copy
    files_to_copy = []
    print("\nScanning directory structure...")
    for root, dirs, files in os.walk(src_dir):
        rel_path = os.path.relpath(root, src_dir)
        
        # Create corresponding directory
        if rel_path != '.':
            dst_root = os.path.join(dst_dir, rel_path)
            os.makedirs(dst_root, exist_ok=True)
            print(f"Created directory: {rel_path}")
        
        # Add files to copy list
        for file in files:
            # Skip already processed TIF files
            if rel_path == '.' and file.endswith('.tif') and file in processed_tif_files:
                continue
            
            src_file = os.path.join(root, file)
            dst_file = os.path.join(dst_dir, rel_path, file) if rel_path != '.' else os.path.join(dst_dir, file)
            files_to_copy.append((src_file, dst_file))
    
    # Copy files in parallel
    print(f"\nCopying {len(files_to_copy)} additional files using {REGENERATE_TILE_1024_SPLIT_FROM_SITG_NUM_THREADS} threads...")
    with ThreadPoolExecutor(max_workers=REGENERATE_TILE_1024_SPLIT_FROM_SITG_NUM_THREADS) as executor:
        results = list(tqdm(executor.map(copy_file, files_to_copy), total=len(files_to_copy), desc="Copying files"))
    
    # Check for copy errors
    copy_errors = [result for result in results if not result[0]]
    if copy_errors:
        print(f"Warning: {len(copy_errors)} files failed to copy:")
        for _, error in copy_errors[:10]:  # Show first 10 errors
            print(f"  {error}")
        if len(copy_errors) > 10:
            print(f"  ... and {len(copy_errors) - 10} more errors")
    
    # Verify file counts
    def count_files(directory):
        count = 0
        for root, _, files in os.walk(directory):
            count += len(files)
        return count
    
    old_count = count_files(src_dir)
    new_count = count_files(dst_dir)
    
    print(f"\nTotal files in old directory: {old_count}")
    print(f"Total files in new directory: {new_count}")
    
    # Calculate expected difference
    expected_diff = len(processed_tif_files)
    actual_diff = new_count - old_count + expected_diff
    
    print(f"Expected difference: {expected_diff}")
    print(f"Actual difference: {actual_diff}")
    
    if actual_diff != 0:
        print("WARNING: File count doesn't match expectations!")
        proceed = input("Do you want to proceed with the renaming? (y/n): ")
        if proceed.lower() != 'y':
            print("Operation aborted")
            exit()

    # Rename directories
    todaysdate = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    old_folder_new_name = "data/notebook_04/geotiff/tile_1024_split_old_" + str(todaysdate)
    
    print(f"\nRenaming old folder '{src_dir}' to '{old_folder_new_name}'")
    os.rename(src_dir, old_folder_new_name)
    
    print(f"Renaming new folder '{dst_dir}' to '{src_dir}'")
    os.rename(dst_dir, src_dir)

    print("\nProcessing complete!")

## Load Data
### Pre-annotation Dataset

In [None]:
gdf_dataset = gpd.read_file(DATASET_TILES_INFORMATION_CSV_PATH)
# Remove duplicate geometry columns if present
if "geometry" in gdf_dataset.columns:
    gdf_dataset = gdf_dataset.drop(columns=["geometry"])
if "geometry_x" in gdf_dataset.columns:
    gdf_dataset = gdf_dataset.drop(columns=["geometry_x"])


In [None]:
print(type(gdf_dataset))
gdf_dataset.head()

In [None]:
# Verify no duplicates in tile_id
assert(len(gdf_dataset[gdf_dataset.duplicated(subset=["tile_id"])]) == 0), f"gdf_dataset has duplicates in tile_id: {gdf_dataset[gdf_dataset.duplicated(subset=['tile_id'])]}"

### Annotated Dataset

In [None]:
# Load image paths
dataset_img_path = os.path.join(DATASET_ANNOTATED_BRUT_PATH, [f for f in os.listdir(DATASET_ANNOTATED_BRUT_PATH) if f.startswith("dataset")][0], "img")
assert(os.path.exists(dataset_img_path)), f"Path does not exist: {dataset_img_path}"
print(f"dataset_masks_path: {dataset_img_path}")
print(f"Number of files: {len(os.listdir(dataset_img_path))}")

# Load mask paths
dataset_masks_path = os.path.join(DATASET_ANNOTATED_BRUT_PATH, [f for f in os.listdir(DATASET_ANNOTATED_BRUT_PATH) if f.startswith("dataset")][0], "masks_machine")
assert(os.path.exists(dataset_masks_path)), f"Path does not exist: {dataset_masks_path}"
print(f"dataset_masks_path: {dataset_masks_path}")
print(f"Number of files: {len(os.listdir(dataset_masks_path))}")

# Verify equal number of images and masks
assert(len(os.listdir(dataset_img_path)) == len(os.listdir(dataset_masks_path))), f"Number of files in {dataset_masks_path} is not equal to number of files in {dataset_masks_path}"

In [None]:
# Create dataframe with annotation paths
dataset_original_masks_path_list = [os.path.join(dataset_masks_path, f) for f in os.listdir(dataset_masks_path)]
dataset_original_img_path_list = [os.path.join(dataset_img_path, f) for f in os.listdir(dataset_img_path)]

df_annotations = pd.DataFrame(
    {
        "tile_id": [os.path.basename(f).split(".")[0] for f in dataset_original_masks_path_list],
        "original_mask_path_png": dataset_original_masks_path_list,
        "original_img_path_png": dataset_original_img_path_list,
    }
)
# Extract tile_id from filename
df_annotations["tile_id"] = df_annotations["tile_id"].apply(lambda x: "_".join(x.split("_")[2:]))

# Validate data integrity
assert(len(df_annotations[df_annotations.duplicated(subset=["tile_id"])]) == 0), f"gdf_dataset has duplicates in tile_id: {df_annotations[df_annotations.duplicated(subset=['tile_id'])]}"
assert(len(df_annotations[df_annotations.duplicated(subset=["original_mask_path_png"])]) == 0), f"gdf_dataset has duplicates in original_mask_path_png: {df_annotations[df_annotations.duplicated(subset=['original_mask_path_png'])]}"
assert(len(df_annotations[df_annotations.duplicated(subset=["original_img_path_png"])]) == 0), f"gdf_dataset has duplicates in original_img_path_png: {df_annotations[df_annotations.duplicated(subset=['original_img_path_png'])]}"

# Check for null values
assert(df_annotations["tile_id"].notnull().all()), f"df_annotations has null values in tile_id: {df_annotations[df_annotations["tile_id"].isnull()]}"
assert(df_annotations["original_mask_path_png"].notnull().all()), f"df_annotations has null values in original_mask_path_png: {df_annotations[df_annotations["original_mask_path_png"].isnull()]}"
assert(df_annotations["original_img_path_png"].notnull().all()), f"df_annotations has null values in original_img_path_png: {df_annotations[df_annotations["original_img_path_png"].isnull()]}"

# Verify counts match
assert(len(df_annotations["tile_id"]) == len(gdf_dataset["tile_id"])), f"len(df_annotations['tile_id']) is not equal to len(gdf_dataset['tile_id']): {len(df_annotations['tile_id'])} != {len(gdf_dataset['tile_id'])}"

display(df_annotations)

In [None]:
# Merge annotation paths with dataset
gdf_dataset = gdf_dataset.merge(
    df_annotations,
    how="left",
    left_on="tile_id",
    right_on="tile_id",
)

In [None]:
# Validate merged data
assert(len(gdf_dataset[gdf_dataset.duplicated(subset=["tile_id"])]) == 0), f"gdf_dataset has duplicates in tile_id: {gdf_dataset[gdf_dataset.duplicated(subset=['tile_id'])]}"
assert(len(gdf_dataset[gdf_dataset.duplicated(subset=["original_img_path_png"])]) == 0), f"gdf_dataset has duplicates in img_path: {gdf_dataset[gdf_dataset.duplicated(subset=['original_img_path_png'])]}"
assert(len(gdf_dataset[gdf_dataset.duplicated(subset=["original_mask_path_png"])]) == 0), f"gdf_dataset has duplicates in original_mask_path_png: {gdf_dataset[gdf_dataset.duplicated(subset=['original_mask_path_png'])]}"

# Check for null values
assert(gdf_dataset["tile_id"].notnull().all()), f"gdf_dataset has null values in tile_id: {gdf_dataset[gdf_dataset["tile_id"].isnull()]}"
assert(gdf_dataset["original_img_path_png"].notnull().all()), f"gdf_dataset has null values in original_img_path_png: {gdf_dataset[gdf_dataset["original_img_path_png"].isnull()]}"
assert(gdf_dataset["original_mask_path_png"].notnull().all()), f"gdf_dataset has null values in original_mask_path_png: {gdf_dataset[gdf_dataset["original_mask_path_png"].isnull()]}"

# Verify record counts
assert(len(gdf_dataset["tile_id"]) == len(df_annotations["tile_id"])), f"len(gdf_dataset['tile_id']) is not equal to len(df_annotations['tile_id']): {len(gdf_dataset['tile_id'])} != {len(df_annotations['tile_id'])}"

In [None]:
gdf_dataset.head()

### Building Polygons

In [None]:
gdf_cad_batiment_horsol = gpd.read_parquet(CAD_BATIMENT_HORSOL_TOIT_MERGE_PARQUET_PATH)
print(type(gdf_cad_batiment_horsol))
gdf_cad_batiment_horsol.head()

### Enrich Dataset with Spatial Information

In [None]:
def get_geometry_from_tiff(tiff_path, crs=CORRECT_CRS):
    """
    Extract spatial geometry from a GeoTIFF file.
    
    Creates a polygon representing the spatial extent of the GeoTIFF
    based on its georeferencing information.
    
    Parameters:
        tiff_path: Path to GeoTIFF file
        crs: Expected coordinate reference system
        
    Returns:
        Polygon: Spatial footprint of the GeoTIFF, or None if error
    """
    try:
        with rasterio.open(tiff_path) as src:
            # Extract geospatial information
            transform = src.transform
            height, width = src.shape
            crs_src = src.crs
            
            # Check CRS consistency
            if crs_src != crs:
                print(f"Warning: CRS mismatch in {tiff_path}. Found {crs_src}, expected {crs}")
            
            # Calculate corner coordinates
            minx = transform[2]
            maxy = transform[5]
            miny = maxy + height * transform[4]
            maxx = minx + width * transform[0]
            
            # Create polygon
            polygon = Polygon([
                (minx, miny), (maxx, miny), (maxx, maxy), (minx, maxy), (minx, miny)
            ])
            
            # Validate geometry
            if not polygon.is_valid:
                print(f"Warning: Invalid polygon created from {tiff_path}")
                polygon = polygon.buffer(0)  # Attempt to fix invalid geometry
            
            if polygon.area <= 0:
                print(f"Warning: Zero-area polygon created from {tiff_path}")
                return None
                
            return polygon
    except Exception as e:
        print(f"Error processing {tiff_path}: {e}")
        return None


def get_image_dimensions(image_path):
    """
    Get dimensions of an image file.
    
    Parameters:
        image_path: Path to image file
        
    Returns:
        tuple: (width, height) or (None, None) if error
    """
    try:
        if os.path.exists(image_path):
            with Image.open(image_path) as img:
                width, height = img.size
                return width, height
        else:
            return None, None
    except Exception as e:
        print(f"Error opening {image_path}: {e}")
        return None, None


def create_geodataframe_from_tiffs(df):
    """
    Create a GeoDataFrame from TIF file paths with spatial geometries.
    
    Processes each GeoTIFF to extract its spatial footprint and creates
    a GeoDataFrame with proper geometries for spatial operations.
    
    Parameters:
        df: DataFrame containing 'tile_path' column
        
    Returns:
        GeoDataFrame: Input dataframe enhanced with geometry column
    """
    # Initialize containers for geometries
    geometries = []
    indices = []
    
    # Process each file
    print("Processing GeoTIFF files...")
    total_files = len(df)
    
    for idx, row in df.iterrows():
        tiff_path = row['tile_path']
        
        # Verify file exists
        if os.path.exists(tiff_path):
            # Extract geometry
            geometry = get_geometry_from_tiff(tiff_path)
            
            # Store valid results
            if geometry is not None:
                geometries.append(geometry)
                indices.append(idx)
        else:
            print(f"File not found: {tiff_path}")
    
    print(f"Successfully processed {len(geometries)} out of {total_files} files")
    
    # Create filtered dataframe with valid geometries
    df_processed = df.loc[indices].copy()
    
    # Create GeoDataFrame
    gdf = gpd.GeoDataFrame(
        df_processed,
        geometry=geometries,
        crs=EPSG_SUISSE
    )
    
    return gdf

# Process dataset
gdf_dataset = create_geodataframe_from_tiffs(gdf_dataset)

# Add image dimensions
image_dimensions = gdf_dataset['tile_path'].apply(get_image_dimensions)
gdf_dataset['image_width'], gdf_dataset['image_height'] = zip(*image_dimensions)

In [None]:
gdf_dataset.head()

In [None]:
# Validate image dimensions
assert(gdf_dataset["image_width"].notnull().all()), f"gdf_dataset has null values in image_width: {gdf_dataset[gdf_dataset['image_width'].isnull()]}"
assert(gdf_dataset["image_height"].notnull().all()), f"gdf_dataset has null values in image_height: {gdf_dataset[gdf_dataset['image_height'].isnull()]}"

In [None]:
print(type(gdf_dataset))
gdf_dataset.head(2)

## Process Dataset Images
### Clip Images Using Building Polygons

In [None]:
def clip_geotiff_and_png_masks(gdf_dataset, gdf_buildings, output_img_dir, output_mask_dir, convert_masks_to_geotiff=True):
    """
    Clip GeoTIFF images and masks using building polygons.
    
    Removes pixels outside building footprints by setting them to zero.
    Ensures image and mask pairs maintain identical dimensions and alignment.
    
    Parameters:
        gdf_dataset: GeoDataFrame with tile paths and geometries
        gdf_buildings: GeoDataFrame containing building polygons
        output_img_dir: Directory to save clipped images
        output_mask_dir: Directory to save clipped masks
        convert_masks_to_geotiff: Convert PNG masks to GeoTIFF format
        
    Returns:
        tuple: (processed_tile_ids, skipped_tiles_dict)
    """
    # Track progress
    successful_img_count = 0
    successful_mask_count = 0
    error_count = 0
    
    # Track processing status
    processed_tile_ids = []
    skipped_tile_ids = []
    skipped_reasons = {}
    
    # Process each file pair
    for idx, row in tqdm(gdf_dataset.iterrows(), total=len(gdf_dataset), desc="Clipping files"):
        # Get file paths
        tiff_path = row['tile_path']
        mask_path = row.get('original_mask_path_png')
        tile_id = row['tile_id']
        
        # Validate paths
        if pd.isna(tiff_path):
            skipped_tile_ids.append(tile_id)
            skipped_reasons[tile_id] = "Missing tiff_path"
            continue
            
        if not os.path.exists(tiff_path):
            skipped_tile_ids.append(tile_id)
            skipped_reasons[tile_id] = f"TIFF file not found: {tiff_path}"
            continue
            
        # Create output paths
        output_img_path = os.path.join(output_img_dir, os.path.basename(tiff_path))
        
        # Handle mask path
        if pd.isna(mask_path):
            mask_path = None
            output_mask_path = None
            skipped_reasons[tile_id] = "Missing mask_path"
        elif not os.path.exists(mask_path):
            mask_path = None
            output_mask_path = None
            skipped_reasons[tile_id] = f"Mask file not found: {mask_path}"
        else:
            # Change extension if converting to GeoTIFF
            if convert_masks_to_geotiff:
                mask_basename = os.path.splitext(os.path.basename(mask_path))[0] + '.tif'
                output_mask_path = os.path.join(output_mask_dir, mask_basename)
            else:
                output_mask_path = os.path.join(output_mask_dir, os.path.basename(mask_path))
        
        try:
            # Get tile geometry
            tile_geom = row.geometry
            
            if tile_geom is None:
                skipped_tile_ids.append(tile_id)
                skipped_reasons[tile_id] = "Missing geometry"
                continue
                
            # Find intersecting buildings
            buildings_in_tile = gdf_buildings[gdf_buildings.intersects(tile_geom)]
            
            if len(buildings_in_tile) == 0:
                skipped_tile_ids.append(tile_id)
                skipped_reasons[tile_id] = "No intersecting buildings found"
                continue
            
            # Process GeoTIFF and mask together
            with rasterio.open(tiff_path) as src:
                # Get source metadata
                src_meta = src.meta.copy()
                original_height, original_width = src.height, src.width
                
                # Convert building geometries to GeoJSON format
                shapes = [mapping(geom) for geom in buildings_in_tile.geometry]
                
                # Create mask for clipping
                masked_data, mask_transform = rasterio.mask.mask(
                    src, 
                    shapes, 
                    crop=False, 
                    all_touched=True,
                    invert=True,
                    filled=True,
                    nodata=0
                )
                
                # Create binary mask
                binary_mask = (masked_data[0] == 0).astype(np.uint8)
                
                # Apply mask to all bands
                original_img = src.read()
                masked_img = original_img.copy()
                
                for i in range(masked_img.shape[0]):
                    masked_img[i][binary_mask == 0] = 0
                
                # Update metadata
                out_meta = src_meta.copy()
                
                # Write masked image
                with rasterio.open(output_img_path, 'w', **out_meta) as dest:
                    dest.write(masked_img)
                
                successful_img_count += 1
                
                # Process mask if it exists
                if mask_path is not None:
                    try:
                        # Open PNG mask
                        with Image.open(mask_path) as mask_img:
                            mask_array = np.array(mask_img)
                            
                            # Check dimension consistency
                            if mask_array.shape[:2] != (original_height, original_width):
                                tiff_dims = f"{original_width}x{original_height}"
                                mask_dims = f"{mask_array.shape[1]}x{mask_array.shape[0]}"
                                print(f"Warning: Mask dimensions don't match GeoTIFF for {os.path.basename(tiff_path)}")
                                print(f"  - GeoTIFF dimensions: {tiff_dims}")
                                print(f"  - Mask dimensions: {mask_dims}")
                                skipped_reasons[tile_id] = f"Mask dimensions don't match GeoTIFF: GeoTIFF={tiff_dims}, Mask={mask_dims}"
                                continue

                            # Apply same binary mask to ensure identical dimensions
                            if len(mask_array.shape) == 3:  # RGB or RGBA
                                for i in range(mask_array.shape[2]):
                                    mask_array[:, :, i] = mask_array[:, :, i] * binary_mask
                            else:  # Grayscale
                                mask_array = mask_array * binary_mask
                            
                            if convert_masks_to_geotiff:
                                # Create GeoTIFF metadata for mask
                                mask_meta = src_meta.copy()
                                
                                # Update metadata
                                if len(mask_array.shape) == 3:  # RGB or RGBA
                                    mask_meta.update(
                                        dtype=mask_array.dtype,
                                        count=mask_array.shape[2],
                                        nodata=0,
                                    )
                                else:  # Grayscale
                                    mask_meta.update(
                                        dtype=mask_array.dtype,
                                        count=1,
                                        nodata=0,
                                    )
                                
                                # Write mask as GeoTIFF
                                with rasterio.open(output_mask_path, 'w', **mask_meta) as dest:
                                    if len(mask_array.shape) == 3:  # RGB or RGBA
                                        for i in range(mask_array.shape[2]):
                                            dest.write(mask_array[:, :, i], i+1)
                                    else:  # Grayscale
                                        dest.write(mask_array, 1)
                            else:
                                # Save as PNG
                                Image.fromarray(mask_array).save(output_mask_path)
                                
                            successful_mask_count += 1
                            
                            # Mark as successfully processed
                            processed_tile_ids.append(tile_id)
                            
                    except Exception as e:
                        error_count += 1
                        skipped_tile_ids.append(tile_id)
                        skipped_reasons[tile_id] = f"Error processing mask: {str(e)}"
                        print(f"Error processing mask {mask_path}: {e}")
                else:
                    # No mask but image processed successfully
                    processed_tile_ids.append(tile_id)
                        
        except Exception as e:
            error_count += 1
            skipped_tile_ids.append(tile_id)
            skipped_reasons[tile_id] = f"Error: {str(e)}"
            print(f"Error processing {tiff_path}: {e}")
    
    # Remove duplicates
    processed_tile_ids = list(set(processed_tile_ids))
    skipped_tile_ids = list(set(skipped_tile_ids))
    
    # Check for overlap between processed and skipped
    overlap = set(processed_tile_ids) & set(skipped_tile_ids)
    
    print(f"Completed processing {len(gdf_dataset)} files:")
    print(f"- Successfully processed {successful_img_count} GeoTIFFs")
    print(f"- Successfully processed {successful_mask_count} masks")
    if convert_masks_to_geotiff:
        print(f"- Converted {successful_mask_count} PNG masks to GeoTIFF format")
    print(f"- Encountered {error_count} errors")
    print(f"- Successfully processed tiles: {len(processed_tile_ids)}")
    print(f"- Skipped tiles: {len(skipped_tile_ids)}")
    
    if overlap:
        print(f"Warning: {len(overlap)} tile_ids appear in both processed and skipped lists")
     
    # Analyze skipped reasons
    if skipped_tile_ids:
        # Group dimension mismatches
        dimension_mismatches = [reason for tile_id, reason in skipped_reasons.items() 
                            if "Mask dimensions don't match" in reason]
        
        print("\nDimension mismatch summary:")
        print(f"  - Total files with dimension mismatches: {len(dimension_mismatches)}")
        
        # Analyze dimension patterns
        if dimension_mismatches:
            import re
            geotiff_dims = []
            mask_dims = []
            pattern = r"GeoTIFF=(\d+x\d+), Mask=(\d+x\d+)"
            
            for reason in dimension_mismatches:
                match = re.search(pattern, reason)
                if match:
                    geotiff_dims.append(match.group(1))
                    mask_dims.append(match.group(2))
            
            # Count frequency
            from collections import Counter
            geotiff_counter = Counter(geotiff_dims)
            mask_counter = Counter(mask_dims)
            
            print("\nMost common GeoTIFF dimensions:")
            for dims, count in geotiff_counter.most_common(3):
                print(f"  - {dims}: {count} files")
                
            print("\nMost common mask dimensions:")
            for dims, count in mask_counter.most_common(3):
                print(f"  - {dims}: {count} files")
        
        # Convert to dictionary
        skipped_tiles = {tile_id: skipped_reasons[tile_id] for tile_id in skipped_tile_ids}
        
        return processed_tile_ids, skipped_tiles

# Execute clipping
processed_tile_ids, skipped_tiles = clip_geotiff_and_png_masks(
    gdf_dataset, 
    gdf_cad_batiment_horsol, 
    DATASET_OUTPUT_IMG_PATH, 
    DATASET_OUTPUT_MASKS_PATH, 
    convert_masks_to_geotiff=True
)

print(f"Processed tile_ids: {len(processed_tile_ids)}")
print(f"Skipped tile_ids: {len(skipped_tiles)}")


### Update Dataset with Processed File Paths

In [None]:
# Get list of processed files
dataset_processed_masks_path_list = [os.path.join(DATASET_OUTPUT_MASKS_PATH, f) for f in os.listdir(DATASET_OUTPUT_MASKS_PATH)]
dataset_processed_img_path_list = [os.path.join(DATASET_OUTPUT_IMG_PATH, f) for f in os.listdir(DATASET_OUTPUT_IMG_PATH)]

df_processed = pd.DataFrame(
    {
        "tile_id": [os.path.basename(f).split(".")[0] for f in dataset_processed_masks_path_list],
        "processed_mask_path_tif": dataset_processed_masks_path_list,
        "processed_img_path_tif": dataset_processed_img_path_list,
    }
)
# Extract tile_id
df_processed["tile_id"] = df_processed["tile_id"].apply(lambda x: "_".join(x.split("_")[2:]))


In [None]:
# Validate processed data
assert(len(df_processed[df_processed.duplicated(subset=["tile_id"])]) == 0), f"df_processed has duplicates in tile_id: {df_processed[df_processed.duplicated(subset=['tile_id'])]}"
assert(len(df_processed[df_processed.duplicated(subset=["processed_img_path_tif"])]) == 0), f"df_processed has duplicates in processed_img_path_tif: {df_processed[df_processed.duplicated(subset=['processed_img_path_tif'])]}"
assert(len(df_processed[df_processed.duplicated(subset=["processed_mask_path_tif"])]) == 0), f"df_processed has duplicates in processed_mask_path_tif: {df_processed[df_processed.duplicated(subset=['processed_mask_path_tif'])]}"

# Check for null values
assert(df_processed["tile_id"].isnull().sum() == 0), f"df_processed has null in tile_id: {df_processed[df_processed['tile_id'].isnull()]}"
assert(df_processed["processed_img_path_tif"].isnull().sum() == 0), f"df_processed has null in processed_img_path_tif: {df_processed[df_processed['processed_img_path_tif'].isnull()]}"
assert(df_processed["processed_mask_path_tif"].isnull().sum() == 0), f"df_processed has null in processed_mask_path_tif: {df_processed[df_processed['processed_mask_path_tif'].isnull()]}"

display(df_processed.head())

In [None]:
# Merge processed paths
gdf_dataset = gdf_dataset.merge(
    df_processed,
    how="left",
    left_on="tile_id",
    right_on="tile_id",
)

### Handle Tile Overlaps
#### Detect Overlap Positions

In [None]:
def determine_relative_position(geom1, geom2, tolerance=0.5):
    """
    Determine the relative position of geom1 with respect to geom2.
    
    Accounts for Swiss coordinate system (EPSG:2056) where Y increases northward.
    Note: QGIS display orientation differs from coordinate values.
    
    Parameters:
        geom1: First geometry
        geom2: Second geometry  
        tolerance: Tolerance factor for position determination
        
    Returns:
        str: Relative position descriptor
    """
    # Get bounding boxes
    minx1, miny1, maxx1, maxy1 = geom1.bounds
    minx2, miny2, maxx2, maxy2 = geom2.bounds
    
    # Calculate centers
    center_x1 = (minx1 + maxx1) / 2
    center_y1 = (miny1 + maxy1) / 2
    center_x2 = (minx2 + maxx2) / 2
    center_y2 = (miny2 + maxy2) / 2
    
    # Calculate average dimensions for tolerance
    avg_width = ((maxx1 - minx1) + (maxx2 - minx2)) / 2
    avg_height = ((maxy1 - miny1) + (maxy2 - miny2)) / 2
    
    # Scale tolerance by dimensions
    x_tolerance = tolerance * avg_width
    y_tolerance = tolerance * avg_height
    
    # Calculate overlap percentage
    intersection = geom1.intersection(geom2)
    intersection_area = intersection.area
    
    # Initialize position components
    vertical_position = None
    horizontal_position = None
    
    # Determine vertical position (Y-axis)
    vertical_diff = center_y1 - center_y2
    if abs(vertical_diff) <= y_tolerance:
        vertical_position = None
    elif vertical_diff > 0:
        vertical_position = "bottom"  # geom1 north of geom2
    else:
        vertical_position = "top"     # geom1 south of geom2
    
    # Determine horizontal position (X-axis)
    horizontal_diff = center_x1 - center_x2
    if abs(horizontal_diff) <= x_tolerance:
        horizontal_position = None
    elif horizontal_diff > 0:
        horizontal_position = "left"   # geom1 east of geom2
    else:
        horizontal_position = "right"  # geom1 west of geom2
    
    # Calculate overlap percentage
    smaller_area = min(geom1.area, geom2.area)
    overlap_percentage = (intersection_area / smaller_area) * 100 if smaller_area > 0 else 0
    
    # Combine positions
    if vertical_position and horizontal_position:
        position = f"{vertical_position}-{horizontal_position}"
    elif vertical_position:
        position = vertical_position
    elif horizontal_position:
        position = horizontal_position
    else:
        position = "substantial-overlap" if overlap_percentage > 90 else "center"
    
    return position


def get_opposite_position(position):
    """
    Get the opposite relative position.
    
    Parameters:
        position: Current position string
        
    Returns:
        str: Opposite position
    """
    position_map = {
        'top': 'bottom',
        'bottom': 'top',
        'left': 'right',
        'right': 'left',
        'top-left': 'bottom-right',
        'top-right': 'bottom-left',
        'bottom-left': 'top-right',
        'bottom-right': 'top-left',
        'center': 'center',
        'substantial-overlap': 'substantial-overlap'
    }
    return position_map.get(position, position)


def check_geotiffs_overlap(geom1, geom2, min_overlap_area=0.0):
    """
    Calculate overlap information between two geometries.
    
    Parameters:
        geom1: First geometry
        geom2: Second geometry
        min_overlap_area: Minimum area to consider as overlap
        
    Returns:
        dict: Overlap information including area and position
    """
    result = {
        'overlaps': False,
        'overlap_area': 0.0,
        'relative_position': None,
        'overlap_percentage_1': 0.0,
        'overlap_percentage_2': 0.0
    }
    
    if geom1.intersects(geom2):
        intersection = geom1.intersection(geom2)
        overlap_area = intersection.area
        
        if overlap_area > min_overlap_area:
            result['overlaps'] = True
            result['overlap_area'] = overlap_area
            result['relative_position'] = determine_relative_position(geom1, geom2)
            result['overlap_percentage_1'] = (overlap_area / geom1.area) * 100
            result['overlap_percentage_2'] = (overlap_area / geom2.area) * 100
    
    return result


def check_overlaps_in_dataframe(gdf, min_overlap_area=1.0, include_symmetric=False, buffer_distance=0.01):
    """
    Check for overlapping geometries in a GeoDataFrame.
    
    Uses spatial indexing for efficient overlap detection and handles
    buffered geometries for near-overlaps.
    
    Parameters:
        gdf: GeoDataFrame with geometry column
        min_overlap_area: Minimum overlap area to report
        include_symmetric: Include both directions of overlap
        buffer_distance: Buffer to apply for near-overlap detection
        
    Returns:
        DataFrame: Overlap information for each pair
    """
    overlap_results = []
    n = len(gdf)
    
    try:
        if not isinstance(gdf, gpd.GeoDataFrame):
            raise TypeError("Input must be a GeoDataFrame")
        
        if n == 0:
            raise ValueError("GeoDataFrame is empty")
        
        print("Creating spatial index...")
        sindex = gdf.sindex
        
        print(f"Checking overlaps among {n} geometries...")
        with tqdm(total=n, desc="Checking overlaps") as pbar:
            for i in range(n):
                geom1 = gdf.iloc[i]['geometry']
                tile_id1 = gdf.iloc[i]['tile_id']
                
                if geom1 is None or not geom1.is_valid:
                    print(f"Warning: Skipping invalid geometry for {tile_id1}")
                    pbar.update(1)
                    continue
                
                # Find potential matches using spatial index
                bbox = geom1.bounds
                potential_matches_idx = list(sindex.intersection(bbox))
                
                # Remove self-intersection
                if i in potential_matches_idx:
                    potential_matches_idx.remove(i)
                    
                # Only check pairs once
                potential_matches_idx = [j for j in potential_matches_idx if j > i]
                
                for j in potential_matches_idx:
                    geom2 = gdf.iloc[j]['geometry']
                    tile_id2 = gdf.iloc[j]['tile_id']
                    
                    if geom2 is None or not geom2.is_valid:
                        print(f"Warning: Skipping invalid geometry for {tile_id2}")
                        continue
                    
                    # Apply buffer if specified
                    if buffer_distance > 0:
                        buffered_geom1 = geom1.buffer(buffer_distance)
                        buffered_geom2 = geom2.buffer(buffer_distance)
                    else:
                        buffered_geom1 = geom1
                        buffered_geom2 = geom2
                    
                    if buffered_geom1.intersects(buffered_geom2):
                        intersection = buffered_geom1.intersection(buffered_geom2)
                        
                        if not intersection.is_empty and intersection.area > min_overlap_area:
                            result = check_geotiffs_overlap(buffered_geom1, buffered_geom2, min_overlap_area)
                            
                            if result['overlaps']:
                                overlap_results.append({
                                    'tile_id1': tile_id1,
                                    'tile_id2': tile_id2,
                                    'index1': i,
                                    'index2': j,
                                    'overlap_area': result['overlap_area'],
                                    'relative_position': result['relative_position'],
                                    'overlap_percentage_1': result['overlap_percentage_1'],
                                    'overlap_percentage_2': result['overlap_percentage_2'],
                                    'buffered': buffer_distance > 0
                                })
                                
                                if include_symmetric:
                                    opposite_position = get_opposite_position(result['relative_position'])
                                    
                                    overlap_results.append({
                                        'tile_id1': tile_id2,
                                        'tile_id2': tile_id1,
                                        'index1': j,
                                        'index2': i,
                                        'overlap_area': result['overlap_area'],
                                        'relative_position': opposite_position,
                                        'overlap_percentage_1': result['overlap_percentage_2'],
                                        'overlap_percentage_2': result['overlap_percentage_1'],
                                        'buffered': buffer_distance > 0
                                    })
                
                pbar.update(1)
    
    except Exception as e:
        print(f"Error during overlap check: {e}")
        traceback.print_exc()
    
    if overlap_results:
        overlap_df = pd.DataFrame(overlap_results)
        print(f"Found {len(overlap_df)} overlapping pairs")
        
        # Summary by position
        position_counts = overlap_df['relative_position'].value_counts()
        print("\nOverlap positions before filtering:")
        for pos, count in position_counts.items():
            print(f"  {pos}: {count}")
            
        if 'buffered' in overlap_df.columns:
            buffered_count = overlap_df['buffered'].sum()
            print(f"\nOverlaps using buffered geometries: {buffered_count} ({(buffered_count/len(overlap_df))*100:.1f}%)")
            
        return overlap_df
    else:
        print("No overlapping pairs found")
        return pd.DataFrame(columns=['tile_id1', 'tile_id2', 'index1', 'index2', 
                                    'overlap_area', 'relative_position', 
                                    'overlap_percentage_1', 'overlap_percentage_2',
                                    'buffered'])

# Detect overlaps
overlap_df = check_overlaps_in_dataframe(
    gdf_dataset, 
    min_overlap_area=1.0, 
    include_symmetric=True, 
    buffer_distance=BUFFER_DISTANCE
)

# Display position types
print("\nUnique position types found:")
for pos in sorted(overlap_df['relative_position'].unique()):
    print(f"  {pos}")


#### Remove Overlapping Regions

In [None]:
def remove_overlap_in_geotiffs(overlap_df, gdf_dataset, overlap_positions=None, overwrite=True, buffer_distance=0.01):
    """
    Remove overlapping regions by setting pixels to zero in one of the tiles.
    
    For overlapping tile pairs, determines which tile should have its overlap
    region set to background (zero) based on relative position.
    
    Parameters:
        overlap_df: DataFrame with overlap information
        gdf_dataset: GeoDataFrame with file paths and geometries
        overlap_positions: List of positions to process
        overwrite: Whether to overwrite original files
        buffer_distance: Buffer distance for overlap calculation
        
    Returns:
        DataFrame: Information about processed files
    """
    
    # Default positions to process
    if overlap_positions is None:
        overlap_positions = ['right', 'top', 'top-right', 'bottom-right']
    
    # Track processed files
    processed_files = []
    failed_files = []
    
    if len(overlap_df) == 0:
        print("No overlaps to process")
        return pd.DataFrame()
    
    # Filter by position
    filtered_df = overlap_df.copy()
    if overlap_positions:
        position_filter = filtered_df['relative_position'].apply(
            lambda pos: any(p in pos for p in overlap_positions)
        )
        filtered_df = filtered_df[position_filter]
        print(f"Processing {len(filtered_df)} out of {len(overlap_df)} overlaps that match position criteria")
    
    if len(filtered_df) == 0:
        print("No overlaps match the specified positions")
        return pd.DataFrame()
    
    # Process each overlap
    with tqdm(total=len(filtered_df), desc="Processing overlaps") as pbar:
        for idx, row in filtered_df.iterrows():
            # Get indices and position
            index1 = row['index1']
            index2 = row['index2']
            position = row['relative_position']
            
            # Get file paths
            tiff_path1 = gdf_dataset.iloc[index1]['processed_img_path_tif']
            tiff_path2 = gdf_dataset.iloc[index2]['processed_img_path_tif']
            
            mask_path1 = gdf_dataset.iloc[index1]['processed_mask_path_tif']
            mask_path2 = gdf_dataset.iloc[index2]['processed_mask_path_tif']
            
            # Get geometries
            geom1 = gdf_dataset.iloc[index1]['geometry']
            geom2 = gdf_dataset.iloc[index2]['geometry']
            
            # Apply buffer if needed
            use_buffer = row.get('buffered', True)
            
            if use_buffer:
                buffered_geom1 = geom1.buffer(buffer_distance)
                buffered_geom2 = geom2.buffer(buffer_distance)
            else:
                buffered_geom1 = geom1
                buffered_geom2 = geom2
            
            # Get intersection
            intersection = buffered_geom1.intersection(buffered_geom2)
            
            # Check if intersection is valid
            if intersection.is_empty or intersection.area <= 0:
                failed_files.append({
                    'file_path': f"{tiff_path1} / {tiff_path2}",
                    'file_type': "both",
                    'position': position,
                    'error': "Empty intersection"
                })
                pbar.update(1)
                continue
            
            # Determine which file to modify based on position
            modify_idx1 = False
            
            if 'right' in position and 'left' not in position:
                modify_idx1 = True  # Modify left file
            elif 'left' in position and 'right' not in position:
                modify_idx1 = False  # Modify right file
            elif 'top' in position and 'bottom' not in position:
                modify_idx1 = False  # Modify bottom file
            elif 'bottom' in position and 'top' not in position:
                modify_idx1 = True  # Modify top file
            elif 'center' in position or 'substantial' in position:
                # Modify smaller area
                if geom1.area <= geom2.area:
                    modify_idx1 = True
                else:
                    modify_idx1 = False
            else:
                # Complex cases - use overlap percentage
                if row['overlap_percentage_1'] <= row['overlap_percentage_2']:
                    modify_idx1 = True
                else:
                    modify_idx1 = False
            
            # Set file paths based on decision
            if modify_idx1:
                img_to_modify = tiff_path1
                mask_to_modify = mask_path1
                overlap_with_img = tiff_path2
                overlap_with_mask = mask_path2
                tile_id = gdf_dataset.iloc[index1]['tile_id']
            else:
                img_to_modify = tiff_path2
                mask_to_modify = mask_path2
                overlap_with_img = tiff_path1
                overlap_with_mask = mask_path1
                tile_id = gdf_dataset.iloc[index2]['tile_id']
            
            # Process both image and mask
            for file_type, file_to_modify in [("image", img_to_modify), ("mask", mask_to_modify)]:
                if not os.path.exists(file_to_modify):
                    failed_files.append({
                        'file_path': file_to_modify,
                        'file_type': file_type,
                        'position': position,
                        'error': "File does not exist"
                    })
                    continue
                
                # Determine output path
                if overwrite:
                    output_file = file_to_modify
                else:
                    output_dir = os.path.dirname(file_to_modify)
                    base_name = os.path.basename(file_to_modify)
                    output_file = os.path.join(output_dir, f"overlap_fixed_{base_name}")
                
                try:
                    # Create temporary file if overwriting
                    temp_file = None
                    if overwrite:
                        temp_dir = os.path.dirname(file_to_modify)
                        temp_file = os.path.join(temp_dir, f"temp_{os.path.basename(file_to_modify)}")
                    
                    with rasterio.open(file_to_modify) as src:
                        # Read data
                        data = src.read()
                        
                        # Get intersection bounds
                        minx, miny, maxx, maxy = intersection.bounds
                        
                        # Convert to pixel coordinates
                        window = from_bounds(minx, miny, maxx, maxy, src.transform)
                        
                        # Validate window
                        if (np.isnan(window.col_off) or np.isnan(window.row_off) or 
                            np.isnan(window.width) or np.isnan(window.height)):
                            failed_files.append({
                                'file_path': file_to_modify,
                                'file_type': file_type,
                                'position': position,
                                'error': "Invalid window coordinates"
                            })
                            continue
                        
                        # Round to integers
                        col_off = max(0, int(window.col_off))
                        row_off = max(0, int(window.row_off))
                        width = min(int(np.ceil(window.width)), src.width - col_off)
                        height = min(int(np.ceil(window.height)), src.height - row_off)
                        
                        # Validate dimensions
                        if width <= 0 or height <= 0:
                            failed_files.append({
                                'file_path': file_to_modify,
                                'file_type': file_type,
                                'position': position,
                                'error': "Invalid window dimensions"
                            })
                            continue
                        
                        # Set overlap area to zero
                        for band in range(data.shape[0]):
                            data[band, row_off:row_off+height, col_off:col_off+width] = 0
                        
                        # Get profile
                        profile = src.profile
                    
                    # Write to temporary file
                    write_path = temp_file if overwrite else output_file
                    
                    with rasterio.open(write_path, 'w', **profile) as dst:
                        dst.write(data)
                    
                    # Replace original if overwriting
                    if overwrite and temp_file:
                        if os.path.exists(file_to_modify):
                            os.remove(file_to_modify)
                        shutil.move(temp_file, file_to_modify)
                    
                    # Record success
                    processed_files.append({
                        'file_path': file_to_modify,
                        'file_type': file_type,
                        'position': position,
                        'overlap_with': overlap_with_img if file_type == 'image' else overlap_with_mask,
                        'overlap_area': row['overlap_area'],
                        'modified_pixels': width * height,
                        'tile_id': tile_id
                    })
                        
                except Exception as e:
                    error_msg = str(e)
                    print(f"Error processing {file_type} file {file_to_modify}: {error_msg}")
                    # Clean up temp file
                    if overwrite and temp_file and os.path.exists(temp_file):
                        os.remove(temp_file)
                    
                    failed_files.append({
                        'file_path': file_to_modify,
                        'file_type': file_type,
                        'position': position,
                        'error': error_msg
                    })
            
            pbar.update(1)
    
    # Create results dataframe
    if processed_files:
        results_df = pd.DataFrame(processed_files)
        print(f"Successfully processed {len(results_df)} files")
        
        # Report failures
        if failed_files:
            failed_df = pd.DataFrame(failed_files)
            print(f"Failed to process {len(failed_df)} files")
            print("First few failures:")
            print(failed_df.head())
            
        return results_df
    else:
        if failed_files:
            failed_df = pd.DataFrame(failed_files)
            print(f"Failed to process all {len(failed_df)} files")
            print("First few failures:")
            print(failed_df.head())
            
        print("No files were processed successfully")
        return pd.DataFrame()

# Remove overlaps
results = remove_overlap_in_geotiffs(
    overlap_df, 
    gdf_dataset, 
    overlap_positions=OVERLAP_POSITIONS, 
    buffer_distance=BUFFER_DISTANCE
)

# Display results
if len(results) > 0:
    print("\nSummary of processed files:")
    print(f"Total modified files: {len(results)}")
    
    # Group by file type
    file_type_counts = results['file_type'].value_counts()
    print("\nFiles by type:")
    print(file_type_counts)
    
    # Display sample
    print("\nSample of processed files:")
    display(results.head())

#### Verify Overlap Corrections

In [None]:
def verify_overlap_corrections(overlap_df, gdf_dataset, buffer_distance=0.01):
    """
    Verify that overlap corrections were applied correctly.
    
    Checks that overlapping regions have been set to background (zero)
    in at least one of the overlapping tiles.
    
    Parameters:
        overlap_df: DataFrame with overlap information
        gdf_dataset: GeoDataFrame with file paths
        buffer_distance: Buffer distance used for overlaps
        
    Returns:
        DataFrame: Verification results for each overlap
    """

    # Initialize results list
    verification_results = []
    skipped_pairs = 0
    
    if len(overlap_df) == 0:
        print("No overlaps to verify")
        return pd.DataFrame()
    
    # Process each overlap
    print(f"Verifying {len(overlap_df)} overlapping pairs...")
    with tqdm(total=len(overlap_df), desc="Verifying overlaps") as pbar:
        for idx, row in overlap_df.iterrows():
            # Get indices and paths
            index1 = row['index1']
            index2 = row['index2']
            position = row['relative_position']
            
            tiff_path1 = gdf_dataset.iloc[index1]['processed_img_path_tif']
            tiff_path2 = gdf_dataset.iloc[index2]['processed_img_path_tif']
            
            # Get geometries
            geom1 = gdf_dataset.iloc[index1]['geometry']
            geom2 = gdf_dataset.iloc[index2]['geometry']
            
            # Initialize result
            result = {
                'tile_id1': row['tile_id1'],
                'tile_id2': row['tile_id2'],
                'position': position,
                'overlap_area': row['overlap_area'],
                'file1_has_zeros': False,
                'file2_has_zeros': False,
                'file1_zero_percentage': 0.0,
                'file2_zero_percentage': 0.0,
                'both_have_zeros': False,
                'either_has_zeros': False,
                'avg_zero_percentage': 0.0,
                'status': 'unchecked'
            }
            
            try:
                # Validate geometries
                if geom1 is None or not geom1.is_valid or geom2 is None or not geom2.is_valid:
                    result['status'] = 'invalid_geometry'
                    verification_results.append(result)
                    pbar.update(1)
                    skipped_pairs += 1
                    continue
                
                # Apply buffer if needed
                use_buffer = row.get('buffered', True)
                
                if use_buffer:
                    buffered_geom1 = geom1.buffer(buffer_distance)
                    buffered_geom2 = geom2.buffer(buffer_distance)
                else:
                    buffered_geom1 = geom1
                    buffered_geom2 = geom2
                
                # Get intersection
                intersection = buffered_geom1.intersection(buffered_geom2)
                
                if intersection.is_empty or intersection.area <= 0:
                    result['status'] = 'empty_intersection'
                    verification_results.append(result)
                    pbar.update(1)
                    skipped_pairs += 1
                    continue
                
                # Check first file
                with rasterio.open(tiff_path1) as src1:
                    # Get intersection bounds
                    minx, miny, maxx, maxy = intersection.bounds
                    
                    # Convert to pixel coordinates
                    window1 = from_bounds(minx, miny, maxx, maxy, src1.transform)
                    
                    # Validate window
                    if (np.isnan(window1.col_off) or np.isnan(window1.row_off) or 
                        np.isnan(window1.width) or np.isnan(window1.height)):
                        result['status'] = 'invalid_window_file1'
                        verification_results.append(result)
                        pbar.update(1)
                        skipped_pairs += 1
                        continue
                    
                    # Round to integers
                    col_off1 = max(0, int(window1.col_off))
                    row_off1 = max(0, int(window1.row_off))
                    width1 = min(int(np.ceil(window1.width)), src1.width - col_off1)
                    height1 = min(int(np.ceil(window1.height)), src1.height - row_off1)
                    
                    if width1 <= 0 or height1 <= 0:
                        result['status'] = 'invalid_dimensions_file1'
                        verification_results.append(result)
                        pbar.update(1)
                        skipped_pairs += 1
                        continue
                    
                    # Read overlap region
                    data1 = src1.read(1, window=((row_off1, row_off1+height1), (col_off1, col_off1+width1)))
                    
                    # Calculate zero percentage
                    zero_count1 = np.sum(data1 == 0)
                    total_pixels1 = data1.size
                    zero_percentage1 = (zero_count1 / total_pixels1) * 100
                    
                    result['file1_has_zeros'] = zero_count1 > 0
                    result['file1_zero_percentage'] = zero_percentage1
                
                # Check second file
                with rasterio.open(tiff_path2) as src2:
                    # Get intersection bounds
                    minx, miny, maxx, maxy = intersection.bounds
                    
                    # Convert to pixel coordinates
                    window2 = from_bounds(minx, miny, maxx, maxy, src2.transform)
                    
                    # Validate window
                    if (np.isnan(window2.col_off) or np.isnan(window2.row_off) or 
                        np.isnan(window2.width) or np.isnan(window2.height)):
                        result['status'] = 'invalid_window_file2' if result['status'] == 'unchecked' else 'invalid_windows_both'
                        verification_results.append(result)
                        pbar.update(1)
                        skipped_pairs += 1
                        continue
                    
                    # Round to integers
                    col_off2 = max(0, int(window2.col_off))
                    row_off2 = max(0, int(window2.row_off))
                    width2 = min(int(np.ceil(window2.width)), src2.width - col_off2)
                    height2 = min(int(np.ceil(window2.height)), src2.height - row_off2)
                    
                    if width2 <= 0 or height2 <= 0:
                        result['status'] = 'invalid_dimensions_file2' if result['status'] == 'unchecked' else 'invalid_dimensions_both'
                        verification_results.append(result)
                        pbar.update(1)
                        skipped_pairs += 1
                        continue
                    
                    # Read overlap region
                    data2 = src2.read(1, window=((row_off2, row_off2+height2), (col_off2, col_off2+width2)))
                    
                    # Calculate zero percentage
                    zero_count2 = np.sum(data2 == 0)
                    total_pixels2 = data2.size
                    zero_percentage2 = (zero_count2 / total_pixels2) * 100
                    
                    result['file2_has_zeros'] = zero_count2 > 0
                    result['file2_zero_percentage'] = zero_percentage2
                
                # Calculate final metrics
                if result['status'] == 'unchecked':
                    result['both_have_zeros'] = result['file1_has_zeros'] and result['file2_has_zeros']
                    result['either_has_zeros'] = result['file1_has_zeros'] or result['file2_has_zeros']
                    result['avg_zero_percentage'] = (result['file1_zero_percentage'] + result['file2_zero_percentage']) / 2
                    
                    if result['both_have_zeros']:
                        result['status'] = 'both_have_zeros'
                    elif result['either_has_zeros']:
                        result['status'] = 'one_has_zeros'
                    else:
                        result['status'] = 'no_zeros'
            
            except Exception as e:
                result['status'] = f"error: {str(e)}"
                skipped_pairs += 1
            
            verification_results.append(result)
            pbar.update(1)
    
    # Create results dataframe
    if verification_results:
        df_verification = pd.DataFrame(verification_results)
        
        # Summary statistics
        status_counts = df_verification['status'].value_counts()
        print("\nVerification results:")
        for status, count in status_counts.items():
            print(f"  {status}: {count} pairs ({count/len(df_verification)*100:.1f}%)")
        
        # Calculate statistics for valid results
        valid_df = df_verification[df_verification['status'].isin(['both_have_zeros', 'one_has_zeros', 'no_zeros'])]
        
        if len(valid_df) > 0:
            both_zeros_count = valid_df['both_have_zeros'].sum()
            either_zeros_count = valid_df['either_has_zeros'].sum()
            
            print(f"\n  Pairs where both tiles have zeros in overlap: {both_zeros_count} ({both_zeros_count/len(valid_df)*100:.1f}%)")
            print(f"  Pairs where at least one tile has zeros in overlap: {either_zeros_count} ({either_zeros_count/len(valid_df)*100:.1f}%)")
            
            avg_zero_pct = valid_df['avg_zero_percentage'].mean()
            print(f"  Average percentage of zeros in overlap areas: {avg_zero_pct:.1f}%")
            
            # Check for failed corrections
            failed_verification = valid_df[valid_df['status'] == 'no_zeros']
            if len(failed_verification) > 0:
                print(f"\nWARNING: {len(failed_verification)} pairs have no background pixels in overlap regions!")
                print("\nSample of problematic pairs:")
                display(failed_verification.head(5))
        
        print(f"\nSkipped {skipped_pairs} pairs due to geometry or window issues")
        
        return df_verification
    else:
        print("No verification results")
        return pd.DataFrame()


In [None]:
# Run verification
df_verification = verify_overlap_corrections(overlap_df, gdf_dataset, buffer_distance=BUFFER_DISTANCE)

if len(df_verification) > 0:
    # Show low zero percentages
    low_zeros = df_verification[df_verification['avg_zero_percentage'] < 10]
    if len(low_zeros) > 0:
        print("\nPairs with low zero percentage (<10%):")
        display(low_zeros[['tile_id1', 'tile_id2', 'position', 'file1_zero_percentage', 'file2_zero_percentage', 'status']])
    
    # Show high zero percentages
    high_zeros = df_verification[(df_verification['file1_zero_percentage'] > 90) & 
                                (df_verification['file2_zero_percentage'] > 90)]
    if len(high_zeros) > 0:
        print("\nPairs where both files have high zero percentage (>90%):")
        display(high_zeros[['tile_id1', 'tile_id2', 'position', 'file1_zero_percentage', 'file2_zero_percentage', 'status']])

In [None]:
def visualize_overlap_corrections(overlap_df, df_verification, gdf_dataset, dataset_output_checks_path, zero_threshold=99.9):
    """
    Create visualizations of overlap corrections including mask overlays.
    
    Generates diagnostic images showing how overlaps were handled and
    checks for mask conflicts in overlapping regions.
    
    Parameters:
        overlap_df: DataFrame with overlap information
        df_verification: DataFrame with verification results
        gdf_dataset: GeoDataFrame with file paths
        dataset_output_checks_path: Output directory for visualizations
        zero_threshold: Percentage to consider as background
        
    Returns:
        dict: Visualization statistics
    """
    
    # Create output directory
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    output_dir = os.path.join(dataset_output_checks_path, f"overlap_check_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)
    
    # Track statistics
    successful = 0
    failed = 0
    mask_issues = 0
    
    def get_safe_window_data(src, intersection_bounds):
        """
        Extract window data with consistent dimensions.
        
        Parameters:
            src: Rasterio dataset
            intersection_bounds: Bounds tuple
            
        Returns:
            tuple: (data, window)
        """
        minx, miny, maxx, maxy = intersection_bounds
        
        # Get window
        window = from_bounds(minx, miny, maxx, maxy, src.transform)
        
        # Round and ensure bounds
        col_off = max(0, min(int(round(window.col_off)), src.width - 1))
        row_off = max(0, min(int(round(window.row_off)), src.height - 1))
        width = max(1, min(int(round(window.width)), src.width - col_off))
        height = max(1, min(int(round(window.height)), src.height - row_off))
        
        # Create safe window
        safe_window = Window(col_off, row_off, width, height)
        
        # Read data
        data = src.read(1, window=safe_window)
        
        return data, safe_window
    
    def visualize_pair(row, output_path):
        """
        Create visualization for a single overlap pair.
        
        Parameters:
            row: DataFrame row with overlap information
            output_path: Path to save visualization
            
        Returns:
            tuple: (success, has_mask_issue)
        """
        try:
            # Get tile IDs
            tile_id1 = row['tile_id1']
            tile_id2 = row['tile_id2']
            
            # Find indices
            idx1 = gdf_dataset[gdf_dataset['tile_id'] == tile_id1].index[0]
            idx2 = gdf_dataset[gdf_dataset['tile_id'] == tile_id2].index[0]
            
            # Get file paths
            tiff_path1 = gdf_dataset.loc[idx1, 'processed_img_path_tif']
            tiff_path2 = gdf_dataset.loc[idx2, 'processed_img_path_tif']
            
            # Check file existence
            if not os.path.exists(tiff_path1) or not os.path.exists(tiff_path2):
                print(f"Files not found for {tile_id1} and {tile_id2}")
                return False
            
            # Get mask paths
            has_masks = False
            if 'processed_mask_path_tif' in gdf_dataset.columns:
                mask_path1 = gdf_dataset.loc[idx1, 'processed_mask_path_tif']
                mask_path2 = gdf_dataset.loc[idx2, 'processed_mask_path_tif']
                
                has_masks = (os.path.exists(mask_path1) and os.path.exists(mask_path2))
                if not has_masks:
                    print(f"Warning: Mask files not found for {tile_id1} and/or {tile_id2}")
            else:
                print("Warning: 'processed_mask_path_tif' column not found")
            
            # Open and analyze images
            with rasterio.open(tiff_path1) as src1, rasterio.open(tiff_path2) as src2:
                # Get bounds
                bounds1 = src1.bounds
                bounds2 = src2.bounds
                
                # Calculate intersection
                intersection = (
                    max(bounds1.left, bounds2.left),
                    max(bounds1.bottom, bounds2.bottom),
                    min(bounds1.right, bounds2.right),
                    min(bounds1.top, bounds2.top)
                )
                
                # Validate intersection
                if intersection[2] <= intersection[0] or intersection[3] <= intersection[1]:
                    print(f"No valid intersection for {tile_id1} and {tile_id2}")
                    return False
                
                # Get window data safely
                data1, window1 = get_safe_window_data(src1, intersection)
                data2, window2 = get_safe_window_data(src2, intersection)
                
                # Read full images
                full_data1 = src1.read(1)
                full_data2 = src2.read(1)
                
                # Create overlap masks
                overlap_mask1 = np.zeros_like(full_data1, dtype=bool)
                overlap_mask1[window1.row_off:window1.row_off+window1.height, 
                             window1.col_off:window1.col_off+window1.width] = True
                
                overlap_mask2 = np.zeros_like(full_data2, dtype=bool)
                overlap_mask2[window2.row_off:window2.row_off+window2.height, 
                             window2.col_off:window2.col_off+window2.width] = True
                
                # Process masks if available
                has_mask_conflict = False
                mask_conflict_percentage = 0
                mask_data1 = None
                mask_data2 = None
                mask_overlap1 = None
                mask_overlap2 = None
                
                if has_masks:
                    try:
                        with rasterio.open(mask_path1) as mask_src1, rasterio.open(mask_path2) as mask_src2:
                            # Read full masks
                            mask_data1 = mask_src1.read(1)
                            mask_data2 = mask_src2.read(1)
                            
                            # Use same windows as images
                            mask_overlap1 = mask_src1.read(1, window=window1)
                            mask_overlap2 = mask_src2.read(1, window=window2)
                            
                            # Handle shape mismatches
                            if mask_overlap1.shape != mask_overlap2.shape:
                                print(f"Mask shape mismatch for {tile_id1} and {tile_id2}")
                                
                                # Resize to smaller dimensions
                                min_height = min(mask_overlap1.shape[0], mask_overlap2.shape[0])
                                min_width = min(mask_overlap1.shape[1], mask_overlap2.shape[1])
                                
                                mask_overlap1 = mask_overlap1[:min_height, :min_width]
                                mask_overlap2 = mask_overlap2[:min_height, :min_width]
                                
                                # Also resize image data
                                data1 = data1[:min_height, :min_width]
                                data2 = data2[:min_height, :min_width]
                            
                            # Check for mask conflicts
                            if mask_overlap1.shape == mask_overlap2.shape and mask_overlap1.size > 0:
                                mask_conflict = np.logical_and(mask_overlap1 > 0, mask_overlap2 > 0)
                                has_mask_conflict = np.any(mask_conflict)
                                mask_conflict_percentage = np.sum(mask_conflict) / mask_conflict.size * 100
                            else:
                                has_mask_conflict = False
                                mask_conflict_percentage = 0
                                
                    except Exception as e:
                        print(f"Error reading mask files: {str(e)}")
                        has_masks = False
                
                # Ensure consistent dimensions
                if data1.shape != data2.shape:
                    min_height = min(data1.shape[0], data2.shape[0])
                    min_width = min(data1.shape[1], data2.shape[1])
                    data1 = data1[:min_height, :min_width]
                    data2 = data2[:min_height, :min_width]
                
                # Create figure
                fig, axs = plt.subplots(2, 3, figsize=(18, 12))
                
                # Row 1: Image analysis
                
                # Tile 1 with overlap
                axs[0, 0].imshow(full_data1, cmap='gray')
                highlighted1 = np.zeros((*full_data1.shape, 4))
                highlighted1[..., 0] = 1  # Red
                highlighted1[..., 3] = np.where(overlap_mask1, 0.4, 0)  # Alpha
                axs[0, 0].imshow(highlighted1)
                axs[0, 0].set_title(f"Tile {tile_id1}\nZero %: {row.get('file1_zero_percentage', 'N/A'):.1f}%")
                axs[0, 0].axis('off')
                
                # Tile 2 with overlap
                axs[0, 1].imshow(full_data2, cmap='gray')
                highlighted2 = np.zeros((*full_data2.shape, 4))
                highlighted2[..., 2] = 1  # Blue
                highlighted2[..., 3] = np.where(overlap_mask2, 0.4, 0)  # Alpha
                axs[0, 1].imshow(highlighted2)
                axs[0, 1].set_title(f"Tile {tile_id2}\nZero %: {row.get('file2_zero_percentage', 'N/A'):.1f}%")
                axs[0, 1].axis('off')
                
                # Overlap comparison
                if data1.size > 0 and data2.size > 0:
                    composite = np.zeros((data1.shape[0], data1.shape[1] * 2))
                    composite[:, :data1.shape[1]] = data1
                    composite[:, data1.shape[1]:] = data2
                    
                    axs[0, 2].imshow(composite, cmap='gray')
                    axs[0, 2].axvline(x=data1.shape[1], color='r', linestyle='--')
                    
                    # Calculate zero percentages
                    zeros1 = np.sum(data1 == 0) / data1.size * 100
                    zeros2 = np.sum(data2 == 0) / data2.size * 100
                    
                    # Determine status
                    if zeros1 >= zero_threshold and zeros2 >= zero_threshold:
                        content_status = "both_background"
                    else:
                        content_status = "partial_image"
                    
                    axs[0, 2].set_title(f"Overlap Comparison\nPosition: {row.get('position', 'N/A')}, Status: {content_status}")
                    
                    # Add zero percentage labels
                    axs[0, 2].text(data1.shape[1] * 0.5, data1.shape[0] * 0.9, 
                                  f"{zeros1:.1f}% zeros", ha='center', color='white',
                                  bbox=dict(facecolor='red', alpha=0.7))
                    axs[0, 2].text(data1.shape[1] * 1.5, data1.shape[0] * 0.9, 
                                  f"{zeros2:.1f}% zeros", ha='center', color='white',
                                  bbox=dict(facecolor='blue', alpha=0.7))
                else:
                    axs[0, 2].text(0.5, 0.5, "No overlap data available", 
                                 ha='center', va='center', fontsize=12)
                    content_status = "no_data"
                    zeros1 = zeros2 = 0
                
                axs[0, 2].axis('off')
                
                # Row 2: Mask analysis
                
                if has_masks and mask_data1 is not None and mask_data2 is not None:
                    # Tile 1 with mask overlay
                    axs[1, 0].imshow(full_data1, cmap='gray')
                    mask_overlay1 = np.zeros((*full_data1.shape, 4))
                    mask_overlay1[..., 0] = 1  # Red
                    mask_overlay1[..., 3] = np.where(mask_data1 > 0, 0.5, 0)
                    axs[1, 0].imshow(mask_overlay1)
                    axs[1, 0].set_title(f"Tile {tile_id1} with mask overlay")
                    axs[1, 0].axis('off')
                    
                    # Tile 2 with mask overlay
                    axs[1, 1].imshow(full_data2, cmap='gray')
                    mask_overlay2 = np.zeros((*full_data2.shape, 4))
                    mask_overlay2[..., 2] = 1  # Blue
                    mask_overlay2[..., 3] = np.where(mask_data2 > 0, 0.5, 0)
                    axs[1, 1].imshow(mask_overlay2)
                    axs[1, 1].set_title(f"Tile {tile_id2} with mask overlay")
                    axs[1, 1].axis('off')
                    
                    # Mask overlap comparison
                    if (mask_overlap1 is not None and mask_overlap2 is not None and 
                        data1.size > 0 and data2.size > 0):
                        
                        mask_composite = np.zeros((data1.shape[0], data1.shape[1] * 2, 4))
                        
                        # Set base image
                        for c in range(3):
                            if np.max(data1) > 0:
                                mask_composite[:, :data1.shape[1], c] = data1 / np.max(data1)
                            if np.max(data2) > 0:
                                mask_composite[:, data1.shape[1]:, c] = data2 / np.max(data2)
                        mask_composite[..., 3] = 1.0
                        
                        # Ensure mask dimensions match
                        if mask_overlap1.shape == data1.shape and mask_overlap2.shape == data2.shape:
                            # Add mask overlays
                            mask_overlay_left = np.zeros((data1.shape[0], data1.shape[1], 4))
                            mask_overlay_left[..., 0] = 1.0  # Red
                            mask_overlay_left[..., 3] = np.where(mask_overlap1 > 0, 0.5, 0)
                            
                            mask_overlay_right = np.zeros((data2.shape[0], data2.shape[1], 4))
                            mask_overlay_right[..., 2] = 1.0  # Blue
                            mask_overlay_right[..., 3] = np.where(mask_overlap2 > 0, 0.5, 0)
                            
                            # Plot composite
                            axs[1, 2].imshow(mask_composite)
                            axs[1, 2].imshow(np.pad(mask_overlay_left, ((0,0), (0,data1.shape[1]), (0,0)), 'constant'))
                            axs[1, 2].imshow(np.pad(mask_overlay_right, ((0,0), (data1.shape[1],0), (0,0)), 'constant'))
                            
                            # Add dividing line
                            axs[1, 2].axvline(x=data1.shape[1], color='yellow', linestyle='--')
                            
                            # Set title based on conflicts
                            if has_mask_conflict:
                                title = f"Mask Overlap Comparison\nWarning: {mask_conflict_percentage:.1f}% mask conflict!"
                            else:
                                title = "Mask Overlap Comparison\nNo mask conflicts"
                            
                            axs[1, 2].set_title(title)
                            axs[1, 2].axis('off')
                        else:
                            axs[1, 2].text(0.5, 0.5, "Mask-image dimension mismatch", 
                                         ha='center', va='center', fontsize=12)
                            axs[1, 2].axis('off')
                    else:
                        axs[1, 2].text(0.5, 0.5, "Unable to process mask overlaps", 
                                     ha='center', va='center', fontsize=12)
                        axs[1, 2].axis('off')
                    
                else:
                    # No masks available
                    for i in range(3):
                        axs[1, i].text(0.5, 0.5, "No mask files found", 
                                     ha='center', va='center', fontsize=12)
                        axs[1, i].axis('off')
                
                # Add main title
                plt.suptitle(f"Overlap Analysis: {tile_id1} and {tile_id2}", 
                            fontsize=16, y=0.98)
                
                plt.tight_layout()
                plt.subplots_adjust(top=0.92)
                plt.savefig(output_path, dpi=150, bbox_inches='tight')
                plt.close(fig)
                
                # Update row with analysis
                row['content_status'] = content_status
                row['zeros1'] = zeros1
                row['zeros2'] = zeros2
                
                if has_masks:
                    row['has_masks'] = True
                    row['has_mask_conflict'] = has_mask_conflict
                    row['mask_conflict_percentage'] = mask_conflict_percentage
                else:
                    row['has_masks'] = False
                
                return True, has_masks and has_mask_conflict
                
        except Exception as e:
            print(f"Error visualizing pair {tile_id1} and {tile_id2}: {str(e)}")
            traceback.print_exc()
            return False, False
    
    # Process verification results
    if len(df_verification) > 0:
        print(f"Processing {len(df_verification)} verified pairs...")
        
        # Store results
        results_df = pd.DataFrame()
        
        # Process each row
        for idx, row in tqdm(df_verification.iterrows(), total=len(df_verification)):
            # Create filename
            tile_id1 = row['tile_id1']
            tile_id2 = row['tile_id2']
            position = row.get('position', 'unknown')
            
            # Copy row for updates
            row_copy = row.copy()
            
            filename = f"{tile_id1}_{tile_id2}_{position}.png"
            output_path = os.path.join(output_dir, filename)
            
            success, has_mask_issue = visualize_pair(row_copy, output_path)
            if success:
                successful += 1
                if has_mask_issue:
                    mask_issues += 1
                # Append result
                results_df = pd.concat([results_df, pd.DataFrame([row_copy])], ignore_index=True)
            else:
                failed += 1
        
        # Save results
        results_path = os.path.join(output_dir, "overlap_analysis_results.csv")
        results_df.to_csv(results_path, index=False)
        print(f"Saved results to {results_path}")
        
        # Save mask issues separately
        if mask_issues > 0:
            mask_issues_df = results_df[results_df.get('has_mask_conflict', False) == True]
            mask_issues_path = os.path.join(output_dir, "mask_issues.csv")
            mask_issues_df.to_csv(mask_issues_path, index=False)
            print(f"Found {mask_issues} tile pairs with mask issues. Saved to {mask_issues_path}")
    
    print(f"Visualization complete. Created {successful} visualizations in {output_dir}")
    print(f"- Successful: {successful}")
    print(f"- Failed: {failed}")
    print(f"- Mask issues: {mask_issues}")
    
    return {
        "successful": successful,
        "failed": failed,
        "mask_issues": mask_issues,
        "output_dir": output_dir,
        "results_path": results_path if len(df_verification) > 0 else None
    }

In [None]:
# Create visualizations
visualization_results = visualize_overlap_corrections(
    overlap_df=overlap_df,
    df_verification=df_verification,
    gdf_dataset=gdf_dataset,
    dataset_output_checks_path=DATASET_OUTPUT_CHECKS_PATH,
)

print(f"Results saved to: {visualization_results['output_dir']}")
print(f"Successful visualizations: {visualization_results['successful']}")
print(f"Failed visualizations: {visualization_results['failed']}")
print(f"Tiles with mask issues: {visualization_results['mask_issues']}")

### Standardize Dimensions to 1280x1280

In [None]:
def standardize_image_dimensions(img_dir, mask_dir, target_size=(1280, 1280), overwrite=True):
    """
    Pad images and masks to standard dimensions with consistent alignment.
    
    Centers the original image within the target dimensions and applies
    identical padding to both image and mask files.
    
    Parameters:
        img_dir: Directory containing images
        mask_dir: Directory containing masks
        target_size: Target dimensions (width, height)
        overwrite: Whether to overwrite original files
        
    Returns:
        list: Information about modified files
    """
    # Get file lists
    img_files = [f for f in os.listdir(img_dir) if f.endswith(('.tif', '.tiff'))]
    mask_files = [f for f in os.listdir(mask_dir) if f.endswith(('.tif', '.tiff'))]
    
    # Create mask mapping
    mask_map = {}
    for mask_file in mask_files:
        mask_basename = os.path.splitext(mask_file)[0]
        mask_map[mask_basename] = mask_file
    
    # Track statistics
    total_images = len(img_files)
    resized_pairs = 0
    errors = 0
    skipped = 0
    
    # Track modified files
    modified_files = []
    
    print(f"Processing {total_images} images to ensure {target_size[0]}x{target_size[1]} dimensions...")
    
    # Process each image
    for img_filename in tqdm(img_files, desc="Standardizing images"):
        try:
            img_path = os.path.join(img_dir, img_filename)
            
            # Find corresponding mask
            img_basename = os.path.splitext(img_filename)[0]
            mask_filename = mask_map.get(img_basename)
            
            if mask_filename:
                mask_path = os.path.join(mask_dir, mask_filename)
                
                # Verify mask exists
                if not os.path.exists(mask_path):
                    print(f"Warning: Mask file {mask_path} not found. Skipping pair.")
                    skipped += 1
                    continue
            else:
                print(f"Warning: No matching mask found for {img_filename}. Skipping.")
                skipped += 1
                continue
            
            # Open image to check dimensions
            with rasterio.open(img_path) as src:
                height, width = src.height, src.width
                
                # Skip if already target size
                if (width, height) == target_size:
                    continue
                
                # Calculate padding
                pad_width = max(0, target_size[0] - width)
                pad_height = max(0, target_size[1] - height)
                
                # Calculate padding offsets for centering
                start_x = pad_width // 2
                start_y = pad_height // 2
                
                # Skip if larger than target
                if pad_width < 0 or pad_height < 0:
                    print(f"Warning: {img_filename} is larger than target size. Skipping pair.")
                    skipped += 1
                    continue
                
                # Process image
                with rasterio.open(img_path) as src:
                    # Read data
                    img_data = src.read()
                    
                    # Create padded array
                    bands = img_data.shape[0]
                    padded_data = np.zeros((bands, target_size[1], target_size[0]), dtype=img_data.dtype)
                    
                    # Copy original data centered
                    for b in range(bands):
                        padded_data[b, start_y:start_y+height, start_x:start_x+width] = img_data[b]
                    
                    # Update transform for georeferencing
                    transform = src.transform
                    xoff = transform.c - start_x * transform.a
                    yoff = transform.f - start_y * transform.e
                    new_transform = rasterio.Affine(transform.a, transform.b, xoff,
                                                   transform.d, transform.e, yoff)
                    
                    # Update metadata
                    meta = src.meta.copy()
                    meta.update({
                        'height': target_size[1],
                        'width': target_size[0],
                        'transform': new_transform
                    })
                    
                    # Write to temporary file
                    with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp:
                        tmp_path = tmp.name
                    
                    # Write padded image
                    with rasterio.open(tmp_path, 'w', **meta) as dst:
                        dst.write(padded_data)
                    
                    # Replace original
                    shutil.move(tmp_path, img_path)
                
                # Process mask (TIF format)
                if mask_filename.lower().endswith(('.tif', '.tiff')):
                    with rasterio.open(mask_path) as mask_src:
                        mask_height, mask_width = mask_src.height, mask_src.width
                        
                        # Verify dimensions match
                        if (mask_width, mask_height) != (width, height):
                            print(f"Warning: Dimension mismatch between {img_filename} and {mask_filename}")
                        
                        # Read mask data
                        mask_data = mask_src.read()
                        
                        # Create padded array
                        mask_bands = mask_data.shape[0]
                        padded_mask_data = np.zeros((mask_bands, target_size[1], target_size[0]), dtype=mask_data.dtype)
                        
                        # Copy original data with same offsets as image
                        for b in range(mask_bands):
                            padded_mask_data[b, start_y:start_y+mask_height, start_x:start_x+mask_width] = mask_data[b]
                        
                        # Update metadata (use same transform as image)
                        mask_meta = mask_src.meta.copy()
                        mask_meta.update({
                            'height': target_size[1],
                            'width': target_size[0],
                            'transform': new_transform
                        })
                        
                        # Write to temporary file
                        with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp:
                            tmp_path = tmp.name
                        
                        # Write padded mask
                        with rasterio.open(tmp_path, 'w', **mask_meta) as dst:
                            dst.write(padded_mask_data)
                        
                        # Replace original
                        shutil.move(tmp_path, mask_path)
                
                # Record modification
                modified_files.append({
                    'img_file': img_path,
                    'mask_file': mask_path,
                    'from_size': (width, height),
                    'to_size': target_size,
                    'padding': (start_x, start_y, pad_width, pad_height)
                })
                
                resized_pairs += 1
                
        except Exception as e:
            errors += 1
            print(f"Error processing {img_filename}: {e}")
    
    print("Standardization complete:")
    print(f"- Total images processed: {total_images}")
    print(f"- Image/mask pairs resized: {resized_pairs}")
    print(f"- Pairs skipped: {skipped}")
    print(f"- Errors encountered: {errors}")
    print(f"- Already at target size: {total_images - resized_pairs - skipped - errors}")
    
    return modified_files

# Standardize dimensions
modified_files = standardize_image_dimensions(
    img_dir=DATASET_OUTPUT_IMG_PATH,
    mask_dir=DATASET_OUTPUT_MASKS_PATH,
    target_size=(1280, 1280)
)

# Display results
if modified_files:
    print(f"\nModified {len(modified_files)} file pairs. First 5 examples:")
    for i, file_info in enumerate(modified_files[:5]):
        print(f"{i+1}. {os.path.basename(file_info['img_file'])}: {file_info['from_size']} -> {file_info['to_size']}")
else:
    print("\nNo files were modified.")

In [None]:
def verify_padding(processed_img_dir, processed_mask_dir, output_dir, show_images=False, modified_files=None, 
                   modified_sample_count=5, unmodified_sample_count=5):
    """
    Verify padding consistency and create visualizations.
    
    Checks that all files have correct dimensions and that padding
    was applied consistently to image-mask pairs.
    
    Parameters:
        processed_img_dir: Directory with processed images
        processed_mask_dir: Directory with processed masks
        output_dir: Output directory for visualizations
        show_images: Whether to display images
        modified_files: List of modified file information
        modified_sample_count: Number of modified samples to visualize
        unmodified_sample_count: Number of unmodified samples to visualize
        
    Returns:
        dict: Verification statistics
    """
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all files
    all_tiff_files = [f for f in os.listdir(processed_img_dir) if f.endswith(('.tif', '.tiff'))]
    
    if not all_tiff_files:
        print("No GeoTIFF files found for verification.")
        return
    
    # Create lookup for modified files
    modified_info = {}
    if modified_files:
        for info in modified_files:
            filename = os.path.basename(info['img_file'])
            modified_info[filename] = info
    
    modified_paths = set(modified_info.keys())
    
    # Separate modified and unmodified
    modified_tiff_files = [f for f in all_tiff_files if f in modified_paths]
    unmodified_tiff_files = [f for f in all_tiff_files if f not in modified_paths]
    
    print(f"Found {len(modified_tiff_files)} modified files and {len(unmodified_tiff_files)} unmodified files")
    
    # Track verification results
    total_files = len(all_tiff_files)
    verified_files = 0
    dimension_mismatches = 0
    missing_masks = 0
    
    print(f"Verifying all {total_files} image-mask pairs...")
    
    # Verify all files
    for img_filename in tqdm(all_tiff_files, desc="Verifying files"):
        img_path = os.path.join(processed_img_dir, img_filename)
        img_basename = os.path.splitext(img_filename)[0]
        
        # Find corresponding mask
        mask_filename = None
        for ext in ['.tif', '.tiff', '.png', '.PNG']:
            candidate_mask = img_basename + ext
            if os.path.exists(os.path.join(processed_mask_dir, candidate_mask)):
                mask_filename = candidate_mask
                break
        
        if not mask_filename:
            print(f"No matching mask found for {img_filename}.")
            missing_masks += 1
            continue
        
        mask_path = os.path.join(processed_mask_dir, mask_filename)
        
        try:
            # Read dimensions
            with rasterio.open(img_path) as src:
                geotiff_height, geotiff_width = src.height, src.width
            
            # Read mask dimensions
            if mask_filename.lower().endswith(('.tif', '.tiff')):
                with rasterio.open(mask_path) as mask_src:
                    mask_height, mask_width = mask_src.height, mask_src.width
            else:
                with Image.open(mask_path) as mask_img:
                    mask_width, mask_height = mask_img.size
            
            # Check dimension match
            if (geotiff_height, geotiff_width) != (mask_height, mask_width):
                print(f"Dimension mismatch for {img_basename}: GeoTIFF {geotiff_width}x{geotiff_height}, "
                      f"Mask {mask_width}x{mask_height}")
                dimension_mismatches += 1
            else:
                verified_files += 1
                
        except Exception as e:
            print(f"Error verifying {img_filename}: {e}")
    
    # Print summary
    print("\nVerification Summary:")
    print(f"- Total files checked: {total_files}")
    print(f"- Successfully verified pairs: {verified_files}")
    print(f"- Dimension mismatches: {dimension_mismatches}")
    print(f"- Missing masks: {missing_masks}")
    
    def visualize_sample(img_filename, sample_type):
        """Create visualization for a single sample."""
        img_path = os.path.join(processed_img_dir, img_filename)
        img_basename = os.path.splitext(img_filename)[0]
        
        # Find mask
        mask_filename = None
        for ext in ['.tif', '.tiff', '.png', '.PNG']:
            candidate_mask = img_basename + ext
            if os.path.exists(os.path.join(processed_mask_dir, candidate_mask)):
                mask_filename = candidate_mask
                break
        
        if not mask_filename:
            print(f"No mask found for {img_filename}")
            return False
        
        mask_path = os.path.join(processed_mask_dir, mask_filename)
        
        try:
            # Read image
            with rasterio.open(img_path) as src:
                geotiff_data = src.read(1)
            
            # Read mask
            if mask_filename.lower().endswith(('.tif', '.tiff')):
                with rasterio.open(mask_path) as mask_src:
                    mask_data = mask_src.read(1)
            else:
                mask_data = np.array(Image.open(mask_path))
                if len(mask_data.shape) == 3:
                    mask_data = mask_data[:, :, 0]
            
            # Get dimension information
            is_modified = img_filename in modified_paths
            if is_modified:
                padding_info = modified_info[img_filename]
                original_size = padding_info['from_size']
                dimension_text = f"Original: {original_size[0]}×{original_size[1]} → Current: 1280×1280"
            else:
                dimension_text = "Original: 1280×1280 (no change needed)"
            
            # Create visualization
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            
            # Plot image
            axes[0].imshow(geotiff_data, cmap='gray')
            axes[0].set_title(f"GeoTIFF: {img_filename}")
            axes[0].axis('off')
            
            # Plot mask
            axes[1].imshow(mask_data, cmap='gray')
            axes[1].set_title(f"Mask: {mask_filename}")
            axes[1].axis('off')
            
            # Create overlay
            if geotiff_data.max() > geotiff_data.min():
                normalized_geotiff = (geotiff_data - geotiff_data.min()) / (geotiff_data.max() - geotiff_data.min())
            else:
                normalized_geotiff = np.zeros_like(geotiff_data)
            
            if mask_data.max() > mask_data.min():
                normalized_mask = (mask_data - mask_data.min()) / (mask_data.max() - mask_data.min())
            else:
                normalized_mask = np.zeros_like(mask_data)
            
            # Create RGB overlay
            overlay = np.zeros((geotiff_data.shape[0], geotiff_data.shape[1], 3))
            overlay[:, :, 0] = normalized_geotiff  # Red (image)
            overlay[:, :, 2] = normalized_mask     # Blue (mask)
            
            # Display overlay
            axes[2].imshow(overlay)
            axes[2].set_title("Overlay (purple shows alignment)")
            axes[2].axis('off')
            
            # Add boundaries for modified files
            if is_modified:
                padding_info = modified_info[img_filename]
                padding = padding_info['padding']
                original_size = padding_info['from_size']
                
                start_x, start_y = padding[0], padding[1]
                width, height = original_size
                
                from matplotlib.patches import Rectangle
                rect_style = dict(linewidth=2, edgecolor='yellow', facecolor='none', linestyle='--')
                
                # Add rectangle showing original area
                axes[0].add_patch(Rectangle((start_x, start_y), width, height, **rect_style))
                axes[1].add_patch(Rectangle((start_x, start_y), width, height, **rect_style))
                axes[2].add_patch(Rectangle((start_x, start_y), width, height, **rect_style))
            
            plt.suptitle(f"{sample_type} Sample: {img_basename}\n{dimension_text}", fontsize=16)
            plt.tight_layout()
            
            save_path = os.path.join(output_dir, f"{sample_type.lower()}_sample_{img_basename}.png")
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            
            if show_images:
                plt.show()
            else:
                plt.close()
            
            print(f"Visualization for {img_basename} ({sample_type}):")
            print(f"- GeoTIFF dimensions: {geotiff_data.shape}")
            print(f"- Mask dimensions: {mask_data.shape}")
            print(f"- {dimension_text}")
            print(f"- Saved to: {save_path}")
            print("-" * 50)
            
            return True
            
        except Exception as e:
            print(f"Error visualizing {img_filename}: {e}")
            return False
    
    # Generate visualizations for modified samples
    if modified_sample_count > 0 and modified_tiff_files:
        print(f"\nGenerating visualizations for {min(modified_sample_count, len(modified_tiff_files))} modified samples...")
        
        modified_samples = random.sample(modified_tiff_files, min(modified_sample_count, len(modified_tiff_files)))
        
        successful_visualizations = 0
        for img_filename in modified_samples:
            if visualize_sample(img_filename, "Modified"):
                successful_visualizations += 1
        
        print(f"Successfully created {successful_visualizations} modified sample visualizations")
    
    # Generate visualizations for unmodified samples
    if unmodified_sample_count > 0 and unmodified_tiff_files:
        print(f"\nGenerating visualizations for {min(unmodified_sample_count, len(unmodified_tiff_files))} unmodified samples...")
        
        unmodified_samples = random.sample(unmodified_tiff_files, min(unmodified_sample_count, len(unmodified_tiff_files)))
        
        successful_visualizations = 0
        for img_filename in unmodified_samples:
            if visualize_sample(img_filename, "Unmodified"):
                successful_visualizations += 1
        
        print(f"Successfully created {successful_visualizations} unmodified sample visualizations")
    
    return {
        'total_files': total_files,
        'verified_files': verified_files, 
        'dimension_mismatches': dimension_mismatches,
        'missing_masks': missing_masks,
        'modified_files_count': len(modified_tiff_files),
        'unmodified_files_count': len(unmodified_tiff_files)
    }

# Run verification
verification_results = verify_padding(
    processed_img_dir=DATASET_OUTPUT_IMG_PATH,
    processed_mask_dir=DATASET_OUTPUT_MASKS_PATH,
    output_dir=DATASET_OUTPUT_CHECKS_PATH + "/padding_verification",
    modified_files=modified_files,
    modified_sample_count=5,
    unmodified_sample_count=3,
    show_images=False
)

print("\nFinal Verification Results:")
for key, value in verification_results.items():
    print(f"   {key}: {value}")

## Final Validation

In [None]:
# Pre-cleanup: Remove files with 100% zero content
print("PRE-VERIFICATION CLEANUP: REMOVING 100% ZERO FILES")
print("=" * 60)

def remove_zero_files_and_records(gdf_dataset, dataset_output_img_path, dataset_output_masks_path):
    """
    Remove image files that are 100% zeros and their records.
    
    Only removes pairs where the image is completely zero, regardless
    of mask content. Masks with zero content but non-zero images are kept.
    
    Parameters:
        gdf_dataset: GeoDataFrame with file information
        dataset_output_img_path: Image directory path
        dataset_output_masks_path: Mask directory path
        
    Returns:
        GeoDataFrame: Cleaned dataset with zero-content records removed
    """
    
    print(f"Starting with {len(gdf_dataset)} records")
    
    # Make copy to avoid warnings
    gdf_cleaned = gdf_dataset.copy()
    
    zero_files_to_remove = []
    files_deleted = []
    records_to_remove = []
    
    print("\n1. Scanning for 100% zero files...")
    
    for idx, row in gdf_cleaned.iterrows():
        img_path = row.get('processed_img_path_tif')
        mask_path = row.get('processed_mask_path_tif')
        tile_id = row['tile_id']
        
        img_all_zero = False
        mask_all_zero = False
        should_remove = False
        
        # Check if image is all zeros
        if pd.notna(img_path) and os.path.exists(img_path):
            try:
                with rasterio.open(img_path) as src:
                    img_data = src.read(1)
                    if np.all(img_data == 0):
                        img_all_zero = True
                        print(f"   Image is 100% zeros: {tile_id}")
            except Exception as e:
                print(f"   Error reading image {tile_id}: {e}")
                should_remove = True
        
        # Check if mask is all zeros
        if pd.notna(mask_path) and os.path.exists(mask_path):
            try:
                with rasterio.open(mask_path) as src:
                    mask_data = src.read(1)
                    if np.all(mask_data == 0):
                        mask_all_zero = True
                        print(f"   Mask is 100% zeros: {tile_id}")
            except Exception as e:
                print(f"   Error reading mask {tile_id}: {e}")
                should_remove = True
        
        # Remove only if image is 100% zeros
        if img_all_zero or should_remove:
            zero_files_to_remove.append(tile_id)
            records_to_remove.append(idx)
            
            # Log removal reason
            if img_all_zero and mask_all_zero:
                print(f"   MARKED FOR REMOVAL: {tile_id} (both image and mask are 100% zeros)")
            elif img_all_zero and not mask_all_zero:
                print(f"   MARKED FOR REMOVAL: {tile_id} (image is 100% zeros, removing both)")
            elif should_remove:
                print(f"   MARKED FOR REMOVAL: {tile_id} (file read errors)")
            # Add files to deletion list
            if pd.notna(img_path) and os.path.exists(img_path):
                files_deleted.append(img_path)
            if pd.notna(mask_path) and os.path.exists(mask_path):
                files_deleted.append(mask_path)
        else:
            # Log kept files
            if not img_all_zero and mask_all_zero:
                print(f"   KEEPING: {tile_id} (image has content, mask is 100% zeros - acceptable)")
            elif not img_all_zero and not mask_all_zero:
                print(f"   KEEPING: {tile_id} (both image and mask have content)")
    
    print(f"\n2. Found {len(zero_files_to_remove)} records where image is 100% zeros")
    if zero_files_to_remove:
        print(f"   Records to remove: {zero_files_to_remove}")
        print("   Note: Removing pairs ONLY when image is 100% zeros (mask state doesn't matter)")
    
    # Delete files from disk
    print(f"\n3. Deleting {len(files_deleted)} files from disk...")
    deleted_count = 0
    delete_errors = 0
    
    for file_path in files_deleted:
        try:
            if os.path.exists(file_path):
                os.remove(file_path)
                deleted_count += 1
                print(f"   Deleted: {os.path.basename(file_path)}")
            else:
                print(f"   File already missing: {os.path.basename(file_path)}")
        except Exception as e:
            print(f"   Error deleting {os.path.basename(file_path)}: {e}")
            delete_errors += 1
    
    print(f"   Successfully deleted {deleted_count} files")
    if delete_errors > 0:
        print(f"   Failed to delete {delete_errors} files")
    
    # Remove records from dataframe
    if records_to_remove:
        print(f"\n4. Removing {len(records_to_remove)} records from dataframe...")
        gdf_cleaned = gdf_cleaned.drop(records_to_remove)
        gdf_cleaned = gdf_cleaned.reset_index(drop=True)
        print(f"   Dataframe now has {len(gdf_cleaned)} records")
    else:
        print("\n4. No records to remove from dataframe")
    
    # Print summary
    print("\n5. Cleanup summary:")
    print(f"   Original records: {len(gdf_dataset)}")
    print(f"   Records removed: {len(records_to_remove)}")
    print(f"   Final records: {len(gdf_cleaned)}")
    print(f"   Files deleted: {deleted_count}")
    
    return gdf_cleaned

# Remove records with missing files
print("STEP 1: Filter out records with missing files")
gdf_dataset_filtered = gdf_dataset.copy()

missing_file_records = []
for idx, row in gdf_dataset_filtered.iterrows():
    img_path = row.get('processed_img_path_tif')
    mask_path = row.get('processed_mask_path_tif')
    
    img_exists = pd.notna(img_path) and os.path.exists(img_path)
    mask_exists = pd.notna(mask_path) and os.path.exists(mask_path)
    
    if not (img_exists and mask_exists):
        missing_file_records.append(idx)

if missing_file_records:
    gdf_dataset_filtered = gdf_dataset_filtered.drop(missing_file_records)
    print(f"Removed {len(missing_file_records)} records with missing files")
else:
    print("No records with missing files found")

print(f"After filtering missing files: {len(gdf_dataset_filtered)} records")

# Remove zero content files
print("\nSTEP 2: Remove 100% zero files and records")
gdf_dataset_cleaned = remove_zero_files_and_records(
    gdf_dataset=gdf_dataset_filtered,
    dataset_output_img_path=DATASET_OUTPUT_IMG_PATH,
    dataset_output_masks_path=DATASET_OUTPUT_MASKS_PATH
)

# Run final verification
print("\nSTEP 3: Running final verification on cleaned dataset")
print(f"Records going into verification: {len(gdf_dataset_cleaned)}")

def final_verification_checks_no_zeros(gdf_dataset, dataset_output_img_path, dataset_output_masks_path):
    """
    Perform final verification checks on the dataset.
    
    Validates file existence, dimensions, data integrity, and identifies
    duplicate images. Zero-content check is skipped as it was done in preprocessing.
    
    Parameters:
        gdf_dataset: GeoDataFrame to verify
        dataset_output_img_path: Path to image directory
        dataset_output_masks_path: Path to mask directory
        
    Returns:
        GeoDataFrame: Dataset with validation_processing column added
    """
    
    print("RUNNING FINAL VERIFICATION CHECKS (NO ZERO CHECK)")
    print("=" * 50)
    
    # Make copy to avoid warnings
    gdf_dataset = gdf_dataset.copy()
    
    # Initialize validation column
    gdf_dataset['validation_processing'] = 'ok'
    
    print(f"\n0. Initialized validation_processing column for {len(gdf_dataset)} records")
    
    # File existence verification
    print("\n1. Verifying all referenced files exist...")
    
    missing_images = []
    missing_masks = []
    
    for idx, row in gdf_dataset.iterrows():
        # Check image files
        img_path = row.get('processed_img_path_tif')
        if pd.notna(img_path) and not os.path.exists(img_path):
            missing_images.append(row['tile_id'])
            gdf_dataset.loc[idx, 'validation_processing'] = 'ko'
        
        # Check mask files  
        mask_path = row.get('processed_mask_path_tif')
        if pd.notna(mask_path) and not os.path.exists(mask_path):
            missing_masks.append(row['tile_id'])
            gdf_dataset.loc[idx, 'validation_processing'] = 'ko'
    
    assert len(missing_images) == 0, f"Missing image files for tiles: {missing_images[:10]}"
    assert len(missing_masks) == 0, f"Missing mask files for tiles: {missing_masks[:10]}"
    print(f"   All {len(gdf_dataset)} image and mask files exist")
    
    # Dimension consistency check
    print("\n2. Verifying file dimensions are 1280x1280...")
    
    target_size = (1280, 1280)
    dimension_errors = []
    
    # Check sample for performance
    sample_size = min(10, len(gdf_dataset))
    sample_indices = np.random.choice(len(gdf_dataset), sample_size, replace=False)
    
    for idx in sample_indices:
        row = gdf_dataset.iloc[idx]
        tile_id = row['tile_id']
        
        # Check image dimensions
        img_path = row.get('processed_img_path_tif')
        if pd.notna(img_path):
            with rasterio.open(img_path) as src:
                if (src.width, src.height) != target_size:
                    dimension_errors.append(f"Image {tile_id}: {src.width}x{src.height}")
                    gdf_dataset.loc[idx, 'validation_processing'] = 'ko'
        
        # Check mask dimensions
        mask_path = row.get('processed_mask_path_tif')
        if pd.notna(mask_path):
            with rasterio.open(mask_path) as mask_src:
                if (mask_src.width, mask_src.height) != target_size:
                    dimension_errors.append(f"Mask {tile_id}: {mask_src.width}x{mask_src.height}")
                    gdf_dataset.loc[idx, 'validation_processing'] = 'ko'
    
    assert len(dimension_errors) == 0, f"Dimension errors found: {dimension_errors}"
    print(f"   Sample check: All {sample_size} files have correct 1280x1280 dimensions")
    
    # Column validation
    print("\n3. Validating critical columns...")
    
    required_columns = ['tile_id', 'processed_img_path_tif', 'processed_mask_path_tif']
    for col in required_columns:
        assert col in gdf_dataset.columns, f"Required column missing: {col}"
        null_count = gdf_dataset[col].isnull().sum()
        assert null_count == 0, f"Found {null_count} null values in required column: {col}"
    
    print("   All required columns present with no null values")
    
    # Geometry validation
    print("\n4. Validating geometries...")
    
    if 'geometry' in gdf_dataset.columns:
        null_geoms = gdf_dataset['geometry'].isnull().sum()
        assert null_geoms == 0, f"Found {null_geoms} null geometries"
        
        invalid_geoms = []
        for idx, row in gdf_dataset.iterrows():
            if not row['geometry'].is_valid:
                invalid_geoms.append(row['tile_id'])
                gdf_dataset.loc[idx, 'validation_processing'] = 'ko'
        
        assert len(invalid_geoms) == 0, f"Invalid geometries found for tiles: {invalid_geoms[:10]}"
        print(f"   All {len(gdf_dataset)} geometries are valid")
    
    # Directory structure check
    print("\n5. Verifying directory structure...")
    
    assert os.path.exists(dataset_output_img_path), f"Image directory does not exist: {dataset_output_img_path}"
    assert os.path.exists(dataset_output_masks_path), f"Mask directory does not exist: {dataset_output_masks_path}"
    
    # Count expected files
    expected_img_files = set()
    expected_mask_files = set()
    
    for idx, row in gdf_dataset.iterrows():
        img_path = row.get('processed_img_path_tif')
        mask_path = row.get('processed_mask_path_tif')
        
        if pd.notna(img_path):
            expected_img_files.add(os.path.basename(img_path))
        if pd.notna(mask_path):
            expected_mask_files.add(os.path.basename(mask_path))
    
    actual_img_files = set([f for f in os.listdir(dataset_output_img_path) if f.endswith(('.tif', '.tiff'))])
    actual_mask_files = set([f for f in os.listdir(dataset_output_masks_path) if f.endswith(('.tif', '.tiff'))])
    
    # Verify all expected files exist
    missing_expected_imgs = expected_img_files - actual_img_files
    missing_expected_masks = expected_mask_files - actual_mask_files
    
    assert len(missing_expected_imgs) == 0, f"Expected image files missing: {list(missing_expected_imgs)[:5]}"
    assert len(missing_expected_masks) == 0, f"Expected mask files missing: {list(missing_expected_masks)[:5]}"
    
    print(f"   Directory structure correct: {len(expected_img_files)} expected images, {len(expected_mask_files)} expected masks")
    
    # Data integrity check
    print("\n6. Checking basic data integrity...")
    
    # Check for empty rows
    empty_rows = gdf_dataset.isnull().all(axis=1).sum()
    assert empty_rows == 0, f"Found {empty_rows} completely empty rows"
    
    # Check dataset not empty
    assert len(gdf_dataset) > 0, "Dataset is empty"
    
    # Check memory usage
    memory_mb = gdf_dataset.memory_usage(deep=True).sum() / 1024**2
    assert memory_mb < 1000, f"Dataset unusually large: {memory_mb:.1f} MB"
    
    print(f"   Dataset integrity OK: {len(gdf_dataset)} records, {memory_mb:.1f} MB")
    
    # Skip zero content check
    print("\n7. Zero content check: SKIPPED (already done in preprocessing)")
    
    # Duplicate image detection
    print("\n8. Checking for duplicate images...")
    
    # Calculate image hashes
    image_hashes = {}
    duplicate_groups = []
    
    for idx, row in gdf_dataset.iterrows():
        img_path = row.get('processed_img_path_tif')
        tile_id = row['tile_id']
        
        if pd.notna(img_path) and os.path.exists(img_path):
            try:
                with rasterio.open(img_path) as src:
                    img_data = src.read()
                    # Create hash of image data
                    img_hash = hash(img_data.tobytes())
                    
                    if img_hash in image_hashes:
                        # Found duplicate
                        if len(image_hashes[img_hash]) == 1:
                            # First duplicate, create group
                            duplicate_groups.append(image_hashes[img_hash] + [tile_id])
                        else:
                            # Add to existing group
                            for group in duplicate_groups:
                                if image_hashes[img_hash][0] in group:
                                    group.append(tile_id)
                                    break
                        
                        image_hashes[img_hash].append(tile_id)
                    else:
                        image_hashes[img_hash] = [tile_id]
                        
            except Exception as e:
                print(f"   Error processing image {tile_id}: {e}")
                gdf_dataset.loc[idx, 'validation_processing'] = 'ko'
    
    # Mark duplicates as 'ko'
    duplicate_count = 0
    for group in duplicate_groups:
        for tile_id in group:
            tile_idx = gdf_dataset[gdf_dataset['tile_id'] == tile_id].index[0]
            gdf_dataset.loc[tile_idx, 'validation_processing'] = 'ko'
            duplicate_count += 1
    
    print(f"   Found {len(duplicate_groups)} duplicate groups affecting {duplicate_count} files")
    if duplicate_groups:
        print(f"   Example duplicate group: {duplicate_groups[0]}")
    
    # Validation summary
    print("\n9. Validation summary...")
    
    validation_counts = gdf_dataset['validation_processing'].value_counts()
    ok_count = validation_counts.get('ok', 0)
    ko_count = validation_counts.get('ko', 0)
    
    print(f"   Records marked 'ok': {ok_count}")
    print(f"   Records marked 'ko': {ko_count}")
    print(f"   Success rate: {ok_count/len(gdf_dataset)*100:.1f}%")
    
    print("\n" + "=" * 50)
    print("ALL VERIFICATION CHECKS PASSED")
    print(f"Dataset validated: {ok_count} OK, {ko_count} KO")
    print("Dataset is ready for saving")
    print("=" * 50)
    
    return gdf_dataset

# Run verification
gdf_dataset_final = final_verification_checks_no_zeros(
    gdf_dataset=gdf_dataset_cleaned,
    dataset_output_img_path=DATASET_OUTPUT_IMG_PATH,
    dataset_output_masks_path=DATASET_OUTPUT_MASKS_PATH
)

print("\n" + "=" * 60)
print("FINAL RESULTS AFTER CLEANUP AND VERIFICATION")
print("=" * 60)
print("Final validation counts:")
print(gdf_dataset_final['validation_processing'].value_counts())
print(f"\nDataset ready for saving: {len(gdf_dataset_final)} total records")
print(f"High-quality records: {(gdf_dataset_final['validation_processing'] == 'ok').sum()}")

assert (gdf_dataset_final['validation_processing'] == 'ko').sum() == 0, "There are still 'ko' records in the dataset"           

## Save Results

In [None]:
# Save final datasets as parquet files
df_verification.to_parquet(VERIFICATION_OUTPUT_PARQUET_PATH, index=False)
gdf_dataset.to_parquet(DATASET_OUTPUT_PARQUET_PATH, index=False)
gdf_dataset_final.to_parquet(DATASET_FINAL_OUTPUT_PARQUET_PATH, index=False)