In [None]:
# =================================================================================
# GEE API SETUP
# =================================================================================

import ee
from google.colab import drive

In [None]:
try:
    ee.Initialize(project = "cobalt-vector-470207-a5")
    print('‚úÖ Earth Engine API initialized successfully.')
except Exception as e:
    print('‚ùóÔ∏è An error occurred during initialization:', e)
    print('Trying to authenticate...')
    ee.Authenticate()
    ee.Initialize(project = "cobalt-vector-470207-a5")

# 2. Mount your Google Drive
drive.mount('/content/drive')
print('‚úÖ Google Drive mounted successftry:
    ee.Initialize(project = "cobalt-vector-470207-a5")
    print('‚úÖ Earth Engine API initialized successfully.')
except Exception as e:
    print('‚ùóÔ∏è An error occurred during initialization:', e)
    print('Trying to authenticate...')
    ee.Authenticate()
    ee.Initialize(project = "cobalt-vector-470207-a5")

# 2. Mount your Google Drive
drive.mount('/content/drive')
print('‚úÖ Google Drive mounted successfully.')

‚ùóÔ∏è An error occurred during initialization: Please authorize access to your Earth Engine account by running

earthengine authenticate

in your command line, or ee.Authenticate() in Python, and then retry.
Trying to authenticate...
Mounted at /content/drive
‚úÖ Google Drive mounted successfully.


In [None]:
print(f'Post-flood period: {post_flood_start} to {post_flood_end}')

# Loop through each district and run the analysis
for district_name, geometry in district_geometries.items():
    run_and_export_analysis(district_name, geometry)

print('\n--- All analysis tasks have been submitted. ---')
print('Use check_task_status() to monitor progress.')
district_geometries = {
    'Barpeta': ee.Geometry.Rectangle([90.732, 26.155, 91.265, 26.512]),
    'Dhemaji': ee.Geometry.Rectangle([94.395, 27.420, 94.980, 27.750]),
    'Lakhimpur': ee.Geometry.Rectangle([93.700, 26.750, 94.500, 27.550]),
    'Nalbari': ee.Geometry.Rectangle([91.130, 26.250, 91.550, 26.600]),
    'Sonitpur': ee.Geometry.Rectangle([92.500, 26.500, 93.300, 27.000])
}

# Date ranges for the analysis
pre_flood_start = '2023-05-01'
pre_flood_end = '2023-05-31'
post_flood_start = '2023-06-20'
post_flood_end = '2023-06-30'

# Function to run the analysis and export for a single district
def run_and_export_analysis(district_name, geometry):
    """
    Performs flood analysis for a given district and starts export tasks.
    """
    print(f'\n--- Running analysis for {district_name} ---')

    # Load and filter the Sentinel-1 collection for pre-flood
    s1_collection_pre = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filter(ee.Filter.eq('instrumentMode', 'IW')) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
        .filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING')) \
        .filterDate(pre_flood_start, pre_flood_end) \
        .filterBounds(geometry)

    # Load and filter the Sentinel-1 collection for post-flood
    s1_collection_post = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filter(ee.Filter.eq('instrumentMode', 'IW')) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
        .filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING')) \
        .filterDate(post_flood_start, post_flood_end) \
        .filterBounds(geometry)

    # Check collection sizes
    pre_size = s1_collection_pre.size().getInfo()
    post_size = s1_collection_post.size().getInfo()

    print(f'Pre-flood images available: {pre_size}')
    print(f'Post-flood images available: {post_size}')

    # Check if collections have images
    if pre_size > 0 and post_size > 0:
        # Create mosaics
        pre_flood_mosaic = s1_collection_pre.mean().clip(geometry)
        post_flood_mosaic = s1_collection_post.mean().clip(geometry)

        # Export Pre-Flood Image - CORRECTED SYNTAX
        task_pre = ee.batch.Export.image.toDrive(
            image=pre_flood_mosaic.select('VV'),
            description=f'{district_name}_PreFlood_Image',
            folder='Colab Notebooks/Images',
            scale=10,
            region=geometry,
            maxPixels=1e13,
            fileFormat='GeoTIFF'
        )
        task_pre.start()
        print(f'‚úÖ Export task started for {district_name} (Pre-Flood). Task ID: {task_pre.id}')

        # Export Post-Flood Image - CORRECTED SYNTAX
        task_post = ee.batch.Export.image.toDrive(
            image=post_flood_mosaic.select('VV'),
            description=f'{district_name}_PostFlood_Image',
            folder='Colab Notebooks/Images',
            scale=10,
            region=geometry,
            maxPixels=1e13,
            fileFormat='GeoTIFF'
        )
        task_post.start()
        print(f'‚úÖ Export task started for {district_name} (Post-Flood). Task ID: {task_post.id}')

        # Calculate flood detection (optional - for analysis)
        # Apply speckle filtering   # Applies median filter to smooth SAR noise.
        pre_filtered = pre_flood_mosaic.focal_median(2, 'circle', 'pixels', 2)
        post_filtered = post_flood_mosaic.focal_median(2, 'circle', 'pixels', 2)

        # Calculate difference
        difference = post_filtered.subtract(pre_filtered)

        # Flood mask (areas where backscatter decreased significantly)
        flood_threshold = -3  # dB
        flood_mask = difference.lt(flood_threshold)

        # Export flood mask
        task_flood = ee.batch.Export.image.toDrive(
            image=flood_mask.select('VV').rename('flood_mask'),
            description=f'{district_name}_Flood_Mask',
            folder='Colab Notebooks/Images',
            scale=10,
            region=geometry,
            maxPixels=1e13,
            fileFormat='GeoTIFF'
        )
        task_flood.start()
        print(f'‚úÖ Flood mask export started for {district_name}. Task ID: {task_flood.id}')

    elif pre_size == 0:
        print(f'‚ùóÔ∏è No pre-flood images found for {district_name} between {pre_flood_start} and {pre_flood_end}.')
    elif post_size == 0:
        print(f'‚ùóÔ∏è No post-flood images found for {district_name} between {post_flood_start} and {post_flood_end}.')
    else:
        print(f'‚ùóÔ∏è No images found for {district_name}. No exports will be started.')

# Function to check task status
def check_task_status():
    """
    Check the status of all running tasks
    """
    tasks = ee.batch.Task.list()
    print('\n--- Task Status ---')
    for task in tasks[:10]:  # Show first 10 tasks
        print(f'Task: {task.config["description"]}, Status: {task.state}, Progress: {task.progress}%')

# =================================================================================
# RUN THE SCRIPT
# =================================================================================

print('Starting flood analysis for Assam districts...')
print(f'Pre-flood period: {pre_flood_start} to {pre_flood_end}')
print(f'Post-flood period: {post_flood_start} to {post_flood_end}')

# Loop through each district and run the analysis
for district_name, geometry in district_geometries.items():
    run_and_export_analysis(district_name, geometry)

print('\n--- All analysis tasks have been submitted. ---')
print('Use check_task_status() to monitor progress.')

Starting flood analysis for Assam districts...
Pre-flood period: 2023-05-01 to 2023-05-31
Post-flood period: 2023-06-20 to 2023-06-30

--- Running analysis for Barpeta ---
Pre-flood images available: 5
Post-flood images available: 2
‚úÖ Export task started for Barpeta (Pre-Flood). Task ID: A2F76CKIQHNB5OM4QT2RGFZY
‚úÖ Export task started for Barpeta (Post-Flood). Task ID: LO3WERY7J25FQOIHD35TSDK3
‚úÖ Flood mask export started for Barpeta. Task ID: CJ3M2Z74XQCQTJE6KWSEIFIT

--- Running analysis for Dhemaji ---
Pre-flood images available: 5
Post-flood images available: 1
‚úÖ Export task started for Dhemaji (Pre-Flood). Task ID: TBEHKFEJPDAUWHF5VCUKZKEG
‚úÖ Export task started for Dhemaji (Post-Flood). Task ID: VXKDGLLBYUZL3EHEY4COYQNR
‚úÖ Flood mask export started for Dhemaji. Task ID: GQPTMGP345CTXR7FNVH3OHVS

--- Running analysis for Lakhimpur ---
Pre-flood images available: 8
Post-flood images available: 1
‚úÖ Export task started for Lakhimpur (Pre-Flood). Task ID: RSXKNSZXVM5DXJMCXAN

In [None]:
# District geometries (bounding boxes)
district_geometries = {
    'Barpeta': ee.Geometry.Rectangle([90.732, 26.155, 91.265, 26.512]),
    'Dhemaji': ee.Geometry.Rectangle([94.395, 27.420, 94.980, 27.750]),
    'Lakhimpur': ee.Geometry.Rectangle([93.700, 26.750, 94.500, 27.550]),
    'Nalbari': ee.Geometry.Rectangle([91.130, 26.250, 91.550, 26.600]),
    'Sonitpur': ee.Geometry.Rectangle([92.500, 26.500, 93.300, 27.000])
}

# Function to check image availability
def check_image_availability(district_name, geometry, start_date, end_date):
    """
    Check what Sentinel-1 images are available for a given region and time period
    """
    print(f'\n--- Checking {district_name} from {start_date} to {end_date} ---')

    # Basic collection without strict filtering
    collection_basic = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filterBounds(geometry) \
        .filterDate(start_date, end_date)

    basic_count = collection_basic.size().getInfo()
    print(f'Total S1 images available: {basic_count}')

    if basic_count == 0:
        print('‚ùå No images found at all for this region/time period')
        return False

    # Check with VV polarization
    collection_vv = collection_basic.filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
    vv_count = collection_vv.size().getInfo()
    print(f'Images with VV polarization: {vv_count}')

    # Check with IW instrument mode
    collection_iw = collection_vv.filter(ee.Filter.eq('instrumentMode', 'IW'))
    iw_count = collection_iw.size().getInfo()
    print(f'Images with IW mode + VV: {iw_count}')

    if iw_count > 0:
        # Get details of first few images
        images_list = collection_iw.limit(3).getInfo()
        print('Sample images found:')
        for i, img in enumerate(images_list['features']):
            props = img['properties']
            print(f"  {i+1}. Date: {props.get('system:time_start', 'N/A')}")
            print(f"     Orbit: {props.get('orbitProperties_pass', 'N/A')}")
            print(f"     Polarizations: {props.get('transmitterReceiverPolarisation', 'N/A')}")

    return iw_count > 0

# Function to find available date ranges
def find_available_dates(district_name, geometry):
    """
    Find when Sentinel-1 images are actually available for a region
    """
    print(f'\n--- Finding available dates for {district_name} ---')

    # Check a wider date range
    collection = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filterBounds(geometry) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
        .filter(ee.Filter.eq('instrumentMode', 'IW'))

    # Get date range
    date_range = collection.reduceColumns(ee.Reducer.minMax(), ['system:time_start']).getInfo()

    if date_range['min'] and date_range['max']:
        import datetime
        min_date = datetime.datetime.fromtimestamp(date_range['min'] / 1000).strftime('%Y-%m-%d')
        max_date = datetime.datetime.fromtimestamp(date_range['max'] / 1000).strftime('%Y-%m-%d')
        print(f'Images available from: {min_date} to {max_date}')

        # Get total count
        total_images = collection.size().getInfo()
        print(f'Total suitable images: {total_images}')

        return min_date, max_date, total_images > 0
    else:
        print('‚ùå No suitable images found for this region')
        return None, None, False

# Test function with working dates
def test_export_with_available_dates(district_name, geometry):
    """
    Try to export using dates when we know images are available
    """
    print(f'\n--- Testing export for {district_name} ---')

    # Use a broader, more recent date range that's likely to have data
    # Assam floods typically occur during monsoon season (June-September)
    test_start = '2023-01-01'
    test_end = '2024-12-31'

    collection = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filterBounds(geometry) \
        .filterDate(test_start, test_end) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
        .filter(ee.Filter.eq('instrumentMode', 'IW'))

    image_count = collection.size().getInfo()
    print(f'Available images in {test_start} to {test_end}: {image_count}')

    if image_count > 0:
        # Create a simple mosaic of available images
        mosaic = collection.mean().clip(geometry)

        # Try export
        task = ee.batch.Export.image.toDrive(
            image=mosaic.select('VV'),
            description=f'Test_{district_name}_Sentinel1',
            folder='EE_Exports',  # Changed folder name
            scale=10,
            region=geometry,
            maxPixels=1e13,
            fileFormat='GeoTIFF'
        )

        try:
            task.start()
            print(f'‚úÖ Test export started for {district_name}')
            print(f'   Task ID: {task.id}')
            print(f'   Check your Google Drive in the "EE_Exports" folder')
            return True
        except Exception as e:
            print(f'‚ùå Export failed: {str(e)}')
            return False
    else:
        print('‚ùå No images available for export test')
        return False

# Check authentication and drive access
def check_authentication():
    """
    Verify that GEE and Drive access are working
    """
    print('--- Checking Authentication ---')
    try:
        # Test basic GEE functionality
        test_point = ee.Geometry.Point([91, 26])
        test_collection = ee.ImageCollection('COPERNICUS/S1_GRD').limit(1)
        size = test_collection.size().getInfo()
        print(f'‚úÖ GEE authentication working (test collection size: {size})')

        # Check if we can access task list
        tasks = ee.batch.Task.list()
        print(f'‚úÖ Can access task list ({len(tasks)} tasks visible)')

        return True
    except Exception as e:
        print(f'‚ùå Authentication issue: {str(e)}')
        return False

# Main diagnostic function
def run_full_diagnosis():
    """
    Run complete diagnosis of the issue
    """
    print('=' * 60)
    print('GOOGLE EARTH ENGINE EXPORT DIAGNOSIS')
    print('=' * 60)

    # 1. Check authentication
    if not check_authentication():
        print('\n‚ùå STOP: Fix authentication first')
        return

    # 2. Check image availability for your original dates
    original_dates = [
        ('2023-05-01', '2023-05-31', 'Pre-flood'),
        ('2023-06-20', '2023-06-30', 'Post-flood')
    ]

    has_images = False
    for district_name, geometry in district_geometries.items():
        print(f'\n{"="*40}')
        print(f'DISTRICT: {district_name}')
        print(f'{"="*40}')

        # Find what dates are actually available
        min_date, max_date, available = find_available_dates(district_name, geometry)

        if available:
            has_images = True
            # Check your specific dates
            for start_date, end_date, period_name in original_dates:
                check_image_availability(district_name, geometry, start_date, end_date)

        # Test export with available data
        if available:
            test_export_with_available_dates(district_name, geometry)

        print('-' * 40)

        # Only check first district for initial diagnosis
        break

    if not has_images:
        print('\n‚ùå ISSUE FOUND: No Sentinel-1 images available for your regions')
        print('   Try different districts or check if coordinates are correct')

    print('\n--- SUMMARY ---')
    print('1. Check the output above for specific issues')
    print('2. If test exports were started, check Google Drive in ~5-10 minutes')
    print('3. You can monitor progress with: ee.batch.Task.list()')

# =================================================================================
# RUN DIAGNOSIS
# =================================================================================

# Uncomment the line below to run full diagnosis
run_full_diagnosis()

# Alternative: Check specific district
# check_image_availability('Barpeta', district_geometries['Barpeta'], '2023-05-01', '2023-05-31')

# Alternative: Test simple export
# test_export_with_available_dates('Barpeta', district_geometries['Barpeta'])

GOOGLE EARTH ENGINE EXPORT DIAGNOSIS
--- Checking Authentication ---
‚úÖ GEE authentication working (test collection size: 1)
‚úÖ Can access task list (15 tasks visible)

DISTRICT: Barpeta

--- Finding available dates for Barpeta ---
Images available from: 2014-10-08 to 2025-09-09
Total suitable images: 1191

--- Checking Barpeta from 2023-05-01 to 2023-05-31 ---
Total S1 images available: 10
Images with VV polarization: 10
Images with IW mode + VV: 10
Sample images found:
  1. Date: 1682985338000
     Orbit: DESCENDING
     Polarizations: ['VV', 'VH']
  2. Date: 1683374233000
     Orbit: ASCENDING
     Polarizations: ['VV', 'VH']
  3. Date: 1683589641000
     Orbit: DESCENDING
     Polarizations: ['VV', 'VH']

--- Checking Barpeta from 2023-06-20 to 2023-06-30 ---
Total S1 images available: 3
Images with VV polarization: 3
Images with IW mode + VV: 3
Sample images found:
  1. Date: 1687521435000
     Orbit: ASCENDING
     Polarizations: ['VV', 'VH']
  2. Date: 1687736843000
     Orbit

In [None]:
# =================================================================================
# SCRIPT CONFIGURATION: MULTI-DISTRICT FLOOD ANALYSIS FOR ASSAM
# =================================================================================

# This script allows you to run a multi-district flood analysis for Assam.
# It iterates through each district, performs the analysis, and exports the data

# --- PRE-CONFIGURED DISTRICT GEOMETRIES AND DATES ---
# A dictionary of locations. Each district now includes specific
# pre- and post-flood date ranges based on 2023 flood events.
district_geometries = {
    'Barpeta': {
        'geometry': ee.Geometry.Rectangle([90.732, 26.155, 91.265, 26.512]),
        'pre_flood_start': '2023-05-01',
        'pre_flood_end': '2023-05-31',
        'post_flood_start': '2023-06-20',
        'post_flood_end': '2023-06-30'
    },
    'Dhemaji': {
        'geometry': ee.Geometry.Rectangle([94.395, 27.420, 94.980, 27.750]),
        'pre_flood_start': '2023-05-01',
        'pre_flood_end': '2023-05-31',
        'post_flood_start': '2023-06-15',
        'post_flood_end': '2023-06-25'
    },
    'Lakhimpur': {
        'geometry': ee.Geometry.Rectangle([93.700, 26.750, 94.500, 27.550]),
        'pre_flood_start': '2023-05-01',
        'pre_flood_end': '2023-05-31',
        'post_flood_start': '2023-06-15',
        'post_flood_end': '2023-06-25'
    },
    'Nalbari': {
        'geometry': ee.Geometry.Rectangle([91.130, 26.250, 91.550, 26.600]),
        'pre_flood_start': '2023-05-01',
        'pre_flood_end': '2023-05-31',
        'post_flood_start': '2023-06-20',
        'post_flood_end': '2023-06-30'
    },
    'Sonitpur': {
        'geometry': ee.Geometry.Rectangle([92.500, 26.500, 93.300, 27.000]),
        'pre_flood_start': '2023-05-01',
        'pre_flood_end': '2023-05-31',
        'post_flood_start': '2023-06-15',
        'post_flood_end': '2023-06-25'
    }
}

# =================================================================================
# GEE API SETUP
# =================================================================================
# Your GEE and Google Drive setup code here...

# =================================================================================
# CORE LOGIC & FUNCTIONS
# =================================================================================

# Function to run the analysis and export for a single district
def run_and_export_analysis(district_name, district_data):
    """
    Performs flood analysis for a given district and starts export tasks.
    """
    geometry = district_data['geometry']
    pre_flood_start = district_data['pre_flood_start']
    pre_flood_end = district_data['pre_flood_end']
    post_flood_start = district_data['post_flood_start']
    post_flood_end = district_data['post_flood_end']

    print(f'\n--- Running analysis for {district_name} ---')
    print(f'Pre-flood period: {pre_flood_start} to {pre_flood_end}')
    print(f'Post-flood period: {post_flood_start} to {post_flood_end}')

    # Load and filter the Sentinel-1 collection for pre-flood
    s1_collection_pre = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filter(ee.Filter.eq('instrumentMode', 'IW')) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
        .filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING')) \
        .filterDate(pre_flood_start, pre_flood_end) \
        .filterBounds(geometry)

    # Load and filter the Sentinel-1 collection for post-flood
    s1_collection_post = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filter(ee.Filter.eq('instrumentMode', 'IW')) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
        .filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING')) \
        .filterDate(post_flood_start, post_flood_end) \
        .filterBounds(geometry)

    # Check collection sizes
    pre_size = s1_collection_pre.size().getInfo()
    post_size = s1_collection_post.size().getInfo()

    print(f'Pre-flood images available: {pre_size}')
    print(f'Post-flood images available: {post_size}')

    # Check if collections have images
    if pre_size > 0 and post_size > 0:
        # Create mosaics
        pre_flood_mosaic = s1_collection_pre.mean().clip(geometry)
        post_flood_mosaic = s1_collection_post.mean().clip(geometry)

        # Export Pre-Flood Image
        task_pre = ee.batch.Export.image.toDrive(
            image=pre_flood_mosaic.select('VV'),
            description=f'pre_flood_{district_name}',
            folder='Colab Notebooks/Images',
            scale=10,
            region=geometry,
            maxPixels=1e13,
            fileFormat='GeoTIFF'
        )
        task_pre.start()
        print(f'‚úÖ Export task started for pre_flood_{district_name}.')

        # Export Post-Flood Image
        task_post = ee.batch.Export.image.toDrive(
            image=post_flood_mosaic.select('VV'),
            description=f'post_flood_{district_name}',
            folder='Colab Notebooks/Images',
            scale=10,
            region=geometry,
            maxPixels=1e13,
            fileFormat='GeoTIFF'
        )
        task_post.start()
        print(f'‚úÖ Export task started for post_flood_{district_name}.')

    elif pre_size == 0:
        print(f'‚ùóÔ∏è No pre-flood images found for {district_name}.')
    elif post_size == 0:
        print(f'‚ùóÔ∏è No post-flood images found for {district_name}.')
    else:
        print(f'‚ùóÔ∏è No images found for {district_name}. No exports will be started.')


# =================================================================================
# RUN THE SCRIPT
# =================================================================================

print('Starting flood analysis for Assam districts...')

# Loop through each district and run the analysis
for district_name, district_data in district_geometries.items():
    run_and_export_analysis(district_name, district_data)

print('\n--- All analysis tasks have been submitted. ---')

Starting flood analysis for Assam districts...

--- Running analysis for Barpeta ---
Pre-flood period: 2023-05-01 to 2023-05-31
Post-flood period: 2023-06-20 to 2023-06-30
Pre-flood images available: 5
Post-flood images available: 2
‚úÖ Export task started for pre_flood_Barpeta.
‚úÖ Export task started for post_flood_Barpeta.

--- Running analysis for Dhemaji ---
Pre-flood period: 2023-05-01 to 2023-05-31
Post-flood period: 2023-06-15 to 2023-06-25
Pre-flood images available: 5
Post-flood images available: 1
‚úÖ Export task started for pre_flood_Dhemaji.
‚úÖ Export task started for post_flood_Dhemaji.

--- Running analysis for Lakhimpur ---
Pre-flood period: 2023-05-01 to 2023-05-31
Post-flood period: 2023-06-15 to 2023-06-25
Pre-flood images available: 8
Post-flood images available: 2
‚úÖ Export task started for pre_flood_Lakhimpur.
‚úÖ Export task started for post_flood_Lakhimpur.

--- Running analysis for Nalbari ---
Pre-flood period: 2023-05-01 to 2023-05-31
Post-flood period: 2023

In [1]:
# =================================================================================
# CORE LOGIC & FUNCTIONS (Improved for robust export)
# =================================================================================

def run_and_export_analysis(district_name, geometry):
    """
    Performs flood analysis for a given district and starts robust export tasks.
    """
    # Use the dates you defined globally
    global pre_flood_start, pre_flood_end, post_flood_start, post_flood_end

    print(f'\n--- Running analysis for {district_name} ---')

    # Load and filter the Sentinel-1 collection for pre-flood
    s1_collection_pre = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filter(ee.Filter.eq('instrumentMode', 'IW')) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
        .filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING')) \
        .filterDate(pre_flood_start, pre_flood_end) \
        .filterBounds(geometry)

    # Load and filter the Sentinel-1 collection for post-flood
    s1_collection_post = ee.ImageCollection('COPERNICUS/S1_GRD') \
        .filter(ee.Filter.eq('instrumentMode', 'IW')) \
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')) \
        .filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING')) \
        .filterDate(post_flood_start, post_flood_end) \
        .filterBounds(geometry)

    # Check collection sizes (Crucial for debugging)
    pre_size = s1_collection_pre.size().getInfo()
    post_size = s1_collection_post.size().getInfo()

    print(f'Pre-flood images available: {pre_size}')
    print(f'Post-flood images available: {post_size}')

    if pre_size > 0 and post_size > 0:
        # --- ROBUST MOSAIC CREATION ---
        # 1. Use .mean() instead of .mosaic() for better stability.
        # 2. Use .unmask() to prevent masked pixels from causing export errors.
        pre_flood_mosaic = s1_collection_pre.mean().unmask().clip(geometry).toFloat()
        post_flood_mosaic = s1_collection_post.mean().unmask().clip(geometry).toFloat()

        # --- EXPORT TASKS ---
        # Export Pre-Flood Image
        task_pre = ee.batch.Export.image.toDrive(
            image=pre_flood_mosaic.select('VV'),
            description=f'{district_name}_PreFlood_Image',
            folder='GEE_Flood_Exports_Assam',
            scale=10,
            region=geometry,
            fileFormat='GeoTIFF'
        )
        task_pre.start()
        print(f'‚úÖ Export task started for {district_name} (Pre-Flood).')

        # Export Post-Flood Image
        task_post = ee.batch.Export.image.toDrive(
            image=post_flood_mosaic.select('VV'),
            description=f'{district_name}_PostFlood_Image',
            folder='GEE_Flood_Exports_Assam',
            scale=10,
            region=geometry,
            fileFormat='GeoTIFF'
        )
        task_post.start()
        print(f'‚úÖ Export task started for {district_name} (Post-Flood).')

    else:
        print(f'‚ùóÔ∏è No images found for {district_name}. Exports skipped.')

# You must also define your district_geometries and dates globally above this function.

In [3]:
import os
import numpy as np
import rasterio
from rasterio.merge import merge
from scipy.ndimage import binary_opening, binary_closing
from skimage.filters import threshold_otsu
from skimage.restoration import denoise_bilateral
import warnings
import sys
# Set a custom recursion limit for safety in complex geospatial operations
sys.setrecursionlimit(2000) 
warnings.filterwarnings('ignore')

# =================================================================
# CONFIGURATION (Finalized for Stable SAR Analysis)
# =================================================================

# IMPORTANT: Adjust these paths to your system
OUTPUT_CHIPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Output_chips'
FLOOD_MASKS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Flood_Masks'
FINAL_MAPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Final_Flood_Maps'

DISTRICTS = ['Barpeta', 'Dhemaji', 'Lakhimpur', 'Nalbari', 'Sonitpur']

# Change detection parameters
CHANGE_METHOD = 'log_ratio'      # Standard for SAR change detection
THRESHOLD_METHOD = 'fixed'       # <-- MOST RELIABLE: Set to fixed threshold
FIXED_THRESHOLD = -1.5           # Conservative threshold (in dB)
PERCENTILE_VALUE = 10            

# Post-processing
MIN_OBJECT_SIZE = 10             
APPLY_MORPHOLOGY = True          

# Despeckling Control
APPLY_DESPECKLE = True           # <-- ACTIVATED SPECKLE FILTERING
SPECKLE_WINDOW_SIZE = 5          # 5x5 pixel window for the bilateral filter
MIN_VALID_PIXELS = 100           


# =================================================================
# CORE FUNCTIONS (Despeckling and Analysis)
# =================================================================

def apply_lee_filter(image, window_size=5):
    """ Applies a bilateral filter (robust general denoising) for speckle reduction. """
    # Uses denoise_bilateral for robust smoothing on float data.
    image_denoised = denoise_bilateral(
        image, 
        sigma_color=0.1, 
        sigma_spatial=window_size / 2.0, 
        channel_axis=None 
    )
    return image_denoised


def safe_log_ratio(pre_chip, post_chip, epsilon=1e-10):
    """Calculates 10 * log10(Post/Pre) for change detection."""
    pre_chip = np.where(pre_chip <= 0, epsilon, pre_chip)
    post_chip = np.where(post_chip <= 0, epsilon, post_chip)
    log_ratio = 10 * np.log10(post_chip / pre_chip)
    return log_ratio


def calculate_change(pre_chip, post_chip, method='log_ratio'):
    """Wrapper function to choose the change calculation method."""
    if method == 'log_ratio':
        change = safe_log_ratio(pre_chip, post_chip)
    else:
        raise ValueError(f"Unknown method: {method}")
    return change


def apply_threshold(change_map, method='fixed', fixed_value=-3.0, percentile=10):
    """Applies the final classification threshold (fixed, otsu, or percentile)."""
    # Isolate valid data for threshold calculation
    valid_mask = np.isfinite(change_map)
    valid_data = change_map[valid_mask]
    
    if len(valid_data) < MIN_VALID_PIXELS:
        raise ValueError(f"Insufficient valid data: only {len(valid_data)} valid pixels")
    
    # Apply the chosen threshold
    if method == 'fixed':
        flood_mask = change_map < fixed_value
    else:
        # Fallback to Otsu/Percentile logic
        threshold = threshold_otsu(valid_data) # Using Otsu as default fallback
        flood_mask = change_map < threshold
    
    # Set invalid pixels to 0 (non-flooded)
    flood_mask = np.where(valid_mask, flood_mask, 0)
    
    return flood_mask.astype(np.uint8)


def post_process_mask(mask, min_size=10, apply_morph=True):
    """Cleans up the binary mask using morphological operations."""
    if apply_morph:
        mask = binary_opening(mask, structure=np.ones((3, 3)))
        mask = binary_closing(mask, structure=np.ones((3, 3)))
    
    return mask.astype(np.uint8)


def validate_chip_data(chip_array):
    """Ensures a chip contains enough unique, non-NaN data to be processed."""
    nan_count = np.isnan(chip_array).sum()
    total_pixels = chip_array.size
    valid_pixels = total_pixels - nan_count
    inf_count = np.isinf(chip_array).sum()
    
    if valid_pixels < MIN_VALID_PIXELS: 
        return False, f"Too few valid pixels: {valid_pixels}/{total_pixels}"
    if inf_count > 0:
        return False, f"Contains infinite values: {inf_count}"
    
    unique_vals = np.unique(chip_array[~np.isnan(chip_array)])
    if len(unique_vals) < 2:
        return False, "No variation in data (likely no-data chip)"
        
    return True, "Valid"


def process_chip_pair(pre_path, post_path, output_path):
    """Core function: reads chips, applies despeckle, calculates change, and saves mask."""
    try:
        # --- READ CHIPS ---
        with rasterio.open(pre_path) as src_pre:
            pre_chip = src_pre.read(1).astype(np.float32)
            profile = src_pre.profile.copy()
        
        with rasterio.open(post_path) as src_post:
            post_chip = src_post.read(1).astype(np.float32)
        
        # --- DESPECKLING STEP ---
        if APPLY_DESPECKLE:
            pre_chip = apply_lee_filter(pre_chip, window_size=SPECKLE_WINDOW_SIZE)
            post_chip = apply_lee_filter(post_chip, window_size=SPECKLE_WINDOW_SIZE)
        
        # --- VALIDATION AND CLEANUP ---
        is_valid_pre, msg_pre = validate_chip_data(pre_chip)
        if not is_valid_pre:
             raise ValueError(f"Pre-flood: {msg_pre}")
        
        # Replace remaining NaNs with mean of valid pixels
        if np.isnan(pre_chip).any():
            pre_mean = np.nanmean(pre_chip)
            pre_chip = np.nan_to_num(pre_chip, nan=pre_mean)
        
        if np.isnan(post_chip).any():
            post_mean = np.nanmean(post_chip)
            post_chip = np.nan_to_num(post_chip, nan=post_mean)
            
        # --- ANALYSIS AND CLASSIFICATION ---
        change_map = calculate_change(pre_chip, post_chip, method=CHANGE_METHOD)
        
        flood_mask = apply_threshold(
            change_map, 
            method=THRESHOLD_METHOD,
            fixed_value=FIXED_THRESHOLD,
            percentile=PERCENTILE_VALUE
        )
        
        # --- POST-PROCESSING ---
        flood_mask = post_process_mask(
            flood_mask,
            min_size=MIN_OBJECT_SIZE,
            apply_morph=APPLY_MORPHOLOGY
        )
        
        # --- SAVE MASK ---
        profile.update(dtype=rasterio.uint8, count=1, compress='LZW')
        
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(flood_mask, 1)
        
        return True
        
    except Exception as e:
        # Error reporting is ON to find file/data issues.
        print(f"‚ùå Error processing {os.path.basename(pre_path)}: {e}") 
        return False


def stitch_flood_masks(mask_dir, district_name, output_dir):
    """Stitch all chip-level flood masks into a single district-level GeoTIFF."""
    
    mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) 
                         if f.endswith('.tif')])
    
    if not mask_files:
        print(f" ¬† ‚ö†Ô∏è No mask files found in {mask_dir}")
        return
    
    print(f" ¬† Stitching {len(mask_files)} mask chips...")
    
    sources = [rasterio.open(f) for f in mask_files]
    
    stitched_array, out_transform = merge(sources)
    
    out_meta = sources[0].profile.copy()
    out_meta.update({
        "driver": "GTiff",
        "height": stitched_array.shape[1],
        "width": stitched_array.shape[2],
        "transform": out_transform,
        "count": 1,
        "dtype": 'uint8',
        "compress": 'LZW'
    })
    
    output_path = os.path.join(output_dir, f'{district_name}_Flood_Map.tif')
    with rasterio.open(output_path, "w", **out_meta) as dest:
        dest.write(stitched_array, 1)
    
    for src in sources:
        src.close()
    
    print(f" ¬† ‚úÖ Final flood map saved: {output_path}")
    
    total_pixels = stitched_array.size
    flood_pixels = np.sum(stitched_array == 1)
    flood_percentage = (flood_pixels / total_pixels) * 100
    print(f" ¬† üìä Flooded area: {flood_percentage:.2f}% of total pixels")


# =================================================================
# MAIN EXECUTION LOGIC
# =================================================================

if __name__ == '__main__':
    print("="*70)
    print("PHASE 1A: SAR CHANGE DETECTION FOR FLOOD MAPPING")
    print("="*70)
    print(f"\nConfiguration:")
    print(f" ¬†Change method: {CHANGE_METHOD}")
    print(f" ¬†Threshold method: {THRESHOLD_METHOD}")
    print(f" ¬†Fixed threshold: {FIXED_THRESHOLD} dB")
    print(f" ¬†Post-processing: {'Enabled' if APPLY_MORPHOLOGY else 'Disabled'}")
    print(f" ¬†Despeckling: {'Enabled' if APPLY_DESPECKLE else 'Disabled'}")
    print("\n" + "="*70 + "\n")

    # Create output directories
    os.makedirs(FLOOD_MASKS_DIR, exist_ok=True)
    os.makedirs(FINAL_MAPS_DIR, exist_ok=True)

    # Process each district
    for district in DISTRICTS:
        print(f"üåä Processing {district}...")
        
        pre_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'pre_flood')
        post_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'post_flood')
        mask_dir = os.path.join(FLOOD_MASKS_DIR, district)
        
        # Check if directories exist
        if not os.path.exists(pre_dir) or not os.path.exists(post_dir):
            print(f" ¬† ‚ö†Ô∏è Skipping {district}: Chip directories not found")
            continue
        
        # Create mask directory
        os.makedirs(mask_dir, exist_ok=True)
        
        # Get list of chip files (We assume the pre-flood list contains all necessary names)
        pre_chips = sorted([f for f in os.listdir(pre_dir) if f.endswith('.tif')])
        
        if not pre_chips:
            print(f" ¬† ‚ö†Ô∏è No .tif files found in {pre_dir}")
            continue
        
        print(f" ¬† Processing {len(pre_chips)} chip pairs...")
        
        # Process each chip pair
        success_count = 0
        skip_count = 0
        
        for chip_name in pre_chips:
            pre_path = os.path.join(pre_dir, chip_name)
            
            # --- Robust File Matching (Handles the 'PreFlood' -> 'PostFlood' name change) ---
            post_chip_name = chip_name.replace('PreFlood_Image', 'PostFlood_Image')
            post_path = os.path.join(post_dir, post_chip_name)
            
            # Create output filename for mask
            mask_name = chip_name.replace('PreFlood_Image', 'Flood_Mask')
            mask_path = os.path.join(mask_dir, mask_name)
            
            # Check if corresponding post-flood chip exists before processing
            if not os.path.exists(post_path):
                skip_count += 1
                continue
            
            # Process chip pair
            if process_chip_pair(pre_path, post_path, mask_path):
                success_count += 1
            else:
                skip_count += 1
        
        print(f" ¬† ‚úÖ Successfully processed: {success_count}/{len(pre_chips)} chips")
        print(f" ¬† ‚ö†Ô∏è Skipped (data/file error): {skip_count}/{len(pre_chips)} chips")
        
        # Stitch masks into final district-level flood map
        if success_count > 0:
            stitch_flood_masks(mask_dir, district, FINAL_MAPS_DIR)
        else:
            print(f" ¬† ‚ö†Ô∏è No successful masks to stitch for {district}")
        print()
    
    print("="*70)
    print("‚úÖ PHASE 1A COMPLETE! Outputs are ready for QGIS validation.")
    print("="*70)

PHASE 1A: SAR CHANGE DETECTION FOR FLOOD MAPPING

Configuration:
 ¬†Change method: log_ratio
 ¬†Threshold method: fixed
 ¬†Fixed threshold: -1.5 dB
 ¬†Post-processing: Enabled
 ¬†Despeckling: Enabled


üåä Processing Barpeta...
 ¬† Processing 96 chip pairs...
‚ùå Error processing Barpeta_PreFlood_Image_chip_0.tif: Pre-flood: Too few valid pixels: 0/262144
‚ùå Error processing Barpeta_PreFlood_Image_chip_1.tif: Pre-flood: Too few valid pixels: 0/262144
‚ùå Error processing Barpeta_PreFlood_Image_chip_10.tif: Pre-flood: Too few valid pixels: 0/262144
‚ùå Error processing Barpeta_PreFlood_Image_chip_11.tif: Pre-flood: Too few valid pixels: 0/154624
‚ùå Error processing Barpeta_PreFlood_Image_chip_2.tif: Pre-flood: Too few valid pixels: 0/262144
‚ùå Error processing Barpeta_PreFlood_Image_chip_85.tif: Pre-flood: Too few valid pixels: 0/201216
‚ùå Error processing Barpeta_PreFlood_Image_chip_86.tif: Pre-flood: Too few valid pixels: 0/201216
‚ùå Error processing Barpeta_PreFlood_Image_chip_

ValueError: Source shape (1, 1, 3977, 5934) is inconsistent with given indexes 1

# Despekling 

In [None]:
import os
import numpy as np
import rasterio
from rasterio.merge import merge
from scipy.ndimage import binary_opening, binary_closing
from skimage.filters import threshold_otsu
from skimage.restoration import denoise_tv_chambolle, denoise_wavelet, denoise_bilateral
from skimage.filters.rank import mean as mean_filter 
import warnings
warnings.filterwarnings('ignore')

# =================================================================
# CONFIGURATION (Updated with Despeckle Control)
# =================================================================

OUTPUT_CHIPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Output_chips'
FLOOD_MASKS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Flood_Masks'
FINAL_MAPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Final_Flood_Maps'

DISTRICTS = ['Barpeta', 'Dhemaji', 'Lakhimpur', 'Nalbari', 'Sonitpur']

# Change detection parameters
CHANGE_METHOD = 'log_ratio'      # Standard for SAR change detection
THRESHOLD_METHOD = 'fixed'       # <-- CHANGED to fixed for stable results
FIXED_THRESHOLD = -1.5           # <-- Set to a robust value (in dB)
PERCENTILE_VALUE = 10            

# Post-processing
MIN_OBJECT_SIZE = 10             
APPLY_MORPHOLOGY = True          

# Despeckling Control (Missing variables added)
APPLY_DESPECKLE = True           # <-- ACTIVATED SPECKLE FILTERING
SPECKLE_WINDOW_SIZE = 5          # 5x5 pixel window for the bilateral filter
MIN_VALID_PIXELS = 100           

# =================================================================
# CORE FUNCTIONS (No changes needed here, as the logic is correct)
# =================================================================

def apply_lee_filter(image, window_size=5):
    """
    Applies a bilateral filter (robust general denoising) to reduce speckle noise.
    """
    # Use denoise_bilateral as a robust general denoising filter for float data.
    image_denoised = denoise_bilateral(
        image, 
        sigma_color=0.1, 
        sigma_spatial=window_size / 2.0, 
        channel_axis=None 
    )
    return image_denoised


def safe_log_ratio(pre_chip, post_chip, epsilon=1e-10):
    pre_chip = np.where(pre_chip <= 0, epsilon, pre_chip)
    post_chip = np.where(post_chip <= 0, epsilon, post_chip)
    log_ratio = 10 * np.log10(post_chip / pre_chip)
    return log_ratio


def calculate_change(pre_chip, post_chip, method='log_ratio'):
    if method == 'log_ratio':
        change = safe_log_ratio(pre_chip, post_chip)
    else:
        raise ValueError(f"Unknown method: {method}")
    return change


def apply_threshold(change_map, method='otsu', fixed_value=-3.0, percentile=10):
    # Remove NaN and infinite values before thresholding
    valid_mask = np.isfinite(change_map)
    valid_data = change_map[valid_mask]
    
    if len(valid_data) < MIN_VALID_PIXELS:
        raise ValueError(f"Insufficient valid data: only {len(valid_data)} valid pixels")
    
    if method == 'otsu':
        threshold = threshold_otsu(valid_data)
        flood_mask = change_map < threshold
        
    elif method == 'fixed':
        flood_mask = change_map < fixed_value
        
    elif method == 'percentile':
        threshold = np.percentile(valid_data, percentile)
        flood_mask = change_map < threshold
        
    else:
        raise ValueError(f"Unknown threshold method: {method}")
    
    flood_mask = np.where(valid_mask, flood_mask, 0)
    
    return flood_mask.astype(np.uint8)


def post_process_mask(mask, min_size=10, apply_morph=True):
    if apply_morph:
        mask = binary_opening(mask, structure=np.ones((3, 3)))
        mask = binary_closing(mask, structure=np.ones((3, 3)))
    
    return mask.astype(np.uint8)


def validate_chip_data(chip_array):
    # (Existing validation code remains the same)
    nan_count = np.isnan(chip_array).sum()
    total_pixels = chip_array.size
    valid_pixels = total_pixels - nan_count
    inf_count = np.isinf(chip_array).sum()
    
    if valid_pixels < MIN_VALID_PIXELS: 
        return False, f"Too few valid pixels: {valid_pixels}/{total_pixels}"
    if inf_count > 0:
        return False, f"Contains infinite values: {inf_count}"
    
    unique_vals = np.unique(chip_array[~np.isnan(chip_array)])
    if len(unique_vals) < 2:
        return False, "No variation in data (likely no-data chip)"
        
    return True, "Valid"


def process_chip_pair(pre_path, post_path, output_path):
    """Process a single chip pair and generate flood mask."""
    try:
        with rasterio.open(pre_path) as src_pre:
            pre_chip = src_pre.read(1).astype(np.float32)
            profile = src_pre.profile.copy()
        
        with rasterio.open(post_path) as src_post:
            post_chip = src_post.read(1).astype(np.float32)
        
        # --- DESPECKLING STEP ---
        if APPLY_DESPECKLE:
            pre_chip = apply_lee_filter(pre_chip, window_size=SPECKLE_WINDOW_SIZE)
            post_chip = apply_lee_filter(post_chip, window_size=SPECKLE_WINDOW_SIZE)
        # ------------------------
        
        # Validate chips AFTER despeckling
        is_valid_pre, msg_pre = validate_chip_data(pre_chip)
        if not is_valid_pre:
             raise ValueError(f"Pre-flood: {msg_pre}")
        
        # Replace remaining NaNs with mean of valid pixels (robustness)
        if np.isnan(pre_chip).any():
            pre_mean = np.nanmean(pre_chip)
            pre_chip = np.nan_to_num(pre_chip, nan=pre_mean)
        
        if np.isnan(post_chip).any():
            post_mean = np.nanmean(post_chip)
            post_chip = np.nan_to_num(post_chip, nan=post_mean)
            
        # Step 1: Calculate change (log ratio)
        change_map = calculate_change(pre_chip, post_chip, method=CHANGE_METHOD)
        
        # Step 2: Apply threshold
        flood_mask = apply_threshold(
            change_map, 
            method=THRESHOLD_METHOD,
            fixed_value=FIXED_THRESHOLD,
            percentile=PERCENTILE_VALUE
        )
        
        # Step 3: Post-process (morphology)
        flood_mask = post_process_mask(
            flood_mask,
            min_size=MIN_OBJECT_SIZE,
            apply_morph=APPLY_MORPHOLOGY
        )
        
        # Step 4: Save the mask
        profile.update(dtype=rasterio.uint8, count=1, compress='LZW')
        
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(flood_mask, 1)
        
        return True
        
    except Exception as e:
        # print(f"‚ùå Error processing {os.path.basename(pre_path)}: {e}") # Uncomment for debugging
        return False

# (The rest of the stitching and main execution logic remains the same)
# ...

In [3]:
from terratorch.registry import BACKBONE_REGISTRY
from torch import nn
import torch

# Assume processed_tensors is populated here from your previous steps
# processed_tensors = {'Barpeta': {'pre_flood': [...], 'post_flood': [...]}, ...}

# --- Define the specialized model structure ---
class PrithviFloodSegmentationModel(nn.Module):
    # This class definition is required to load the Prithvi weights and adapt them for 2-channel SAR data.
    def __init__(self, output_classes=2):
        super().__init__()
        
        # 1. Load the Prithvi-EO-2.0-600M Backbone
        self.backbone = BACKBONE_REGISTRY.build("prithvi_eo_v2_600", pretrained=True)
        
        # 2. Adapt Input Layer for 2 Channels (Temporal SAR)
        original_conv = self.backbone.patch_embed.proj[0].conv
        new_conv = nn.Conv3d(
            in_channels=2, # Set input channels to 2 (Pre-VV, Post-VV)
            out_channels=original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=original_conv.bias
        )
        # Average original weights and replace the input conv layer
        new_conv.weight.data = original_conv.weight.data[:, :2, :, :].mean(dim=1, keepdim=True).repeat(1, 2, 1, 1, 1)
        self.backbone.patch_embed.proj[0].conv = new_conv

        # 3. Simple Placeholder Segmentation Head (Replace with actual U-Net decoder if known)
        self.segmentation_head = nn.Sequential(
            nn.Conv3d(in_channels=768, out_channels=2, kernel_size=1) 
            # Note: The output requires complex upsampling/decoding, this is highly simplified.
        )
        
    def forward(self, x):
        # Transpose input to (Batch, Channel, Time, H, W) if needed, then pass through model
        # The specific forward pass for temporal features is complex and depends on the exact head.
        # This is a conceptual representation.
        features = self.backbone(x)
        # Assuming features are prepared for segmentation head here...
        return self.segmentation_head(features)
        

# --- FINAL EXECUTION LOOP ---
def run_full_flood_pipeline():
    # 1. Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 2. Instantiate and Load Model
    try:
        model = PrithviFloodSegmentationModel(output_classes=2).to(device)
        model.eval() # Set the model to evaluation mode
        print("‚úÖ Prithvi Model loaded successfully.")
    except Exception as e:
        print(f"‚ùå ERROR loading model: {e}. Cannot proceed with inference.")
        return
    
    # Define a directory to save the intermediate prediction masks
    PREDICTION_MASKS_DIR = os.path.join(ROOT_CHIPS_DIR, 'predicted_masks')
    os.makedirs(PREDICTION_MASKS_DIR, exist_ok=True)

    # 3. Run Inference on All Chips
    for district, phases in processed_tensors.items():
        if phases['pre_flood'] and phases['post_flood']:
            print(f"\nRunning inference for {district}...")
            
            # Ensure the number of pre and post chips match for temporal analysis
            num_chips = min(len(phases['pre_flood']), len(phases['post_flood']))
            
            district_mask_dir = os.path.join(PREDICTION_MASKS_DIR, district)
            os.makedirs(district_mask_dir, exist_ok=True)
            
            for i in range(num_chips):
                pre_tensor = phases['pre_flood'][i]
                post_tensor = phases['post_flood'][i]
                
                # Stack the two tensors along the channel dimension (temporal input)
                # Input shape: (1, 2, 512, 512) -- Note: The Prithvi model expects a TIME dimension often!
                temporal_input = torch.cat([pre_tensor, post_tensor], dim=1).to(device)

                with torch.no_grad():
                    # NOTE: This model call is conceptual due to the simplified model head.
                    # The actual segmentation head will define how features are mapped back to image size.
                    output_logits = model(temporal_input) 
                    
                    # --- [Prediction Logic] ---
                    # 1. Get the class index with highest probability (0 or 1)
                    # 2. Reshape and convert to NumPy array (H, W)
                    # 3. Save mask chip to district_mask_dir
                    
            # 4. Stitch the results (Assuming you have a list of mask paths)
            # stitch_masks(district_mask_dir, district, FINAL_OUTPUT_DIR) # Call your stitching function

# run_full_flood_pipeline() # Uncomment to execute

INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.8 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations
  _C._set_float32_matmul_precision(precision)


In [2]:
from terratorch.registry import BACKBONE_REGISTRY
from torch import nn
import torch

class PrithviFloodSegmentationModel(nn.Module):
    """
    Instantiates the Prithvi backbone and customizes it for 2-channel temporal SAR input.
    """
    def __init__(self, output_classes=2):
        super().__init__()
        
        # 1. Load the Prithvi-EO-2.0-600M Backbone using TerraTorch
        self.backbone = BACKBONE_REGISTRY.build("prithvi_eo_v2_600", pretrained=True)
        
        # 2. Modify Input Layer (Crucial for SAR)
        # Prithvi expects 6 channels. We modify the first convolution (3D patch embedding)
        # to accept 2 input channels instead of 6.
        original_conv = self.backbone.patch_embed.proj[0].conv
        new_conv = nn.Conv3d(
            in_channels=2, # <--- Set input channels to 2 (Pre-VV, Post-VV)
            out_channels=original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=original_conv.bias
        )
        # Copy original weights (simplified method: average and repeat)
        new_conv.weight.data = original_conv.weight.data[:, :2, :, :].mean(dim=1, keepdim=True).repeat(1, 2, 1, 1, 1)
        self.backbone.patch_embed.proj[0].conv = new_conv

        # 3. Segmentation Head (Placeholder)
        # NOTE: You MUST replace this simple linear layer with a proper decoder 
        # (e.g., U-Net or UPerNet, commonly used with Prithvi) to get a pixel-level output.
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=output_classes, kernel_size=1)
        )

    def forward(self, x):
        # The model requires input to be (Batch, Channel, Time, Height, Width)
        # Your input is (B, T, H, W). We need to swap T and C to get (B, C, T, H, W)
        x = x.permute(0, 2, 1, 3, 4) # This swap is necessary if input shape is not (B, C, T, H, W)

        # Pass through backbone and segmentation head
        features = self.backbone(x) 
        # Assuming features are correctly processed for segmentation head here
        # (This is a complex detail that requires the full Prithvi segmentation architecture)
        
        # ... (Rest of the model forward pass) ...
        return features
        
# Example of loading the model:
# model = PrithviFloodSegmentationModel(output_classes=2) 
# print("Model loaded successfully.")

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.8 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations
INFO:matplotlib.font_manager:generated new fontManager
  _C._set_float32_matmul_precision(precision)


In [9]:
import numpy as np
import rasterio
import torch
import os
import ee.batch
from rasterio.merge import merge
from rasterio.windows import Window
from terratorch.registry import BACKBONE_REGISTRY
from torch import nn
from torch.nn.modules.conv import Conv3d # Required for type hinting/access

# =======================================================================
# CONFIGURATION - ADJUST THESE PATHS AND CONSTANTS
# =======================================================================

# IMPORTANT: SET YOUR ROOT DIRECTORY HERE
# This path must match the location of your chipped data
ROOT_CHIPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Output_chips' 

# Sentinel-1 Normalization Parameters (Used in your preprocessing function)
SAR_NORM_MEAN = -15.0  # Common mean for VV dB values
SAR_NORM_STD = 5.0    # Common standard deviation for VV dB values

# Directory to save the final stitched output
FINAL_STITCHED_DIR = os.path.join(ROOT_CHIPS_DIR, 'Final_Stitched_Masks')
os.makedirs(FINAL_STITCHED_DIR, exist_ok=True)
# Directory to save the intermediate predicted chips
PREDICTION_MASKS_DIR = os.path.join(ROOT_CHIPS_DIR, 'intermediate_masks')
os.makedirs(PREDICTION_MASKS_DIR, exist_ok=True)

# Placeholder for tensors (Assume this dictionary is populated by your preprocessing script)
processed_tensors = {}
# NOTE: In a real run, you would execute your preprocessing script here to populate this dictionary.
# For this code to run successfully, ensure 'processed_tensors' is populated with your data.

# =======================================================================
# II. HELPER FUNCTIONS (Stitching and Preprocessing)
# =======================================================================

def stitch_masks(mask_dir, district_name, final_output_dir):
    """Stitches all predicted flood mask chips into a single GeoTIFF."""
    
    mask_files = [os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.tif')]
    
    if not mask_files:
        print(f"Skipping stitching: No mask chips found in {mask_dir}")
        return
        
    # Open all mask datasets
    sources = [rasterio.open(f) for f in mask_files]
    
    # Use rasterio.merge to create a mosaic
    stitched_array, out_transform = merge(sources)
    
    # Get the metadata from the first source file
    out_meta = sources[0].profile.copy()
    
    # Update the metadata for the merged output
    out_meta.update({
        "driver": "GTiff",
        "height": stitched_array.shape[1],
        "width": stitched_array.shape[2],
        "transform": out_transform,
        "count": 1,
        "dtype": 'uint8', # The output mask is a binary integer (0 or 1)
        "nodata": 255 # Set nodata value for clear background
    })
    
    # Write the final stitched GeoTIFF
    final_output_path = os.path.join(final_output_dir, f'{district_name}_Final_Flood_Mask.tif')
    with rasterio.open(final_output_path, "w", **out_meta) as dest:
        dest.write(stitched_array)
        
    # Close all source files
    for src in sources:
        src.close()
        
    print(f"‚úÖ Final stitched mask saved to: {final_output_path}")

# NOTE: Your preprocess_sar_chip function (from earlier) is needed here to load 
# the original metadata profile if you want to save the prediction mask chip with correct 
# georeferencing in the run_inference_and_save loop. 

# For simplicity, we assume we can read a profile from the source chip for saving the mask.
def get_profile(file_path):
    with rasterio.open(file_path) as src:
        return src.profile.copy()

# =======================================================================
# III. PRITHVI MODEL DEFINITION (Fixed and Ready)
# =======================================================================

class PrithviFloodSegmentationModel(nn.Module):
    """
    Model wrapper to load the Prithvi-600M backbone and adapt its input layer 
    for 2-channel temporal SAR input (Pre-VV, Post-VV).
    """
    def __init__(self, output_classes=2):
        super().__init__()
        
        # 1. Load the Prithvi-EO-2.0-600M Backbone
        self.backbone = BACKBONE_REGISTRY.build("prithvi_eo_v2_600", pretrained=True)
        
        # 2. Adapt Input Layer for 2 Channels (Crucial Fix)
        # Access the Conv3d layer within the patch_embed.proj Sequential block
        original_conv = self.backbone.patch_embed.proj[0].conv
        
        # Access the original weights' data
        original_weights = original_conv.weight.data

        # Create the new 2-channel convolution layer
        new_conv = nn.Conv3d(
            in_channels=2, # Set input channels to 2
            out_channels=original_weights.shape[0], 
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=original_conv.bias is not None
        )
        
        # Adapt weights: Calculate the mean across the original 6 channels (temporal input dimension is 1 in the weights)
        # Then, tile it to fill the 2 new input channels (VV-pre, VV-post).
        adapted_weights = original_weights[:, :2, :, :, :].mean(dim=1, keepdim=True).repeat(1, 2, 1, 1, 1)
        new_conv.weight.data = adapted_weights
        
        # Replace the original convolution layer in the model structure
        self.backbone.patch_embed.proj[0].conv = new_conv 

        # 3. Simplified Segmentation Head (For a full project, this needs a proper decoder)
        self.segmentation_head = nn.Sequential(
            # Final 1x1x1 Conv to reduce channels to output_classes
            nn.Conv3d(in_channels=768, out_channels=output_classes, kernel_size=1) 
        )

    def forward(self, x):
        # Input must be shaped: (B, C, T, H, W) -> (B, 2, 1, H, W) in our case
        x = x.unsqueeze(2) # Add a Time dimension (T=1) -> (B, 2, 1, H, W)
        
        # Pass through backbone
        features = self.backbone(x) # Output features: (B, C_features, T, H', W')
        
        # Pass through simplified segmentation head
        output_logits = self.segmentation_head(features)
        
        # Reshape output from (B, Classes, T=1, H, W) to (B, Classes, H, W)
        output_logits = output_logits.squeeze(2)
        return output_logits

# =======================================================================
# IV. FINAL EXECUTION PIPELINE
# =======================================================================

def run_full_flood_pipeline(processed_tensors):
    """
    Runs inference on all district tensors and manages the stitching process.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 1. Instantiate and Load Model
    try:
        model = PrithviFloodSegmentationModel(output_classes=2).to(device)
        model.eval() # Set the model to evaluation mode
        print("‚úÖ Prithvi Model loaded successfully.")
    except Exception as e:
        print(f"‚ùå ERROR loading model: {e}. Cannot proceed with inference.")
        return

    # 2. Run Inference and Save Prediction Chips
    for district, phases in processed_tensors.items():
        if phases['pre_flood'] and phases['post_flood']:
            print(f"\n--- Running inference for {district} ---")
            
            num_chips = min(len(phases['pre_flood']), len(phases['post_flood']))
            
            district_mask_dir = os.path.join(PREDICTION_MASKS_DIR, district)
            os.makedirs(district_mask_dir, exist_ok=True)
            
            # --- Assuming you have a list of file paths to match up for saving the profile ---
            # NOTE: For the code to be truly robust, you need the original list of file paths here.
            
            for i in range(num_chips):
                pre_tensor = phases['pre_flood'][i]
                post_tensor = phases['post_flood'][i]
                
                # Input shape: (1, 1, H, W) * 2 -> Concat along channel (dim=1) -> (1, 2, H, W)
                temporal_input = torch.cat([pre_tensor, post_tensor], dim=1).to(device)

                with torch.no_grad():
                    output_logits = model(temporal_input) 
                    
                    # 1. Get the final classification (0 or 1)
                    predicted_mask_tensor = torch.argmax(output_logits, dim=1).squeeze().cpu()
                    predicted_mask_array = predicted_mask_tensor.numpy().astype(rasterio.uint8)
                    
                    # 2. Save the prediction mask chip 
                    # NOTE: This requires fetching the profile from the original input chip
                    # profile = get_profile(original_input_chip_path)
                    
                    # Placeholder for saving (assumes you have the original profile):
                    # chip_output_path = os.path.join(district_mask_dir, f'{district}_mask_chip_{i}.tif')
                    # with rasterio.open(chip_output_path, 'w', **profile) as dst:
                    #     dst.write(predicted_mask_array, 1)

            # 3. Stitch the results
            # stitch_masks(district_mask_dir, district, FINAL_STITCHED_DIR) 
            print(f"‚úÖ Finished inference for {district}. Masks are ready for stitching.")

# --- EXECUTION ---
# NOTE: To run, uncomment the lines below and ensure 'processed_tensors' is populated.
run_preprocessing_pipeline() 
run_full_flood_pipeline(processed_tensors)

  lines = filter(lambda x: re.match("^\d+ bytes", x), data.splitlines())


ModuleNotFoundError: No module named '_curses'

In [4]:
import torch


In [5]:
torch.cuda.is_available()

False

In [10]:
# =================================================================
# PHASE 1A: UNSUPERVISED SAR CHANGE DETECTION
# =================================================================

import os
import numpy as np
import rasterio
from rasterio.merge import merge
from scipy.ndimage import binary_opening, binary_closing
from skimage.filters import threshold_otsu
import warnings
warnings.filterwarnings('ignore')

# =================================================================
# CONFIGURATION
# =================================================================

# Updated paths based on actual directory structure
BASE_DIR = r'C:\Kaam_Dhanda\Minor_Project'
OUTPUT_CHIPS_DIR = os.path.join(BASE_DIR, 'Output_chips')
FLOOD_MASKS_DIR = os.path.join(BASE_DIR, 'Flood_Masks')
FINAL_MAPS_DIR = os.path.join(BASE_DIR, 'Final_Flood_Maps')

DISTRICTS = ['Barpeta', 'Dhemaji', 'Lakhimpur', 'Nalbari', 'Sonitpur']

# Change detection parameters
CHANGE_METHOD = 'log_ratio'  # Options: 'log_ratio', 'difference', 'ratio'
THRESHOLD_METHOD = 'otsu'     # Options: 'otsu', 'fixed', 'percentile'
FIXED_THRESHOLD = -3.0        # Used if THRESHOLD_METHOD = 'fixed' (in dB)
PERCENTILE_VALUE = 10         # Used if THRESHOLD_METHOD = 'percentile'

# Post-processing
MIN_OBJECT_SIZE = 10          # Remove objects smaller than this (pixels)
APPLY_MORPHOLOGY = True       # Clean up with opening/closing

# Data validation
MIN_VALID_PIXELS = 100        # Minimum valid (non-NaN) pixels required to process a chip


# =================================================================
# CORE FUNCTIONS
# =================================================================

def safe_log_ratio(pre_chip, post_chip, epsilon=1e-10):
    """Calculate log ratio between pre and post images."""
    # Avoid log(0) by adding small epsilon
    pre_chip = np.where(pre_chip <= 0, epsilon, pre_chip)
    post_chip = np.where(post_chip <= 0, epsilon, post_chip)
    
    # Convert to dB scale: 10 * log10(post/pre)
    log_ratio = 10 * np.log10(post_chip / pre_chip)
    return log_ratio


def calculate_change(pre_chip, post_chip, method='log_ratio'):
    """Calculate change between pre and post chips."""
    if method == 'log_ratio':
        change = safe_log_ratio(pre_chip, post_chip)
    elif method == 'difference':
        change = post_chip - pre_chip
    elif method == 'ratio':
        epsilon = 1e-10
        pre_chip = np.where(pre_chip <= 0, epsilon, pre_chip)
        change = post_chip / pre_chip
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return change


def apply_threshold(change_map, method='otsu', fixed_value=-3.0, percentile=10):
    """Apply threshold to identify flooded areas."""
    # Remove NaN and infinite values before thresholding
    valid_mask = np.isfinite(change_map)
    valid_data = change_map[valid_mask]
    
    # Check if we have enough valid data
    if len(valid_data) < 100:
        raise ValueError(f"Insufficient valid data: only {len(valid_data)} valid pixels")
    
    if method == 'otsu':
        # Otsu's method finds optimal threshold automatically
        threshold = threshold_otsu(valid_data)
        # For flood detection, we want values BELOW threshold (darker = water)
        flood_mask = change_map < threshold
        
    elif method == 'fixed':
        # Fixed threshold (e.g., -3 dB for log ratio)
        flood_mask = change_map < fixed_value
        
    elif method == 'percentile':
        # Use bottom percentile as threshold
        threshold = np.percentile(valid_data, percentile)
        flood_mask = change_map < threshold
        
    else:
        raise ValueError(f"Unknown threshold method: {method}")
    
    # Set invalid pixels to 0 (non-flooded)
    flood_mask = np.where(valid_mask, flood_mask, 0)
    
    return flood_mask.astype(np.uint8)


def post_process_mask(mask, min_size=10, apply_morph=True):
    """Clean up the flood mask."""
    if apply_morph:
        # Morphological opening: removes small objects
        mask = binary_opening(mask, structure=np.ones((3, 3)))
        # Morphological closing: fills small holes
        mask = binary_closing(mask, structure=np.ones((3, 3)))
    
    # Remove small objects (optional: implement connected component analysis)
    # For now, we rely on morphological operations
    
    return mask.astype(np.uint8)


def validate_chip_data(chip_array, chip_name):
    """Validate chip data and return True if processable."""
    # Check for NaN values
    nan_count = np.isnan(chip_array).sum()
    total_pixels = chip_array.size
    valid_pixels = total_pixels - nan_count
    
    # Check for infinite values
    inf_count = np.isinf(chip_array).sum()
    
    # Check if we have enough valid data
    if valid_pixels < MIN_VALID_PIXELS:
        return False, f"Too few valid pixels: {valid_pixels}/{total_pixels}"
    
    if nan_count > total_pixels * 0.8:  # More than 80% NaN
        return False, f"Too many NaN values: {nan_count}/{total_pixels}"
    
    if inf_count > 0:
        return False, f"Contains infinite values: {inf_count}"
    
    # Check if all values are the same (no-data chip)
    unique_vals = np.unique(chip_array[~np.isnan(chip_array)])
    if len(unique_vals) < 2:
        return False, "No variation in data (likely no-data chip)"
    
    return True, "Valid"


def process_chip_pair(pre_path, post_path, output_path):
    """Process a single chip pair and generate flood mask."""
    try:
        with rasterio.open(pre_path) as src_pre:
            pre_chip = src_pre.read(1).astype(np.float32)
            profile = src_pre.profile.copy()
        
        with rasterio.open(post_path) as src_post:
            post_chip = src_post.read(1).astype(np.float32)
        
        # Validate pre-flood chip
        is_valid_pre, msg_pre = validate_chip_data(pre_chip, os.path.basename(pre_path))
        if not is_valid_pre:
            raise ValueError(f"Pre-flood: {msg_pre}")
        
        # Validate post-flood chip
        is_valid_post, msg_post = validate_chip_data(post_chip, os.path.basename(post_path))
        if not is_valid_post:
            raise ValueError(f"Post-flood: {msg_post}")
        
        # Replace NaN values with a reasonable value (e.g., mean of valid pixels)
        if np.isnan(pre_chip).any():
            pre_mean = np.nanmean(pre_chip)
            pre_chip = np.nan_to_num(pre_chip, nan=pre_mean)
        
        if np.isnan(post_chip).any():
            post_mean = np.nanmean(post_chip)
            post_chip = np.nan_to_num(post_chip, nan=post_mean)
        
        # Step 1: Calculate change
        change_map = calculate_change(pre_chip, post_chip, method=CHANGE_METHOD)
        
        # Step 2: Apply threshold
        flood_mask = apply_threshold(
            change_map, 
            method=THRESHOLD_METHOD,
            fixed_value=FIXED_THRESHOLD,
            percentile=PERCENTILE_VALUE
        )
        
        # Step 3: Post-process
        flood_mask = post_process_mask(
            flood_mask,
            min_size=MIN_OBJECT_SIZE,
            apply_morph=APPLY_MORPHOLOGY
        )
        
        # Step 4: Save the mask
        profile.update(dtype=rasterio.uint8, count=1, compress='LZW')
        
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(flood_mask, 1)
        
        return True
        
    except Exception as e:
        # Only print first few errors to avoid clutter
        return False


def stitch_flood_masks(mask_dir, district_name, output_dir):
    """Stitch all chip-level flood masks into a single district-level GeoTIFF."""
    
    mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) 
                        if f.endswith('.tif')])
    
    if not mask_files:
        print(f"‚ö†Ô∏è No mask files found in {mask_dir}")
        return
    
    print(f"   Stitching {len(mask_files)} mask chips...")
    
    # Open all mask datasets
    sources = [rasterio.open(f) for f in mask_files]
    
    # Merge into a single mosaic
    stitched_array, out_transform = merge(sources)
    
    # Get metadata from first source
    out_meta = sources[0].profile.copy()
    out_meta.update({
        "driver": "GTiff",
        "height": stitched_array.shape[1],
        "width": stitched_array.shape[2],
        "transform": out_transform,
        "count": 1,
        "dtype": 'uint8',
        "compress": 'LZW'
    })
    
    # Write final stitched GeoTIFF
    output_path = os.path.join(output_dir, f'{district_name}_Flood_Map.tif')
    with rasterio.open(output_path, "w", **out_meta) as dest:
        dest.write(stitched_array)
    
    # Close all sources
    for src in sources:
        src.close()
    
    print(f"   ‚úÖ Final flood map saved: {output_path}")
    
    # Calculate and print flood statistics
    total_pixels = stitched_array.size
    flood_pixels = np.sum(stitched_array == 1)
    flood_percentage = (flood_pixels / total_pixels) * 100
    print(f"   Flooded area: {flood_percentage:.2f}% of total pixels")


# =================================================================
# MAIN EXECUTION
# =================================================================

print("="*70)
print("PHASE 1A: SAR CHANGE DETECTION FOR FLOOD MAPPING (WITH DATA VALIDATION)")
print("="*70)
print(f"\nConfiguration:")
print(f"  Base directory: {BASE_DIR}")
print(f"  Chips directory: {OUTPUT_CHIPS_DIR}")
print(f"  Change method: {CHANGE_METHOD}")
print(f"  Threshold method: {THRESHOLD_METHOD}")
if THRESHOLD_METHOD == 'fixed':
    print(f"  Fixed threshold: {FIXED_THRESHOLD} dB")
elif THRESHOLD_METHOD == 'percentile':
    print(f"  Percentile: {PERCENTILE_VALUE}%")
print(f"  Post-processing: {'Enabled' if APPLY_MORPHOLOGY else 'Disabled'}")
print(f"  Min valid pixels: {MIN_VALID_PIXELS}")
print("\n" + "="*70 + "\n")

# Create output directories
os.makedirs(FLOOD_MASKS_DIR, exist_ok=True)
os.makedirs(FINAL_MAPS_DIR, exist_ok=True)

# Process each district
for district in DISTRICTS:
    print(f"üåä Processing {district}...")
    
    pre_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'pre_flood')
    post_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'post_flood')
    mask_dir = os.path.join(FLOOD_MASKS_DIR, district)
    
    # Check if directories exist
    if not os.path.exists(pre_dir) or not os.path.exists(post_dir):
        print(f"   ‚ö†Ô∏è Skipping {district}: Chip directories not found")
        print(f"      Expected: {pre_dir}")
        print(f"                {post_dir}")
        continue
    
    # Create mask directory
    os.makedirs(mask_dir, exist_ok=True)
    
    # Get list of chip files
    pre_chips = sorted([f for f in os.listdir(pre_dir) if f.endswith('.tif')])
    
    if not pre_chips:
        print(f"   ‚ö†Ô∏è No .tif files found in {pre_dir}")
        continue
    
    print(f"   Processing {len(pre_chips)} chip pairs...")
    
    # Process each chip pair
    success_count = 0
    skip_count = 0
    
    for chip_name in pre_chips:
        pre_path = os.path.join(pre_dir, chip_name)
        
        # Convert pre-flood chip name to post-flood chip name
        # Example: "Barpeta_PreFlood_Image_chip_0.tif" -> "Barpeta_PostFlood_Image_chip_0.tif"
        post_chip_name = chip_name.replace('PreFlood_Image', 'PostFlood_Image')
        post_path = os.path.join(post_dir, post_chip_name)
        
        # Check if corresponding post-flood chip exists
        if not os.path.exists(post_path):
            skip_count += 1
            continue
        
        # Create output filename for mask
        # Example: "Barpeta_PreFlood_Image_chip_0.tif" -> "Barpeta_Flood_Mask_chip_0.tif"
        mask_name = chip_name.replace('PreFlood_Image', 'Flood_Mask')
        mask_path = os.path.join(mask_dir, mask_name)
        
        # Process chip pair
        if process_chip_pair(pre_path, post_path, mask_path):
            success_count += 1
        else:
            skip_count += 1
    
    print(f"   ‚úÖ Successfully processed: {success_count}/{len(pre_chips)} chips")
    print(f"   ‚ö†Ô∏è Skipped (invalid data): {skip_count}/{len(pre_chips)} chips")
    
    # Stitch masks into final district-level flood map
    if success_count > 0:
        stitch_flood_masks(mask_dir, district, FINAL_MAPS_DIR)
    else:
        print(f"   ‚ö†Ô∏è No masks to stitch for {district}")
    print()

print("="*70)
print("‚úÖ PHASE 1A COMPLETE!")
print("="*70)
print(f"\nOutputs saved to:")
print(f"  Chip-level masks: {FLOOD_MASKS_DIR}")
print(f"  Final flood maps: {FINAL_MAPS_DIR}")
print("\nNote: Chips with invalid data (NaN/no-data areas) were automatically skipped.")
print("\nNext steps:")
print("  1. Open the GeoTIFF files in QGIS/ArcGIS for visualization")
print("  2. Visually assess flood detection quality")
print("  3. Adjust thresholds if needed and re-run")
print("  4. Proceed to Phase 2: Manual validation/correction")


PHASE 1A: SAR CHANGE DETECTION FOR FLOOD MAPPING (WITH DATA VALIDATION)

Configuration:
  Base directory: C:\Kaam_Dhanda\Minor_Project
  Chips directory: C:\Kaam_Dhanda\Minor_Project\Output_chips
  Change method: log_ratio
  Threshold method: otsu
  Post-processing: Enabled
  Min valid pixels: 100


üåä Processing Barpeta...
   Processing 96 chip pairs...
   ‚úÖ Successfully processed: 96/96 chips
   ‚ö†Ô∏è Skipped (invalid data): 0/96 chips
   Stitching 96 mask chips...
   ‚úÖ Successfully processed: 96/96 chips
   ‚ö†Ô∏è Skipped (invalid data): 0/96 chips
   Stitching 96 mask chips...
   ‚úÖ Final flood map saved: C:\Kaam_Dhanda\Minor_Project\Final_Flood_Maps\Barpeta_Flood_Map.tif
   Flooded area: 98.95% of total pixels

üåä Processing Dhemaji...
   Processing 104 chip pairs...
   ‚úÖ Final flood map saved: C:\Kaam_Dhanda\Minor_Project\Final_Flood_Maps\Barpeta_Flood_Map.tif
   Flooded area: 98.95% of total pixels

üåä Processing Dhemaji...
   Processing 104 chip pairs...
   ‚úÖ Su

In [9]:
# =================================================================
# PHASE 1A: IMPROVED SAR CHANGE DETECTION WITH DESPECKLING
# =================================================================

import os
import numpy as np
import rasterio
from rasterio.merge import merge
from scipy.ndimage import binary_opening, binary_closing, uniform_filter, label
from skimage.filters import threshold_otsu
import warnings
warnings.filterwarnings('ignore')

# =================================================================
# CONFIGURATION
# =================================================================

OUTPUT_CHIPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Output_chips'
FLOOD_MASKS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Flood_Masks'
FINAL_MAPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Final_Flood_Maps'

DISTRICTS = ['Barpeta', 'Dhemaji', 'Lakhimpur', 'Nalbari', 'Sonitpur']

# Change detection parameters
CHANGE_METHOD = 'log_ratio'       # Options: 'log_ratio', 'difference', 'ratio'
THRESHOLD_METHOD = 'fixed'        # CHANGED: Use fixed threshold for better control
FIXED_THRESHOLD = -5.0            # CHANGED: More conservative (-5 dB is significant change)
PERCENTILE_VALUE = 5              # Used if THRESHOLD_METHOD = 'percentile'

# Speckle filtering (CRITICAL FOR SAR!)
APPLY_SPECKLE_FILTER = True       # NEW: Apply Lee filter
SPECKLE_FILTER_SIZE = 5           # NEW: 5x5 window for Lee filter

# Post-processing (MORE AGGRESSIVE)
MIN_OBJECT_SIZE = 100             # CHANGED: Remove objects < 100 pixels (was 10)
APPLY_MORPHOLOGY = True           # Clean up with opening/closing
MORPH_KERNEL_SIZE = 5             # CHANGED: Larger kernel (was 3)

# Data validation
MIN_VALID_PIXELS = 1000           # CHANGED: Need more valid pixels (was 100)


# =================================================================
# CORE FUNCTIONS
# =================================================================

def lee_filter(img, size=5):
    """
    Apply Lee speckle filter to SAR image.
    
    The Lee filter:
    - Smooths uniform areas (reduces speckle)
    - Preserves edges and features
    - Essential for SAR change detection
    """
    img_mean = uniform_filter(img, size=size)
    img_sqr_mean = uniform_filter(img**2, size=size)
    img_variance = img_sqr_mean - img_mean**2
    
    overall_variance = np.var(img)
    
    # Avoid division by zero
    img_variance = np.maximum(img_variance, 0)
    
    # Lee filter weighting
    img_weights = img_variance / (img_variance + overall_variance + 1e-10)
    img_filtered = img_mean + img_weights * (img - img_mean)
    
    return img_filtered


def safe_log_ratio(pre_chip, post_chip, epsilon=1e-10):
    """Calculate log ratio between pre and post images."""
    pre_chip = np.where(pre_chip <= 0, epsilon, pre_chip)
    post_chip = np.where(post_chip <= 0, epsilon, post_chip)
    
    log_ratio = 10 * np.log10(post_chip / pre_chip)
    return log_ratio


def calculate_change(pre_chip, post_chip, method='log_ratio'):
    """Calculate change between pre and post chips."""
    if method == 'log_ratio':
        change = safe_log_ratio(pre_chip, post_chip)
    elif method == 'difference':
        change = post_chip - pre_chip
    elif method == 'ratio':
        epsilon = 1e-10
        pre_chip = np.where(pre_chip <= 0, epsilon, pre_chip)
        change = post_chip / pre_chip
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return change


def apply_threshold(change_map, method='fixed', fixed_value=-5.0, percentile=5):
    """Apply threshold to identify flooded areas."""
    valid_mask = np.isfinite(change_map)
    valid_data = change_map[valid_mask]
    
    if len(valid_data) < 100:
        raise ValueError(f"Insufficient valid data: only {len(valid_data)} valid pixels")
    
    if method == 'otsu':
        threshold = threshold_otsu(valid_data)
        flood_mask = change_map < threshold
        
    elif method == 'fixed':
        # Fixed threshold - RECOMMENDED for SAR flood detection
        flood_mask = change_map < fixed_value
        
    elif method == 'percentile':
        # Use bottom percentile as threshold
        threshold = np.percentile(valid_data, percentile)
        flood_mask = change_map < threshold
        
    else:
        raise ValueError(f"Unknown threshold method: {method}")
    
    # Set invalid pixels to 0 (non-flooded)
    flood_mask = np.where(valid_mask, flood_mask, 0)
    
    return flood_mask.astype(np.uint8)


def remove_small_objects(mask, min_size=100):
    """
    Remove connected components smaller than min_size pixels.
    
    This is CRITICAL to remove speckle noise artifacts.
    """
    # Label connected components
    labeled_array, num_features = label(mask)
    
    # Remove small objects
    for region_label in range(1, num_features + 1):
        region_size = np.sum(labeled_array == region_label)
        if region_size < min_size:
            mask[labeled_array == region_label] = 0
    
    return mask


def post_process_mask(mask, min_size=100, apply_morph=True, kernel_size=5):
    """Clean up the flood mask with advanced filtering."""
    
    if apply_morph:
        # Morphological opening: removes small objects and noise
        kernel = np.ones((kernel_size, kernel_size))
        mask = binary_opening(mask, structure=kernel)
        
        # Morphological closing: fills small holes
        mask = binary_closing(mask, structure=kernel)
    
    # Remove small connected components
    mask = remove_small_objects(mask, min_size=min_size)
    
    return mask.astype(np.uint8)


def validate_chip_data(chip_array):
    """Validate chip data with stricter criteria."""
    nan_count = np.isnan(chip_array).sum()
    total_pixels = chip_array.size
    valid_pixels = total_pixels - nan_count
    
    inf_count = np.isinf(chip_array).sum()
    
    # Stricter validation
    if valid_pixels < MIN_VALID_PIXELS:
        return False, f"Too few valid pixels: {valid_pixels}/{total_pixels}"
    
    if nan_count > total_pixels * 0.5:  # CHANGED: More lenient (was 0.8)
        return False, f"Too many NaN values: {nan_count}/{total_pixels}"
    
    if inf_count > 0:
        return False, f"Contains infinite values: {inf_count}"
    
    # Check if all values are the same
    unique_vals = np.unique(chip_array[~np.isnan(chip_array)])
    if len(unique_vals) < 10:  # CHANGED: Need more variation (was 2)
        return False, "Insufficient variation in data"
    
    return True, "Valid"


def process_chip_pair(pre_path, post_path, output_path):
    """Process a single chip pair with SAR-specific preprocessing."""
    try:
        with rasterio.open(pre_path) as src_pre:
            pre_chip = src_pre.read(1).astype(np.float32)
            profile = src_pre.profile.copy()
        
        with rasterio.open(post_path) as src_post:
            post_chip = src_post.read(1).astype(np.float32)
        
        # Validate chips
        is_valid_pre, msg_pre = validate_chip_data(pre_chip)
        if not is_valid_pre:
            raise ValueError(f"Pre-flood: {msg_pre}")
        
        is_valid_post, msg_post = validate_chip_data(post_chip)
        if not is_valid_post:
            raise ValueError(f"Post-flood: {msg_post}")
        
        # Replace NaN values
        if np.isnan(pre_chip).any():
            pre_mean = np.nanmean(pre_chip)
            pre_chip = np.nan_to_num(pre_chip, nan=pre_mean)
        
        if np.isnan(post_chip).any():
            post_mean = np.nanmean(post_chip)
            post_chip = np.nan_to_num(post_chip, nan=post_mean)
        
        # STEP 1: Apply Lee speckle filter (CRITICAL!)
        if APPLY_SPECKLE_FILTER:
            pre_chip = lee_filter(pre_chip, size=SPECKLE_FILTER_SIZE)
            post_chip = lee_filter(post_chip, size=SPECKLE_FILTER_SIZE)
        
        # STEP 2: Calculate change
        change_map = calculate_change(pre_chip, post_chip, method=CHANGE_METHOD)
        
        # STEP 3: Apply threshold
        flood_mask = apply_threshold(
            change_map, 
            method=THRESHOLD_METHOD,
            fixed_value=FIXED_THRESHOLD,
            percentile=PERCENTILE_VALUE
        )
        
        # STEP 4: Advanced post-processing
        flood_mask = post_process_mask(
            flood_mask,
            min_size=MIN_OBJECT_SIZE,
            apply_morph=APPLY_MORPHOLOGY,
            kernel_size=MORPH_KERNEL_SIZE
        )
        
        # STEP 5: Save the mask
        profile.update(dtype=rasterio.uint8, count=1, compress='LZW')
        
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(flood_mask, 1)
        
        return True
        
    except Exception as e:
        return False


def stitch_flood_masks(mask_dir, district_name, output_dir):
    """Stitch all chip-level flood masks into a single district-level GeoTIFF."""
    
    mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) 
                        if f.endswith('.tif')])
    
    if not mask_files:
        print(f"   ‚ö†Ô∏è No mask files found in {mask_dir}")
        return
    
    print(f"   Stitching {len(mask_files)} mask chips...")
    
    sources = [rasterio.open(f) for f in mask_files]
    stitched_array, out_transform = merge(sources)
    
    out_meta = sources[0].profile.copy()
    out_meta.update({
        "driver": "GTiff",
        "height": stitched_array.shape[1],
        "width": stitched_array.shape[2],
        "transform": out_transform,
        "count": 1,
        "dtype": 'uint8',
        "compress": 'LZW'
    })
    
    output_path = os.path.join(output_dir, f'{district_name}_Flood_Map.tif')
    with rasterio.open(output_path, "w", **out_meta) as dest:
        dest.write(stitched_array)
    
    for src in sources:
        src.close()
    
    print(f"   ‚úÖ Final flood map saved: {output_path}")
    
    # Calculate flood statistics
    total_pixels = stitched_array.size
    flood_pixels = np.sum(stitched_array == 1)
    flood_percentage = (flood_pixels / total_pixels) * 100
    
    # Sanity check warning
    if flood_percentage > 50:
        print(f"   ‚ö†Ô∏è WARNING: {flood_percentage:.2f}% flooded - this seems too high!")
        print(f"   Consider increasing FIXED_THRESHOLD (try -4.0 or -3.0)")
    else:
        print(f"   üìä Flooded area: {flood_percentage:.2f}% of total pixels")


# =================================================================
# MAIN EXECUTION
# =================================================================

print("="*70)
print("PHASE 1A: IMPROVED SAR CHANGE DETECTION WITH DESPECKLING")
print("="*70)
print(f"\nConfiguration:")
print(f"  Change method: {CHANGE_METHOD}")
print(f"  Threshold method: {THRESHOLD_METHOD}")
if THRESHOLD_METHOD == 'fixed':
    print(f"  Fixed threshold: {FIXED_THRESHOLD} dB")
elif THRESHOLD_METHOD == 'percentile':
    print(f"  Percentile: {PERCENTILE_VALUE}%")
print(f"  Speckle filter: {'Enabled' if APPLY_SPECKLE_FILTER else 'Disabled'} (Lee {SPECKLE_FILTER_SIZE}x{SPECKLE_FILTER_SIZE})")
print(f"  Post-processing: {'Enabled' if APPLY_MORPHOLOGY else 'Disabled'}")
print(f"  Min object size: {MIN_OBJECT_SIZE} pixels")
print(f"  Min valid pixels: {MIN_VALID_PIXELS}")
print("\n" + "="*70 + "\n")

os.makedirs(FLOOD_MASKS_DIR, exist_ok=True)
os.makedirs(FINAL_MAPS_DIR, exist_ok=True)

for district in DISTRICTS:
    print(f"üåä Processing {district}...")
    
    pre_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'pre_flood')
    post_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'post_flood')
    mask_dir = os.path.join(FLOOD_MASKS_DIR, district)
    
    if not os.path.exists(pre_dir) or not os.path.exists(post_dir):
        print(f"   ‚ö†Ô∏è Skipping {district}: Chip directories not found")
        continue
    
    os.makedirs(mask_dir, exist_ok=True)
    
    pre_chips = sorted([f for f in os.listdir(pre_dir) if f.endswith('.tif')])
    
    if not pre_chips:
        print(f"   ‚ö†Ô∏è No .tif files found in {pre_dir}")
        continue
    
    print(f"   Processing {len(pre_chips)} chip pairs...")
    
    success_count = 0
    skip_count = 0
    
    for chip_name in pre_chips:
        pre_path = os.path.join(pre_dir, chip_name)
        post_chip_name = chip_name.replace('PreFlood_Image', 'PostFlood_Image')
        post_path = os.path.join(post_dir, post_chip_name)
        
        if not os.path.exists(post_path):
            skip_count += 1
            continue
        
        mask_name = chip_name.replace('PreFlood_Image', 'Flood_Mask')
        mask_path = os.path.join(mask_dir, mask_name)
        
        if process_chip_pair(pre_path, post_path, mask_path):
            success_count += 1
        else:
            skip_count += 1
    
    print(f"   ‚úÖ Successfully processed: {success_count}/{len(pre_chips)} chips")
    print(f"   ‚ö†Ô∏è Skipped (invalid data): {skip_count}/{len(pre_chips)} chips")
    
    if success_count > 0:
        stitch_flood_masks(mask_dir, district, FINAL_MAPS_DIR)
    else:
        print(f"   ‚ö†Ô∏è No masks to stitch for {district}")
    print()

print("="*70)
print("‚úÖ PHASE 1A COMPLETE!")
print("="*70)
print(f"\nOutputs saved to:")
print(f"  Chip-level masks: {FLOOD_MASKS_DIR}")
print(f"  Final flood maps: {FINAL_MAPS_DIR}")
print("\nIMPORTANT: If flood percentages seem too high (>50%):")
print("  1. Increase FIXED_THRESHOLD to -4.0 or -3.0")
print("  2. Increase MIN_OBJECT_SIZE to 200 or 500")
print("  3. Re-run the cell")
print("\nNext steps:")
print("  1. Open GeoTIFF files in QGIS for visual assessment")
print("  2. Adjust parameters if needed")
print("  3. Compare with known flood extent maps")


PHASE 1A: IMPROVED SAR CHANGE DETECTION WITH DESPECKLING

Configuration:
  Change method: log_ratio
  Threshold method: fixed
  Fixed threshold: -5.0 dB
  Speckle filter: Enabled (Lee 5x5)
  Post-processing: Enabled
  Min object size: 100 pixels
  Min valid pixels: 1000


üåä Processing Barpeta...
   Processing 96 chip pairs...
   ‚úÖ Successfully processed: 96/96 chips
   ‚ö†Ô∏è Skipped (invalid data): 0/96 chips
   Stitching 96 mask chips...
   ‚úÖ Final flood map saved: C:\Kaam_Dhanda\Minor_Project\Final_Flood_Maps\Barpeta_Flood_Map.tif
   üìä Flooded area: 0.00% of total pixels

üåä Processing Dhemaji...
   Processing 104 chip pairs...
   ‚úÖ Successfully processed: 91/104 chips
   ‚ö†Ô∏è Skipped (invalid data): 13/104 chips
   Stitching 93 mask chips...
   ‚úÖ Final flood map saved: C:\Kaam_Dhanda\Minor_Project\Final_Flood_Maps\Dhemaji_Flood_Map.tif
   üìä Flooded area: 1.26% of total pixels

üåä Processing Lakhimpur...
   Processing 324 chip pairs...
   ‚úÖ Successfully proc