In [None]:
try:
    # Essential imports
    import re
    import os
    import sys
    import json
    import glob
    import torch
    import torch.nn as nn
    import logging
    import rasterio
    import numpy as np
    import pandas as pd
    import seaborn as sns
    import geopandas as gpd
    import matplotlib.pyplot as plt
    from PIL import Image
    from datetime import datetime
    from pathlib import Path
    from scipy import stats
    from typing import Dict, Tuple, Optional
    from collections import defaultdict
    from rasterio.mask import mask
    from shapely.ops import unary_union
    from shapely.wkt import dumps, loads
    from shapely.geometry import mapping, box, Polygon, MultiPolygon
    from rasterio.windows import from_bounds
    from s2cloudless import S2PixelCloudDetector
    from tqdm import tqdm
    from tqdm.notebook import tqdm
    from sklearn.model_selection import TimeSeriesSplit
except Exception as e:
    print(f"Error : {e}")

In [None]:
# Create project directory structure
project_dir = Path("../Solutions/Land_Change_Monitoring")
subdirs = [
    "../Datasets/Sentinel-2/",           # Original GeoJSON and Sentinel-2 data
    "../Datasets/Testing/Processed",     # Processed and grouped events
    "../Datasets/Testing/Samples",       # Our sampled datasets
    "../Datasets/Testing/Tiles",         # Generated image tiles
    "../Docs/Diagrams",                             # Results and visualizations
    "../Models",                                    # Trained models
    "../Docs/Logs"                                  # Processing logs
]

# Create directories if they don't exist
for subdir in subdirs:
    Path(subdir).mkdir(parents=True, exist_ok=True)

# Set up logging
logging.basicConfig(
    filename=Path("../Docs/Logs/processing.log"),
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Print created directory structure for verification
print("Created directory structure:")
for subdir in subdirs:
    if Path(subdir).exists():
        print(f"✓ {subdir}")
    else:
        print(f"✗ {subdir}")

In [None]:
# Print the PyTorch version
print(f"PyTorch version: {torch.__version__}")

# Check if running in Google Colab
if "google.colab" in str(get_ipython()):
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = 'cpu'
        print("GPU not available in Colab, consider enabling a GPU runtime.")
# Running on a local machine
else:
    if torch.backends.mps.is_available():
        device = 'mps'
        print(f"Is Apple MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
        print(f"Is Apple MPS available? {torch.backends.mps.is_available()}")
    elif torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

# TODO: Add support for AMD ROCm GPU if needed

# Print the device being used
print(f"Using device: {device}")

In [None]:
# Set the path to the GeoJSON file
base_path = Path('../Datasets/Sentinel-2')
geojson_path = Path('../Datasets/BoundingBox/deforestation.geojson')
safe_dirs = list(base_path.glob("*/*.SAFE"))

In [None]:
# Load the GeoJSON file
gdf = gpd.read_file(geojson_path)

# Display GeoJSON information
print("GeoJSON Information:", len(gdf))
print("-" * 50)
print(gdf.info())
print("\nFirst few records:")
print(gdf.head())

# Get the bounding box coordinates
bbox = gdf.total_bounds
print("\nBounding Box (minx, miny, maxx, maxy):")
print(bbox)

# Display basic information about the GeoJSON
print("\nGeoJSON CRS:", gdf.crs)
print("Number of features:", len(gdf))
print("Columns:", gdf.columns.tolist())
print("First geometry type:", gdf.geometry.iloc[0].geom_type)

In [None]:
# Ensure 'img_date' is in datetime format if it exists
if 'img_date' in gdf.columns:
    gdf['img_date'] = pd.to_datetime(gdf['img_date'], errors='coerce')

    # Drop rows with invalid dates if any
    gdf = gdf.dropna(subset=['img_date'])

    # Get unique dates and sort them
    unique_dates = gdf['img_date'].dt.date.unique()
    unique_dates.sort()

    date_counts = gdf['img_date'].dt.date.value_counts().sort_index()
    print("Occurrences of each 'img_date':")
    print(date_counts)
else:
    print("'img_date' column not found in the GeoDataFrame.")

In [None]:
# Function to reduce geometry precision
def reduce_precision(geometry, decimal_places=5):
    return loads(dumps(geometry, rounding_precision=decimal_places))

# Apply precision reduction to all geometries
gdf['geometry'] = gdf['geometry'].apply(lambda geom: reduce_precision(geom))

# Now check for duplicates again
duplicate_geometries = gdf[gdf.geometry.duplicated(keep=False)]
print("Duplicate geometries after reducing precision:")
print(duplicate_geometries)

In [None]:
# Drop duplicates based on geometry
# gdf = gdf.drop_duplicates(subset='geometry')

# Save the cleaned GeoDataFrame to a new GeoJSON file
# gdf.to_file("deforestation_unique.geojson", driver='GeoJSON')

print(f"Number of geometries present in gdf: {len(gdf)}")

In [None]:
# Plot the geospatial data
gdf.plot()
plt.title('Deforestation Areas (Ukraine)')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.show()

In [None]:
# Analyze temporal distribution
gdf['year'] = gdf['img_date'].dt.year
gdf['month'] = gdf['img_date'].dt.month

# Create year-month summary
temporal_dist = gdf.groupby(['year', 'month']).size().unstack(fill_value=0)
print("Deforestation events by year and month:")
print(temporal_dist)

# Distribution by tile
print("\nDeforestation events by tile:")
print(gdf['tile'].value_counts())

# Create a monthly summary plot
plt.figure(figsize=(12, 6))
temporal_dist.T.plot(kind='bar', stacked=True)
plt.title('Deforestation Events by Month and Year')
plt.xlabel('Month')
plt.ylabel('Number of Events')
plt.legend(title='Year')
plt.tight_layout()
plt.show()

In [None]:
# Calculate area of each polygon in square meters
# Converting to UTM projection for accurate area and perimeter calculation
gdf['area'] = gdf.geometry.to_crs({'proj':'utm', 'zone':36, 'ellps':'WGS84'}).area
gdf['perimeter'] = gdf.geometry.to_crs({'proj':'utm', 'zone':36, 'ellps':'WGS84'}).length

# Basic statistics of polygon sizes
print("Polygon area statistics (square meters):")
print(gdf['area'].describe())

# Create histogram of polygon sizes
plt.figure(figsize=(10, 6))
plt.hist(gdf['area'], bins=50, edgecolor='black')
plt.title('Distribution of Deforestation Polygon Sizes')
plt.xlabel('Area (square meters)')
plt.ylabel('Count')
plt.yscale('log')  # Using log scale for better visualization
plt.grid(True, alpha=0.3)
plt.show()

# Calculate basic shape metrics
gdf['complexity'] = gdf['perimeter'] / (4 * np.sqrt(gdf['area']))

print("Shape complexity statistics (1.0 = perfect square):")
print(gdf['complexity'].describe())

In [None]:
# Reproject to UTM Zone 36N
deforestation_data_utm = gdf.to_crs(epsg=32636)

# Recalculate statistics in UTM projection
polygon_areas = deforestation_data_utm.geometry.area  # now in square meters
polygon_bounds = deforestation_data_utm.geometry.bounds

# Calculate statistics
max_width = (polygon_bounds.maxx - polygon_bounds.minx).max()
max_height = (polygon_bounds.maxy - polygon_bounds.miny).max()
mean_area = polygon_areas.mean()
median_area = polygon_areas.median()

print(f"Polygon Statistics:")
print(f"Mean area: {mean_area/10000:.2f} hectares")
print(f"Median area: {median_area/10000:.2f} hectares")
print(f"Max width: {max_width:.2f} meters")
print(f"Max height: {max_height:.2f} meters")

# Calculate optimal mesh size based on polygon distribution
hist, bins = np.histogram(polygon_areas, bins=20)
print("\nArea Distribution (hectares):")
for i in range(len(hist)):
    if hist[i] > 0:
        print(f"{bins[i]/10000:.2f} - {bins[i+1]/10000:.2f}: {hist[i]} polygons")

# Additional useful statistics
print("\nAdditional Statistics:")
print(f"Total number of polygons: {len(deforestation_data_utm)}")
print(f"Total area of deforestation: {polygon_areas.sum()/10000:.2f} hectares")
print(f"95th percentile area: {np.percentile(polygon_areas, 95)/10000:.2f} hectares")

In [None]:
def collect_jp2_files(base_dir):
    jp2_files = []  # Create empty list to store paths
    for root, _, files in os.walk(base_dir):
        for file in files:
            if file.endswith(".jp2"):
                full_path = os.path.join(root, file)
                jp2_files.append(full_path)
    return jp2_files


jp2_files = collect_jp2_files(base_path)

# Print count
print(f"Total files found: {len(jp2_files)}\n")

# Print all paths
print("Found .jp2 files:")
for path in jp2_files:
    print(path)

In [None]:
def data_parser(file_path):
    # Extract filename and directory structure
    path_parts = file_path.split(os.sep)
    
    # Get product name from the first directory in the path containing S2
    product_name = next(part for part in path_parts if part.startswith('S2'))
    
    # Get band information from the filename (last part)
    filename = path_parts[-1]
    band = filename.split('_')[-1].replace('.jp2', '').replace('B', '')
    
    # Split the product name into components
    parts = product_name.split('_')
    
    details = {
        'mission': parts[0],           # S2A or S2B satellite
        'product_level': parts[1],     # Processing level (MSIL1C)
        'sensing_date': parts[2][:8],  # YYYYMMDD
        'sensing_time': parts[2][9:],  # HHMMSS
        'processing_number': parts[3],  # Processing baseline number
        'orbit_number': parts[4],      # Relative orbit number
        'tile_number': parts[5][1:],   # Tile identifier
        'product_date': parts[6][:8],  # Product generation date
        'product_time': parts[6][9:],  # Product generation time
        'band': band,                  # Band number
        'file_path': file_path           # Full file path
    }
    
    # Pretty print the results
    print(f"Mission           : {details['mission']}")
    print(f"Product Level     : {details['product_level']}")
    print(f"Sensing Date      : {details['sensing_date']}")
    print(f"Sensing Time      : {details['sensing_time']}")
    print(f"Processing Number : {details['processing_number']}")
    print(f"Orbit Number      : {details['orbit_number']}")
    print(f"Tile Number       : {details['tile_number']}")
    print(f"Product Date      : {details['product_date']}")
    print(f"Product Time      : {details['product_time']}")
    print(f"Band Number       : {details['band']}")
    print(f"File Path         : {details['file_path']}")

    return details

# Example usage:
file_path = jp2_files[70]
result = data_parser(file_path)

In [None]:
def analyze_sentinel_structure(base_path):
    """
    Analyzes Sentinel-2 dataset structure and returns key information
    
    Args:
        base_path (str): Path to the Sentinel-2 dataset directory
    
    Returns:
        pd.DataFrame: DataFrame containing image metadata
    """
    # Initialize lists to store metadata
    metadata = []
    
    # Convert to Path object
    base = Path(base_path)
    # Pattern for date extraction
    date_pattern = r'(\d{8}T\d{6})'
    
    # Iterate through all .SAFE directories
    for safe_dir in base.glob('*/*.SAFE'):
        # Extract metadata from directory name
        dir_name = safe_dir.parent.name
        
        # Extract date using regex
        date_match = re.search(date_pattern, dir_name)
        if date_match:
            acquisition_date = datetime.strptime(date_match.group(1), '%Y%m%dT%H%M%S')
        else:
            acquisition_date = None
            
        # Get satellite (S2A or S2B)
        satellite = dir_name[:3]
        
        # Get tile ID
        tile_match = re.search(r'T(\d{2}[A-Z]{3})', dir_name)
        tile_id = tile_match.group(1) if tile_match else None
        
        # Count number of bands
        bands = list(safe_dir.glob('GRANULE/*/IMG_DATA/*.jp2'))
        num_bands = len([b for b in bands if not b.name.endswith('TCI.jp2')])
        
        metadata.append({
            'satellite': satellite,
            'acquisition_date': acquisition_date,
            'tile_id': tile_id,
            'num_bands': num_bands,
            'path': safe_dir
        })
    
    # Create DataFrame
    df = pd.DataFrame(metadata)
    df = df.sort_values('acquisition_date')
    
    return df

# Usage example:
df = analyze_sentinel_structure(base_path)

# Display basic statistics
print("Dataset Summary:")
print(f"Total number of images: {len(df)}")
print("\nAcquisitions by satellite:")
print(df['satellite'].value_counts())
print("\nDate range:")
print(f"First acquisition: {df['acquisition_date'].min()}")
print(f"Last acquisition: {df['acquisition_date'].max()}")

# Display the DataFrame
df.head()

In [None]:
def get_band_statistics(image_path):
    """Get statistics for specific bands in an image"""
    band_paths = Path(image_path).glob('GRANULE/*/IMG_DATA/*.jp2')
    bands = {}
    for band_path in band_paths:
        band_name = re.search(r'B\d{2}|B8A', band_path.name)
        if band_name:
            bands[band_name.group(0)] = str(band_path)
    return dict(sorted(bands.items()))

def get_quality_masks(image_path):
    """Get list of quality masks for an image"""
    mask_paths = Path(image_path).glob('GRANULE/*/QI_DATA/*.gml')
    return [p.name for p in mask_paths]

# Example usage:
image_path = df.iloc[0]['path']
print("Band files in first image:")
for band, path in get_band_statistics(image_path).items():
    print(f"{band}: {path}")

In [None]:
# Get the first image directory to analyze bands
first_image = safe_dirs[0]
img_data_path = list((first_image / "GRANULE").glob("*"))[0] / "IMG_DATA"
band_files = list(img_data_path.glob("*.jp2"))

# Extract band information
band_info = []
for band_file in band_files:
    band_name = band_file.name.split('_')[-1].split('.')[0]
    
    # Open the band file to get metadata
    with rasterio.open(band_file) as src:
        band_info.append({
            'band': band_name,
            'width': src.width,
            'height': src.height,
            'dtype': src.dtypes[0],
            'resolution': src.res[0]  # pixel size in meters
        })

# Create DataFrame with band information
df_bands = pd.DataFrame(band_info)
print("Band Information:")
print("-" * 50)
print(df_bands.sort_values('band'))

# Count number of files per image
print("\nNumber of files per image:")
print(len(band_files))

# Print list of unique bands
print("\nAvailable bands:")
unique_bands = sorted(list(df_bands['band'].unique()))
print(unique_bands)

In [None]:
# Get first key
example = jp2_files[9]

def visualize_band(jp2_path):
    with rasterio.open(jp2_path) as src:
        band = src.read(1)  # Read the first band
        plt.imshow(band, cmap="gray")
        plt.title(jp2_path)
        plt.show()

visualize_band(example)

# Open the .jp2 file
with rasterio.open(example) as dataset:
    # Read the dataset's data as a numpy array
    band_data = dataset.read(1)
    # Access metadata
    metadata = dataset.meta

print(metadata)

In [None]:
# Create a list to store metadata
metadata_list = []

for safe_dir in safe_dirs:
    # Parse directory name
    dir_parts = safe_dir.name.split('_')
    
    metadata = {
        'satellite': dir_parts[0],  # S2A or S2B
        'processing_level': dir_parts[1],  # MSIL1C
        'timestamp': datetime.strptime(dir_parts[2], '%Y%m%dT%H%M%S'),
        'relative_orbit': dir_parts[4],  # R064
        'tile_id': dir_parts[5],  # T36UYA
        'path': safe_dir
    }
    metadata_list.append(metadata)

# Create DataFrame
df_metadata = pd.DataFrame(metadata_list)

# Basic analysis
print("Dataset Summary:")
print("-" * 50)
print(f"Date range: {df_metadata['timestamp'].min()} to {df_metadata['timestamp'].max()}")
print(f"Number of unique tiles: {df_metadata['tile_id'].nunique()}")
print(f"Number of satellites: {df_metadata['satellite'].nunique()}")
print("\nSatellite distribution:")
print(df_metadata['satellite'].value_counts())
print("\nTile distribution:")
print(df_metadata['tile_id'].value_counts())

# Visualize temporal distribution
plt.figure(figsize=(15, 6))
plt.hist(df_metadata['timestamp'], bins=20, edgecolor='black')
plt.title('Temporal Distribution of Satellite Images')
plt.xlabel('Date')
plt.ylabel('Number of Images')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Monthly distribution
df_metadata['month'] = df_metadata['timestamp'].dt.month
monthly_counts = df_metadata['month'].value_counts().sort_index()

plt.figure(figsize=(12, 5))
monthly_counts.plot(kind='bar')
plt.title('Monthly Distribution of Images')
plt.xlabel('Month')
plt.ylabel('Number of Images')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()

In [None]:
# Sample size calculation with confidence level
def calculate_strata_sizes(total_events=480, confidence_level=0.95, margin_error=0.1):
    z_score = stats.norm.ppf((1 + confidence_level) / 2)
    
    # Define strata bounds and proportions
    strata_bounds = [0, 3.37, 10.10, np.inf]  # hectares
    strata_props = [0.8, 0.15, 0.05]  # given proportions
    
    # Calculate sample sizes for each stratum
    sample_sizes = []
    for prop in strata_props:
        stratum_n = int(np.ceil((z_score**2 * prop * (1-prop)) / margin_error**2))
        sample_sizes.append(min(stratum_n, int(total_events * prop)))
    
    return sample_sizes

In [None]:
def group_deforestation_events(gdf, spatial_thresh, temporal_thresh):
    # Create time-based groups
    gdf['temporal_group'] = pd.to_datetime(gdf['date']).dt.to_period('D')
    temporal_groups = gdf.groupby(pd.Grouper(key='temporal_group', freq=f'{temporal_thresh}D'))
    
    grouped_events = []
    for _, time_group in temporal_groups:
        if len(time_group) > 0:
            # Buffer and merge nearby polygons
            buffered = time_group.geometry.buffer(spatial_thresh)
            merged = unary_union(buffered)
            grouped_events.append(merged)
            
    return gpd.GeoDataFrame(geometry=grouped_events)

In [None]:
import datetime
# Define temporal window
temporal_window = [
    datetime.datetime(2018, 7, 15),
    datetime.datetime(2018, 9, 15)
]

In [None]:
def analyze_size_distribution(gdf):
    """Analyze the size distribution of events before sampling"""
    size_stats = {
        'small': len(gdf[gdf['area_ha'] <= 3.37]),
        'medium': len(gdf[(gdf['area_ha'] > 3.37) & (gdf['area_ha'] <= 10.10)]),
        'large': len(gdf[gdf['area_ha'] > 10.10])
    }
    return size_stats

def sample_deforestation_events(geojson_path, output_dir, temporal_window):
    # Read and initial processing
    gdf = gpd.read_file(geojson_path)
    
    # Log initial dataset size
    logging.info(f"Total events in dataset: {len(gdf)}")
    
    # Convert dates and filter for our temporal window
    gdf['img_date'] = pd.to_datetime(gdf['img_date'])
    mask = (gdf['img_date'] >= temporal_window[0]) & (gdf['img_date'] <= temporal_window[1])
    filtered_gdf = gdf[mask].copy()
    
    logging.info(f"Events in temporal window: {len(filtered_gdf)}")
    
    # Project to UTM and calculate area in hectares
    filtered_gdf = filtered_gdf.to_crs('EPSG:32736')  # UTM zone 36N for Ukraine
    filtered_gdf['area_ha'] = filtered_gdf.geometry.area / 10000
    
    # Analyze size distribution before sampling
    initial_distribution = analyze_size_distribution(filtered_gdf)
    logging.info("Initial size distribution:")
    logging.info(initial_distribution)
    
    # Define strata with adjusted sample sizes based on availability
    filtered_gdf['size_category'] = pd.cut(
        filtered_gdf['area_ha'],
        bins=[0, 3.37, 10.10, float('inf')],
        labels=['small', 'medium', 'large']
    )
    
    # Adjust sample sizes based on availability
    target_sizes = {'small': 50, 'medium': 15, 'large': 7}
    actual_sizes = {}
    
    for category, target in target_sizes.items():
        available = len(filtered_gdf[filtered_gdf['size_category'] == category])
        actual_sizes[category] = min(target, available)
        logging.info(f"{category}: Target={target}, Available={available}, Will sample={actual_sizes[category]}")
    
    # Stratified sampling with adjusted sizes
    sampled_events = pd.DataFrame()
    
    for category, size in actual_sizes.items():
        stratum = filtered_gdf[filtered_gdf['size_category'] == category]
        if len(stratum) > 0:
            sampled = stratum.sample(
                n=size, 
                random_state=42
            )
            sampled_events = pd.concat([sampled_events, sampled])
    
    # Convert back to GeoDataFrame and process
    sampled_events = gpd.GeoDataFrame(sampled_events, geometry='geometry')
    sampled_events = sampled_events.to_crs(4326)
    
    # Save results
    output_dir = Path(output_dir).resolve()
    output_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = output_dir / f"sampled_events_{timestamp}.geojson"
    
    sampled_events.to_file(output_path, driver='GeoJSON')
    
    return sampled_events, output_path, initial_distribution

# Execute sampling
sampled_events, saved_path, initial_dist = sample_deforestation_events(
    geojson_path=geojson_path,
    output_dir=Path("../Datasets/Testing/Samples").resolve(),
    temporal_window=temporal_window
)

# Display comprehensive results
print("\nInitial Distribution in Temporal Window:")
for category, count in initial_dist.items():
    print(f"{category}: {count}")

print("\nFinal Sample Distribution:")
print(sampled_events['size_category'].value_counts())

print(f"\nSpatial Distribution:")
print(sampled_events.groupby('tile')['size_category'].count())

print(f"\nTemporal Distribution:")
print(sampled_events.groupby(sampled_events['img_date'].dt.strftime('%Y-%m-%d'))['size_category'].count())

In [None]:
# First, let's analyze the raw data to understand the area distribution
def analyze_raw_distribution(geojson_path):
    gdf = gpd.read_file(geojson_path)
    gdf = gdf.to_crs('EPSG:32736')  # UTM zone 36N for Ukraine
    gdf['area_ha'] = gdf.geometry.area / 10000
    
    print("Area Statistics (hectares):")
    print(gdf['area_ha'].describe())
    
    print("\nQuantile Distribution:")
    quantiles = [0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
    print(gdf['area_ha'].quantile(quantiles))
    
    # Count events by tile
    print("\nEvents per tile:")
    print(gdf['tile'].value_counts())
    
    # Temporal distribution by month
    gdf['month'] = pd.to_datetime(gdf['img_date']).dt.strftime('%Y-%m')
    print("\nEvents per month:")
    print(gdf['month'].value_counts().sort_index())
    
    return gdf

# Analyze full dataset
raw_data = analyze_raw_distribution(geojson_path)

# Let's adjust our size categories based on the actual distribution
def recalculate_size_thresholds(gdf, num_categories=3):
    quantiles = np.linspace(0, 1, num_categories + 1)[1:-1]
    thresholds = gdf['area_ha'].quantile(quantiles)
    
    print("\nProposed new size thresholds:")
    print(f"Small: <= {thresholds.iloc[0]:.2f} ha")
    print(f"Medium: {thresholds.iloc[0]:.2f} - {thresholds.iloc[1]:.2f} ha")
    print(f"Large: > {thresholds.iloc[1]:.2f} ha")
    
    return thresholds.tolist()

# Calculate new thresholds
new_thresholds = recalculate_size_thresholds(raw_data)

In [None]:
def sample_deforestation_events(geojson_path, output_dir, temporal_window):

    # Read and process data
    gdf = gpd.read_file(geojson_path)
    gdf['img_date'] = pd.to_datetime(gdf['img_date'])

    # Filter temporal window
    mask = (gdf['img_date'] >= temporal_window[0]) & (gdf['img_date'] <= temporal_window[1])
    filtered_gdf = gdf[mask].copy()

    # Calculate areas
    filtered_gdf = filtered_gdf.to_crs('EPSG:32736')
    filtered_gdf['area_ha'] = filtered_gdf.geometry.area / 10000

    # Define size categories
    filtered_gdf['size_category'] = pd.cut(
        filtered_gdf['area_ha'],
        bins=[0, 0.60, 2.35, float('inf')],
        labels=['small', 'medium', 'large']
    )

    # Calculate available events per category
    available = filtered_gdf['size_category'].value_counts()
    print("Available events per category:")
    print(available)

    # Define target samples (adjusted based on availability)
    target_samples = {
        'small': min(10, available.get('small', 0)),
        'medium': min(15, available.get('medium', 0)),
        'large': min(0, available.get('large', 0))
    }
    print("\nTarget samples per category:")
    print(target_samples)

    # Sample from each category
    sampled_events = pd.DataFrame()
    for category, target in target_samples.items():
        if target > 0:
            stratum = filtered_gdf[filtered_gdf['size_category'] == category]
            # Ensure even temporal distribution within each category
            sampled = stratum.groupby(stratum['img_date'].dt.to_period('M')).apply(
                lambda x: x.sample(
                    n=min(
                        max(1, target // len(stratum['img_date'].dt.to_period('M').unique())),
                        len(x)
                    ),
                    random_state=42
                )
            ).reset_index(drop=True)
            # If we still need more samples, take them randomly from the remaining events
            if len(sampled) < target:
                remaining = stratum[~stratum.index.isin(sampled.index)]
                if len(remaining) > 0:
                    additional = remaining.sample(
                        n=min(target - len(sampled), len(remaining)),
                        random_state=42
                    )
                    sampled = pd.concat([sampled, additional])
            sampled_events = pd.concat([sampled_events, sampled])

    # Convert to GeoDataFrame and prepare for saving
    sampled_events = gpd.GeoDataFrame(sampled_events, geometry='geometry')
    sampled_events = sampled_events.to_crs(4326)

    # Assign 'name' property
    sampled_events['name'] = ['PLOT-{0:05d}'.format(i) for i in range(1, len(sampled_events) + 1)]

    # Convert 'img_date' back to string format 'YYYY-MM-DD' before saving
    sampled_events['img_date'] = sampled_events['img_date'].dt.strftime('%Y-%m-%d')

    # Save results
    output_dir = Path(output_dir).resolve()
    output_dir.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = output_dir / f"sampled_events_{timestamp}.geojson"
    sampled_events.to_file(output_path, driver='GeoJSON')

    # Print detailed statistics
    print(f"\nTotal number of sampled events: {len(sampled_events)}")
    print("\nFinal sample distribution:")
    print("\nBy size category:")
    print(sampled_events['size_category'].value_counts())
    print("\nTemporal distribution:")
    print(sampled_events.groupby([
        sampled_events['img_date'].str.slice(0, 7),  # Get 'YYYY-MM' from 'YYYY-MM-DD'
        'size_category'
    ], observed=True).size().unstack(fill_value=0))

    return sampled_events, output_path

# Execute sampling with new parameters
sampled_events, saved_path = sample_deforestation_events(
    geojson_path=geojson_path,
    output_dir=Path("../Datasets/Testing/Samples").resolve(),
    temporal_window=temporal_window
)

# Additional analysis of the results
print("\nArea statistics of sampled events:")
print(sampled_events['area_ha'].describe())

print("\nDate distribution:")
print(sampled_events['img_date'].value_counts().sort_index())

In [None]:
def visualize_sample_distribution(sampled_events):
    # Create a figure with 2x2 subplots, but only use 3
    fig = plt.figure(figsize=(15, 12))
    
    # Adjust the layout to use only 3 plots
    gs = plt.GridSpec(2, 2)
    ax1 = fig.add_subplot(gs[0, 0])  # top-left
    ax2 = fig.add_subplot(gs[0, 1])  # top-right
    ax3 = fig.add_subplot(gs[1, :])  # bottom, spanning both columns
    
    # Size category distribution
    sns.countplot(data=sampled_events, x='size_category', ax=ax1)
    ax1.set_title('Distribution by Size Category')
    ax1.set_ylabel('Count')
    
    # Temporal distribution
    sampled_events['date'] = pd.to_datetime(sampled_events['img_date']).dt.date
    sns.histplot(data=sampled_events, x='date', ax=ax2)
    ax2.set_title('Temporal Distribution')
    ax2.tick_params(axis='x', rotation=45)
    
    # Area distribution - make this plot wider
    sns.boxplot(data=sampled_events, y='area_ha', x='size_category', ax=ax3)
    ax3.set_title('Area Distribution by Category')
    ax3.set_ylabel('Area (hectares)')
    
    # Save metadata
    sample_metadata = {
        'total_samples': len(sampled_events),
        'temporal_range': {
            'start': str(sampled_events['img_date'].min()),
            'end': str(sampled_events['img_date'].max())
        },
        'size_distribution': sampled_events['size_category'].value_counts().to_dict(),
        'area_statistics': {
            'mean': float(sampled_events['area_ha'].mean()),
            'median': float(sampled_events['area_ha'].median()),
            'std': float(sampled_events['area_ha'].std())
        }
    }
    
    output_dir = Path("../Docs/Diagrams/")
    with open(output_dir / 'sample_metadata.json', 'w') as f:
        json.dump(sample_metadata, f, indent=2)
    
    plt.tight_layout()
    plt.savefig(output_dir / 'sample_distribution.png')
    plt.close()

# Generate visualizations
visualize_sample_distribution(sampled_events)

In [None]:
class SentinelPatchProcessor:
    def __init__(self, patch_size=224):
        """
        Initialize the Sentinel-2 patch processor.
        
        Args:
            patch_size (int): Size of the output patches (default: 224)
        """
        self.patch_size = patch_size
        
    def get_tile_id(self, sentinel_path):
        """Extract tile ID from Sentinel path"""
        match = re.search(r'T\d{2}[A-Z]{3}', str(sentinel_path))
        return match.group(0) if match else None
        
    def get_tile_bounds(self, sentinel_path):
        """Get the geographical bounds of a Sentinel tile"""
        sample_band = next(Path(sentinel_path).glob('GRANULE/*/IMG_DATA/*B02.jp2'))
        with rasterio.open(sample_band) as src:
            bounds = box(*src.bounds)
        return bounds
        
    def group_geometries_by_tile(self, geojson_path, sentinel_path):
        """Group geometries based on which tile they intersect with"""
        gdf = gpd.read_file(geojson_path)
        tile_bounds = self.get_tile_bounds(sentinel_path)
        tile_id = self.get_tile_id(sentinel_path)
        
        # Transform geometries to tile CRS if needed
        with rasterio.open(next(Path(sentinel_path).glob('GRANULE/*/IMG_DATA/*B02.jp2'))) as src:
            if gdf.crs != src.crs:
                gdf = gdf.to_crs(src.crs)
        
        # Filter geometries that intersect with this tile
        mask = gdf.geometry.intersects(tile_bounds)
        tile_geometries = gdf[mask].copy()
        
        # Clip geometries to tile bounds
        tile_geometries['geometry'] = tile_geometries.geometry.intersection(tile_bounds)
        
        return tile_geometries if not tile_geometries.empty else None

    def load_bands(self, sentinel_path):
        """Load and stack Sentinel-2 bands"""
        band_paths = list(Path(sentinel_path).glob('GRANULE/*/IMG_DATA/*.jp2'))
        band_data = {}
        required_bands = ['B02', 'B03', 'B04', 'B08', 'B8A', 'B11', 'B12']
        
        for band_path in band_paths:
            band_name = re.search(r'B\d{2}|B8A', band_path.name)
            if band_name:
                band_name = band_name.group(0)
                if band_name in required_bands:
                    try:
                        with rasterio.open(band_path) as src:
                            band_data[band_name] = src.read(1)
                            if band_name == 'B02':
                                self.meta = src.meta.copy()
                            logging.info(f"Loaded band {band_name} from {self.get_tile_id(sentinel_path)}")
                    except Exception as e:
                        logging.error(f"Error loading {band_name} from {sentinel_path}: {str(e)}")
        
        if len(band_data) != len(required_bands):
            missing_bands = set(required_bands) - set(band_data.keys())
            logging.error(f"Missing required bands for {self.get_tile_id(sentinel_path)}: {missing_bands}")
            return None
            
        return band_data

    def validate_bands(self, band_data, required_bands):
        """Validate that all required bands are present"""
        missing_bands = [band for band in required_bands if band not in band_data]
        if missing_bands:
            raise ValueError(f"Missing required bands: {missing_bands}")

    def resample_to_10m(self, band_data):
        """Resample all bands to 10m resolution"""
        try:
            # Get shape from a 10m band (B02)
            target_shape = band_data['B02'].shape
            logging.debug(f"Target shape for resampling: {target_shape}")
            
            # Bands that need resampling (20m bands)
            bands_to_resample = ['B8A', 'B11', 'B12']
            
            for band in bands_to_resample:
                if band in band_data and band_data[band].shape != target_shape:
                    logging.info(f"Resampling {band} to 10m resolution")
                    band_data[band] = self._resample_array(
                        band_data[band],
                        target_shape
                    )
                    logging.debug(f"Resampled {band} shape: {band_data[band].shape}")
                    
            return band_data
            
        except Exception as e:
            logging.error(f"Error in resample_to_10m: {str(e)}")
            raise

    def _resample_array(self, array, target_shape):
        """Helper function to resample arrays using bilinear interpolation"""
        try:
            # Convert array to PIL Image for resampling
            img = Image.fromarray(array)
            
            # Resize to target shape (note the order: width, height)
            resized = img.resize(
                (target_shape[1], target_shape[0]),  # PIL uses (width, height)
                resample=Image.BILINEAR
            )
            
            # Convert back to numpy array
            return np.array(resized)
            
        except Exception as e:
            logging.error(f"Error in _resample_array: {str(e)}")
            raise

    def compute_indices(self, band_data):
        """
        Compute NDVI and NDMI indices from Sentinel-2 bands with safe division
        """
        try:
            # Validate required bands
            required_bands = ['B04', 'B08', 'B8A', 'B11']
            self.validate_bands(band_data, required_bands)

            # Calculate NDVI safely
            nir_red_sum = band_data['B08'] + band_data['B04']
            nir_red_diff = band_data['B08'] - band_data['B04']

            # Use np.divide with where condition to handle zeros
            ndvi = np.divide(
                nir_red_diff, 
                nir_red_sum, 
                out=np.zeros_like(nir_red_diff, dtype=np.float32),
                where=nir_red_sum != 0
            )

            # Calculate NDMI safely
            nir_swir_sum = band_data['B8A'] + band_data['B11']
            nir_swir_diff = band_data['B8A'] - band_data['B11']

            # Use np.divide with where condition to handle zeros
            ndmi = np.divide(
                nir_swir_diff, 
                nir_swir_sum, 
                out=np.zeros_like(nir_swir_diff, dtype=np.float32),
                where=nir_swir_sum != 0
            )
            
            # Add bounds to prevent extreme values
            ndvi = np.clip(ndvi, -1, 1)
            ndmi = np.clip(ndmi, -1, 1)

            # Replace NaN values with 0
            ndvi = np.nan_to_num(ndvi, nan=0.0)
            ndmi = np.nan_to_num(ndmi, nan=0.0)

            logging.info(f"Successfully computed NDVI and NDMI indices")
            logging.debug(f"NDVI range: [{ndvi.min():.3f}, {ndvi.max():.3f}]")
            logging.debug(f"NDMI range: [{ndmi.min():.3f}, {ndmi.max():.3f}]")
            
            return ndvi, ndmi
            
        except Exception as e:
            logging.error(f"Error computing indices: {str(e)}")
            raise

    def create_patches(self, stacked_bands, geometries, output_dir):
        """
        Create and save image patches for each geometry using the 'name' property
        for patch naming, with _P{number} suffix for multiple patches.
        """
        try:
            os.makedirs(output_dir, exist_ok=True)
            
            # Create a dictionary to track patch counts for each name
            patch_counts = {}
            
            for idx, geometry in geometries.iterrows():
                try:
                    # Get the name property from the GeoJSON feature
                    plot_name = geometry['name']
                    
                    # Initialize patch count for this name if not exists
                    if plot_name not in patch_counts:
                        patch_counts[plot_name] = 1
                    else:
                        patch_counts[plot_name] += 1
                    
                    # Create patch name with _P{number} suffix if multiple patches exist
                    patch_name = f"{plot_name}_P{patch_counts[plot_name]}"
                    
                    bounds = geometry.geometry.bounds
                    window = from_bounds(*bounds, transform=self.meta['transform'])

                    patch = stacked_bands[
                        :,
                        int(window.row_off):int(window.row_off + self.patch_size),
                        int(window.col_off):int(window.col_off + self.patch_size)
                    ]

                    if patch.shape[1:] == (self.patch_size, self.patch_size):
                        output_path = Path(output_dir) / f"{patch_name}.npy"
                        np.save(output_path, patch)
                        logging.info(f"Saved patch {patch_name} to {output_path}")
                    else:
                        logging.warning(f"Skipping patch {patch_name} due to incorrect size: {patch.shape[1:]}")

                except KeyError as ke:
                    logging.error(f"'name' property not found in geometry at index {idx}: {str(ke)}")
                    continue
                except Exception as e:
                    logging.error(f"Error processing patch for geometry at index {idx}: {str(e)}")
                    continue

        except Exception as e:
            logging.error(f"Error in create_patches: {str(e)}")
            raise

    def process_imagery(self, sentinel_path, geojson_path, output_dir):
        """Process imagery considering tile boundaries"""
        try:
            tile_id = self.get_tile_id(sentinel_path)
            if not tile_id:
                logging.error(f"Could not determine tile ID for {sentinel_path}")
                return

            # Group geometries by tile
            tile_geometries = self.group_geometries_by_tile(geojson_path, sentinel_path)
            if tile_geometries is None:
                logging.info(f"No geometries intersect with tile {tile_id}")
                return

            # Create output directory
            os.makedirs(output_dir, exist_ok=True)

            # Load and process bands
            band_data = self.load_bands(sentinel_path)
            if band_data is None:
                return

            band_data = self.resample_to_10m(band_data)
            ndvi, ndmi = self.compute_indices(band_data)
            
            stacked_bands = np.stack([
                band_data['B02'], band_data['B03'], band_data['B04'],
                band_data['B08'], band_data['B8A'], band_data['B11'],
                band_data['B12'], ndvi, ndmi
            ])

            self.create_patches(stacked_bands, tile_geometries, output_dir)
            logging.info(f"Successfully processed tile {tile_id}")

        except Exception as e:
            logging.error(f"Error processing tile {tile_id}: {str(e)}")
            raise

In [None]:
# Initialize processor
processor = SentinelPatchProcessor(patch_size=224)

# Get all .SAFE directories
base_path = Path("../Datasets/Sentinel-2")
safe_dirs = list(base_path.glob("*/*.SAFE"))

# Process each Sentinel-2 scene
for safe_dir in tqdm(safe_dirs, desc="Processing Sentinel-2 images"):
    try:
        processor.process_imagery(
            sentinel_path=str(safe_dir),
            geojson_path=saved_path,
            output_dir=f"../Datasets/Testing/Tiles/{safe_dir.parent.name}"
        )
    except Exception as e:
        logging.error(f"Failed to process {safe_dir}: {str(e)}")
        continue

In [None]:
def visualize_patch(patch_path):
    """Visualize a saved patch with all its bands and indices"""
    # Load the patch
    patch = np.load(patch_path)
    
    # Define band names for labeling
    band_names = ['B02', 'B03', 'B04', 'B08', 'B8A', 'B11', 'B12', 'NDVI', 'NDMI']
    
    # Create a figure with subplots for each band/index
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    fig.suptitle(f'Patch Visualization: {patch_path.name}', fontsize=16)
    
    # Plot each band/index
    for idx, (ax, name) in enumerate(zip(axes.flat, band_names)):
        im = ax.imshow(patch[idx], cmap='viridis')
        ax.set_title(name)
        plt.colorbar(im, ax=ax)
    
    plt.tight_layout()
    return fig

# Example usage:
tiles_dir = Path("../Datasets/Testing/Tiles")
patch_files = list(tiles_dir.glob("**/*.npy"))

# Visualize first few patches
for patch_file in patch_files[:3]:  # Adjust number of patches to display
    fig = visualize_patch(patch_file)
    plt.show()

In [None]:
def show_rgb_composite(patch_path):
    """Show true color RGB composite"""
    patch = np.load(patch_path)
    
    # Use B04 (Red), B03 (Green), B02 (Blue) for true color
    rgb = np.stack([patch[2], patch[1], patch[0]], axis=-1)
    
    # Normalize for display
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
    
    plt.figure(figsize=(10, 10))
    plt.imshow(rgb)
    plt.title(f'RGB Composite: {patch_path.name}')
    plt.axis('off')
    plt.show()

# Example usage:
show_rgb_composite(patch_files[8])

In [None]:
def visualize_temporal_patches(tiles_dir, plot_identifier):
    """Visualize the same patch location across different dates using the plot number."""
    # Determine if input is a number or a string starting with 'PLOT-'
    if isinstance(plot_identifier, int) or plot_identifier.isdigit():
        plot_name = f"PLOT-{int(plot_identifier):05d}"
    elif str(plot_identifier).startswith('PLOT-'):
        plot_name = plot_identifier
    else:
        print(f"Invalid plot identifier: {plot_identifier}")
        return

    # Find all patch files matching the plot name (including multiple patches)
    patch_files = list(tiles_dir.glob(f"**/{plot_name}*.npy"))
    
    # Check if any patches are found
    if not patch_files:
        print(f"No patches found for plot {plot_name}")
        return
    
    # Extract date and sequence number from the filename
    def get_capture_date(filepath):
        # Assuming the date is included in the directory name or filepath
        match = re.search(r'_(\d{8})T', str(filepath))
        if match:
            date_str = match.group(1)
            return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
        else:
            return "Unknown Date"
        
    def get_patch_sequence(filepath):
        # Extract the patch sequence number (_P1, _P2, etc.)
        match = re.search(r'_P(\d+)\.npy$', filepath.name)
        if match:
            return int(match.group(1))
        else:
            # If no sequence number, assume it's the first patch
            return 1

    # Sort the files based on capture date and patch sequence
    patch_files.sort(key=lambda x: (get_capture_date(x), get_patch_sequence(x)))
    
    n_patches = len(patch_files)
    fig, axes = plt.subplots(n_patches, 1, figsize=(15, 5*n_patches))
    
    # Ensure axes is iterable
    if n_patches == 1:
        axes = [axes]
    
    for idx, patch_file in enumerate(patch_files):
        patch = np.load(patch_file)
        rgb = np.stack([patch[2], patch[1], patch[0]], axis=-1)
        
        # Robust normalization
        rgb_min = rgb.min()
        rgb_max = rgb.max()
        if rgb_max > rgb_min:
            rgb = (rgb - rgb_min) / (rgb_max - rgb_min)
        else:
            rgb = np.zeros_like(rgb)
        
        # Get and display the actual capture date
        file_name = patch_file.name
        capture_date = get_capture_date(patch_file)
        patch_seq = get_patch_sequence(patch_file)
        
        # Build the plot title
        title = f'Plot: {plot_name}\nFile: {file_name}\nDate: {capture_date}'
        if patch_seq > 1:
            title += f'\nPatch Sequence: {patch_seq}'
        
        # Set the title and plot the image
        axes[idx].imshow(rgb)
        axes[idx].set_title(title)
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()
    
# Display temporal sequence
tiles_dir = Path("../Datasets/Testing/Tiles")
visualize_temporal_patches(tiles_dir, 2)

In [None]:
import logging
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import rasterio
from rasterio.enums import Resampling
from tqdm import tqdm
from s2cloudless import S2PixelCloudDetector

class CloudProcessor:
    def __init__(self, base_path: str, cloud_threshold: float = 0.1):
        self.base_path = Path(base_path)
        self.cloud_threshold = cloud_threshold
        # Corrected band list and order
        self.s2cloudless_bands = ['B02', 'B03', 'B04', 'B05', 'B08', 'B8A', 'B11', 'B12']
        self.cloud_detector = S2PixelCloudDetector(threshold=0.4)
        
        # Verify base path exists
        if not self.base_path.exists():
            raise ValueError(f"Base path does not exist: {self.base_path}")

    def get_safe_dir(self, patch_file: Path) -> Path:
        """Get corresponding SAFE directory for a patch"""
        # Extract product name from patch path
        # Updated to account for the new directory structure
        product_name = patch_file.parent.name

        # Search for matching SAFE directory
        safe_pattern = f"{product_name}.SAFE"
        safe_dirs = list(self.base_path.glob(safe_pattern))
        
        if not safe_dirs:
            # Try searching one level deeper
            safe_dirs = list(self.base_path.glob(f"*/{safe_pattern}"))
            
        if not safe_dirs:
            raise ValueError(f"No SAFE directory found for product: {product_name}")
            
        return safe_dirs[0]

    def load_bands_for_patch(self, safe_dir: Path) -> Dict[str, np.ndarray]:
        """Load required bands for cloud detection with proper alignment"""
        bands = {}
        # Use B02 as reference band
        ref_band = 'B02'
        ref_files = list(safe_dir.glob(f'GRANULE/*/IMG_DATA/*{ref_band}*.jp2'))
        if not ref_files:
            raise FileNotFoundError(f"Reference band {ref_band} not found in {safe_dir}")
        ref_band_file = ref_files[0]
        with rasterio.open(ref_band_file) as ref_src:
            ref_data = ref_src.read(1)
            ref_transform = ref_src.transform
            ref_crs = ref_src.crs
            ref_shape = ref_src.shape

        for band in self.s2cloudless_bands:
            band_files = list(safe_dir.glob(f'GRANULE/*/IMG_DATA/*{band}*.jp2'))
            if not band_files:
                raise FileNotFoundError(f"Band {band} not found in {safe_dir}")
            band_file = band_files[0]
            with rasterio.open(band_file) as src:
                if src.crs != ref_crs:
                    raise ValueError(f"Band {band} CRS does not match reference CRS")
                # Resample band to reference resolution and shape
                band_data = src.read(
                    out_shape=(1, ref_shape[0], ref_shape[1]),
                    resampling=Resampling.bilinear
                )[0]
            bands[band] = band_data
        return bands

    def detect_clouds(self, bands: Dict[str, np.ndarray]) -> Tuple[bool, float]:
        """Run cloud detection on band data"""
        # Stack bands in required order
        stacked = np.stack([bands[b] for b in self.s2cloudless_bands], axis=2) / 10000.0  # Shape: (H, W, C)
        data = np.expand_dims(stacked, axis=0)  # Shape: (1, H, W, C)
        
        # Get cloud probabilities
        cloud_probs = self.cloud_detector.get_cloud_probability_maps(data)
        cloud_mask = cloud_probs > self.cloud_detector.threshold
        cloud_percentage = float(np.mean(cloud_mask))
        
        return cloud_percentage <= self.cloud_threshold, cloud_percentage

    def process_patches(self, input_dir: str, output_dir: str):
        """Process all patches"""
        input_path = Path(input_dir)
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True, parents=True)

        # Get all patches
        patches = list(input_path.rglob("*.npy"))
        
        results = {
            'total': len(patches),
            'processed': 0,
            'kept': 0,
            'errors': 0
        }
        
        for patch_file in tqdm(patches, desc="Processing patches"):
            try:
                # Get original SAFE directory
                safe_dir = self.get_safe_dir(patch_file)
                
                # Load bands and detect clouds
                bands = self.load_bands_for_patch(safe_dir)
                keep_patch, cloud_pct = self.detect_clouds(bands)
                
                results['processed'] += 1
                
                if keep_patch:
                    # Create output directory structure
                    out_file = output_path / patch_file.relative_to(input_path)
                    out_file.parent.mkdir(exist_ok=True, parents=True)
                    
                    # Copy patch to output
                    patch_data = np.load(patch_file)
                    np.save(out_file, patch_data)
                    results['kept'] += 1
                    
                logging.info(f"Processed {patch_file.name}: {cloud_pct:.1%} clouds, kept: {keep_patch}")
                    
            except Exception as e:
                results['errors'] += 1
                logging.error(f"Error processing {patch_file}: {str(e)}")
                continue
        
        # Print summary
        logging.info(f"\nProcessing Summary:")
        logging.info(f"Total patches: {results['total']}")
        logging.info(f"Successfully processed: {results['processed']}")
        logging.info(f"Kept (low clouds): {results['kept']}")
        logging.info(f"Errors: {results['errors']}")

In [None]:
# Load a sample patch
patch_path = Path("../Datasets/Testing/Tiles/S2A_MSIL1C_20190815T083601_N0208_R064_T36UYA_20190815T123742/T36UYA/patch_9.npy")
patch_data = np.load(patch_path)

# Verify shape and channels
print(f"Patch shape: {patch_data.shape}")  # Should be (9, 224, 224)
print(f"Channel stats:")
for i, channel_name in enumerate(['B02', 'B03', 'B04', 'B08', 'B8A', 'B11', 'B12', 'NDVI', 'NDMI']):
    print(f"{channel_name}: min={patch_data[i].min():.3f}, max={patch_data[i].max():.3f}")

In [None]:
def visualize_patch(patch_data):
    # True color composite (B04, B03, B02)
    rgb = np.stack([patch_data[2], patch_data[1], patch_data[0]], axis=-1)
    
    # Normalize for visualization
    rgb = np.clip(rgb / 3000, 0, 1)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # RGB
    axes[0].imshow(rgb)
    axes[0].set_title("RGB")
    
    # NDVI
    axes[1].imshow(patch_data[7], cmap='RdYlGn', vmin=-1, vmax=1)
    axes[1].set_title("NDVI")
    
    # NDMI
    axes[2].imshow(patch_data[8], cmap='RdYlBu', vmin=-1, vmax=1)
    axes[2].set_title("NDMI")
    
    plt.tight_layout()
    plt.show()

visualize_patch(patch_data)

In [None]:
# Define the path to the example SAFE directory
example_path = "../Datasets/Sentinel-2/S2A_MSIL1C_20190815T083601_N0208_R064_T36UYA_20190815T123742/S2A_MSIL1C_20190815T083601_N0208_R064_T36UYA_20190815T123742.SAFE"

# Find the index of the folder in the safe_dirs list
folder_index = safe_dirs.index(Path(example_path))

# Print the index
print(f"The index of the folder is: {folder_index}")

In [None]:
from pathlib import Path
import geopandas as gpd
import rasterio
from rasterio import plot
from rasterio.windows import from_bounds, bounds as window_bounds
from shapely.geometry import box
import numpy as np
import matplotlib.pyplot as plt

def verify_patch_coordinates(sentinel_path, geojson_path, patch_path):
    try:
        # Load the GeoJSON geometry
        gdf = gpd.read_file(geojson_path)
        
        # Find B02 band path safely
        b02_paths = list(Path(sentinel_path).glob('GRANULE/*/IMG_DATA/*B02.jp2'))
        if not b02_paths:
            raise FileNotFoundError("B02 band not found")
        b02_path = b02_paths[0]
        
        with rasterio.open(b02_path) as src:
            # Verify and transform CRS
            if not gdf.crs:
                raise ValueError("GeoJSON CRS is missing")
            if not src.crs:
                raise ValueError("Raster CRS is missing")
            
            if gdf.crs != src.crs:
                gdf = gdf.to_crs(src.crs)
            
            # Extract patch index and get geometry
            patch_idx = int(Path(patch_path).stem.split('_')[1])
            if patch_idx >= len(gdf):
                raise IndexError(f"Patch index {patch_idx} exceeds GeoJSON features")

            geometry = gdf.iloc[patch_idx].geometry
            geom_bounds = geometry.bounds
            window = from_bounds(*geom_bounds, transform=src.transform)

            # Get the spatial bounds of the window (patch)
            patch_bounds = window_bounds(window, transform=src.transform)
            
            # Create a shapely Polygon from patch_bounds
            patch_polygon = box(*patch_bounds)
            
            # Check if geometry is within the patch polygon
            intersects = geometry.intersects(patch_polygon)
            is_within = geometry.within(patch_polygon)
            
            print(f"Geometry bounds: {geom_bounds}")
            print(f"Patch bounds: {patch_bounds}")
            print(f"Does geometry intersect patch bounds: {intersects}")
            print(f"Is geometry entirely within patch bounds: {is_within}")
            
            # Load and validate patch
            patch = np.load(patch_path)

            # Expected dimensions from window
            expected_height = int(window.height)
            expected_width = int(window.width)

            print(f"Patch shape: {patch.shape}")
            
            # Visualize the patch with geometry
            visualize_patch_with_geometry(sentinel_path, geometry, window)
            
            return {
                'geometry_bounds': geom_bounds,
                'patch_bounds': patch_bounds,
                'intersects': intersects,
                'is_within': is_within
            }
    except Exception as e:
        print(f"Error: {str(e)}")
        return None

def visualize_patch_with_geometry(sentinel_path, geometry, window):
    b02_paths = list(Path(sentinel_path).glob('GRANULE/*/IMG_DATA/*B02.jp2'))
    if not b02_paths:
        raise FileNotFoundError("B02 band not found")
    b02_path = b02_paths[0]
    
    with rasterio.open(b02_path) as src:
        patch_data = src.read(1, window=window)
        patch_transform = src.window_transform(window)
        
    fig, ax = plt.subplots(figsize=(10, 10))
    rasterio.plot.show(patch_data, transform=patch_transform, ax=ax, cmap='gray')
    
    # Plot the geometry using geom_type instead of type
    if geometry.geom_type == 'Polygon':
        x, y = geometry.exterior.xy
        ax.plot(x, y, color='red', linewidth=2)
    elif geometry.geom_type == 'MultiPolygon':
        # Use geoms to iterate through MultiPolygon
        for polygon in geometry.geoms:
            x, y = polygon.exterior.xy
            ax.plot(x, y, color='red', linewidth=2)
    else:
        print(f"Unsupported geometry type: {geometry.geom_type}")
    
    plt.show()

# Example usage
result = verify_patch_coordinates(
    sentinel_path=safe_dirs[folder_index],
    geojson_path=saved_path,
    patch_path=patch_path
)

In [None]:
def verify_all_patches(sentinel_path, geojson_path, patches_dir):
    try:
        # Load the GeoJSON geometries
        gdf = gpd.read_file(geojson_path)
        
        # Locate the B02 band image
        b02_paths = list(Path(sentinel_path).glob('GRANULE/*/IMG_DATA/*B02.jp2'))
        if not b02_paths:
            raise FileNotFoundError("B02 band not found in the specified Sentinel-2 data directory.")
        b02_path = b02_paths[0]
        
        with rasterio.open(b02_path) as src:
            # Ensure CRS alignment between geometries and raster
            if not gdf.crs:
                raise ValueError("CRS is missing in the GeoJSON file.")
            if not src.crs:
                raise ValueError("CRS is missing in the raster data.")
            
            if gdf.crs != src.crs:
                gdf = gdf.to_crs(src.crs)
            
            total_patches = 0
            intersects_count = 0
            within_count = 0
            
            # List all patch files in the directory
            patch_paths = sorted(Path(patches_dir).glob('patch_*.npy'))
            
            for patch_path in patch_paths:
                total_patches += 1
                
                # Extract the patch index from the file name
                patch_idx = int(Path(patch_path).stem.split('_')[1])
                if patch_idx >= len(gdf):
                    print(f"Warning: Patch index {patch_idx} exceeds the number of geometries in the GeoJSON file.")
                    continue

                geometry = gdf.iloc[patch_idx].geometry
                geom_bounds = geometry.bounds
                window = from_bounds(*geom_bounds, transform=src.transform)

                # Calculate the spatial bounds of the patch
                patch_bounds = window_bounds(window, transform=src.transform)
                patch_polygon = box(*patch_bounds)
                
                # Check spatial relationships
                intersects = geometry.intersects(patch_polygon)
                is_within = geometry.within(patch_polygon)
                
                if intersects:
                    intersects_count += 1
                if is_within:
                    within_count += 1

            # Output the summary
            print(f"Total patches processed: {total_patches}")
            print(f"Patches where geometry intersects patch bounds: {intersects_count}")
            print(f"Patches where geometry is entirely within patch bounds: {within_count}")
            
            return {
                'total_patches': total_patches,
                'intersects_count': intersects_count,
                'within_count': within_count
            }
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None

In [None]:
# Replace these with your actual paths
sentinel_path = safe_dirs[folder_index]
tile_id = re.search(r'T\d{2}[A-Z]{3}', str(safe_dir)).group(0)
scene_name = safe_dir.parent.name
geojson_path = saved_path
patches_dir = f"../Datasets/Testing/Tiles/{scene_name}/{tile_id}"

# Call the function
result = verify_all_patches(sentinel_path, geojson_path, patches_dir)

In [None]:
import os
import numpy as np
import json
from datetime import datetime
from glob import glob
from pathlib import Path

def load_geojson_dates(print_loading=False):
    # Load the most recent sampled events file
    sample_files = glob('../Datasets/Testing/Samples/*.geojson')
    if not sample_files:
        raise FileNotFoundError("No .geojson files found in Testing/Samples/")
    latest_file = max(sample_files, key=os.path.getctime)
    if print_loading == True:
        print(f"Loading events from {latest_file}")

    with open(latest_file) as f:
        data = json.load(f)

    # Extract dates and convert to datetime objects
    event_dates = []
    for feature in data['features']:
        date_str = feature['properties']['img_date']
        try:
            event_date = datetime.strptime(date_str, '%Y-%m-%d')
            event_dates.append(event_date)
        except ValueError:
            print(f"Date format error in {date_str}")
            # You can choose to skip or handle the error as needed
            continue

    return sorted(event_dates)

def get_tile_date(patch_file_path):
    # Extract date from Sentinel-2 tile path
    tile_name = Path(patch_file_path).parent.parent.name
    parts = tile_name.split('_')
    if len(parts) < 3:
        raise ValueError(f"Unexpected tile name format: {tile_name}")
    date_str = parts[2][:8]
    try:
        return datetime.strptime(date_str, '%Y%m%d')
    except ValueError:
        raise ValueError(f"Invalid date format in tile name: {date_str}")

def load_and_sort_patches():
    patches = []
    patch_dates = []

    # Use pathlib for better path handling
    base_path = Path('../Datasets/Testing/Tiles')
    tile_dirs = list(base_path.glob('S2*'))
    
    if not tile_dirs:
        raise FileNotFoundError(f"No tile directories found in {base_path}")

    # Print total number of patches expected
    print(f"Found {len(tile_dirs)} tile directories")

    for tile_dir in tile_dirs:
        # Find all patch files (0-65) in each tile directory
        patch_files = []
        for patch_num in range(66):  # 0 to 65
            patch_files.extend(list(tile_dir.rglob(f'T*/patch_{patch_num}.npy')))

        if not patch_files:
            print(f"No patches found in {tile_dir}")
            continue

        print(f"Found {len(patch_files)} patches in {tile_dir}")

        for idx, patch_file in enumerate(sorted(patch_files)):
            try:
                patch = np.load(patch_file)
                patch_date = get_tile_date(patch_file)
                patch_dates.append(patch_date)
                patches.append(patch)
            except Exception as e:
                print(f"Error loading {patch_file}: {e}")
                continue

    if not patches:
        raise RuntimeError("No patches loaded successfully.")

    # Convert patch_dates to numpy datetime64 for sorting
    patch_dates_np = np.array(patch_dates, dtype='datetime64')
    patches_np = np.array(patches)

    # Sort patches by date
    sorted_indices = np.argsort(patch_dates_np)
    patches_sorted = patches_np[sorted_indices]
    patch_dates_sorted = patch_dates_np[sorted_indices]

    print(f"\nTotal patches loaded: {len(patches_sorted)}")
    return patches_sorted, patch_dates_sorted

event_dates = load_geojson_dates(print_loading=True)
patches, patch_dates = load_and_sort_patches()
# Example processing: print the number of patches and first few dates
print("First 5 patch dates:")
for date in patch_dates[:5]:
    print(date)

In [None]:
import matplotlib.pyplot as plt
from collections import Counter
from pathlib import Path
import numpy as np

def analyze_patch_distribution():
    base_path = Path('../Datasets/Testing/Tiles')
    patch_counts = Counter()
    
    # Count patches
    for tile_dir in base_path.glob('S2*'):
        for patch_num in range(66):
            count = len(list(tile_dir.rglob(f'T*/patch_{patch_num}.npy')))
            patch_counts[patch_num] = patch_counts[patch_num] + count
    
    # Prepare data for plotting
    patch_numbers = list(range(66))
    frequencies = [patch_counts[i] for i in patch_numbers]
    
    # Create histogram
    plt.figure(figsize=(15, 6))
    plt.bar(patch_numbers, frequencies)
    plt.title('Distribution of Patch Numbers')
    plt.xlabel('Patch Number')
    plt.ylabel('Frequency')
    plt.grid(True, alpha=0.3)
    
    # Print statistics
    print(f"Total patches: {sum(frequencies)}")
    print(f"Most common patch numbers:")
    for num, freq in sorted(patch_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
        print(f"Patch {num}: {freq} occurrences")
        
    plt.show()
    
    return patch_counts

# Run analysis
distribution = analyze_patch_distribution()

In [None]:
import os
import numpy as np
from datetime import datetime, timedelta
import h5py

# Create a directory for intermediate storage
os.makedirs('../Datasets/Testing/Processed/HDF5-TemporalPairs', exist_ok=True)

# Process in smaller chunks and save to HDF5
with h5py.File('../Datasets/Testing/Processed/HDF5-TemporalPairs/pairs.h5', 'w') as f:
    # Create datasets
    f.create_dataset('pairs', shape=(0, 2, *patches[0].shape), maxshape=(None, 2, *patches[0].shape), chunks=True)
    f.create_dataset('labels', shape=(0,), maxshape=(None,), dtype=bool)
    
    pair_count = 0
    chunk_size = 10  # Process 10 pairs at a time
    
    for i in range(0, len(patches)-1, chunk_size):
        chunk_pairs = []
        chunk_labels = []
        
        chunk_end = min(i + chunk_size, len(patches)-1)
        for j in range(i, chunk_end):
            date1 = patch_dates[j]
            for k in range(j+1, len(patches)):
                date2 = patch_dates[k]
                if (date2 - date1).astype('timedelta64[D]').astype(int) <= 30:
                    chunk_pairs.append([patches[j], patches[k]])
                    has_event = any(date1 <= event_date <= date2 for event_date in load_geojson_dates())
                    chunk_labels.append(has_event)
        
        if chunk_pairs:
            # Resize datasets
            new_size = pair_count + len(chunk_pairs)
            f['pairs'].resize(new_size, axis=0)
            f['labels'].resize(new_size, axis=0)
            
            # Store chunk
            f['pairs'][pair_count:new_size] = chunk_pairs
            f['labels'][pair_count:new_size] = chunk_labels
            
            pair_count = new_size

print(f"Total pairs saved: {pair_count}")

In [None]:
import random
from sklearn.model_selection import train_test_split

# Create temporal pairs
temporal_pairs = []
labels = []
max_time_diff = 30  # Maximum days between image pairs

for i, (date1, patch1) in enumerate(zip(patch_dates[:-1], patches[:-1])):
    for j, (date2, patch2) in enumerate(zip(patch_dates[i+1:], patches[i+1:]), i+1):
        time_diff = (date2 - date1).astype('timedelta64[D]').astype(int)
        if time_diff <= max_time_diff:
            temporal_pairs.append((patch1, patch2))
            # Binary label: 1 if deforestation event exists between dates
            has_event = any(date1 <= event_date <= date2 for event_date in load_geojson_dates())
            labels.append(has_event)

# Convert to numpy arrays
X = np.array(temporal_pairs)
y = np.array(labels)

# Split data: 70% train, 15% validation, 15% test
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

print(f"Training pairs: {len(X_train)}")
print(f"Validation pairs: {len(X_val)}")
print(f"Testing pairs: {len(X_test)}")
print(f"Positive samples: {sum(y)}/{len(y)} ({sum(y)/len(y)*100:.2f}%)")

In [None]:
class ConvBlock(nn.Module):
   def __init__(self, in_ch, out_ch):
       super().__init__()
       self.conv = nn.Sequential(
           nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
           nn.BatchNorm2d(out_ch),
           nn.ReLU(inplace=True),
           nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
           nn.BatchNorm2d(out_ch),
           nn.ReLU(inplace=True)
       )

   def forward(self, x):
       return self.conv(x)

class UNetCH(nn.Module):
   def __init__(self, in_channels=27): # 9 channels each for img1, img2, diff
       super().__init__()
       
       # Encoder
       self.enc1 = ConvBlock(in_channels, 64)
       self.enc2 = ConvBlock(64, 128)
       self.enc3 = ConvBlock(128, 256)
       
       # Decoder
       self.dec3 = ConvBlock(256 + 128, 128)
       self.dec2 = ConvBlock(128 + 64, 64)
       self.dec1 = ConvBlock(64, 32)
       
       # Classification Head
       self.cls_head = nn.Sequential(
           nn.AdaptiveMaxPool2d(1),
           nn.Flatten(),
           nn.Linear(256, 1),
           nn.Sigmoid()
       )
       
       # Final Conv
       self.final = nn.Conv2d(32, 1, kernel_size=1)
       
       # Pooling and Upsampling
       self.pool = nn.MaxPool2d(2)
       self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
       
   def forward(self, x):
       # Encoder
       e1 = self.enc1(x)
       e2 = self.enc2(self.pool(e1))
       e3 = self.enc3(self.pool(e2))
       
       # Classification Branch
       cls_output = self.cls_head(e3)
       
       # Decoder with Skip Connections
       d3 = self.dec3(torch.cat([self.up(e3), e2], dim=1))
       d2 = self.dec2(torch.cat([self.up(d3), e1], dim=1))
       d1 = self.dec1(d2)
       
       # Final Output
       seg_output = torch.sigmoid(self.final(d1))
       
       return seg_output, cls_output