# **Image Quality Filtering and Preprocessing**
This notebook filters HLS (Harmonized Landsat Sentinel) 6-band optical images based on:
- **File size**: Only processes images above a minimum size threshold
- **Data validity**: Checks for corrupt or incomplete files
- **Spatial dimensions**: Ensures images meet minimum resolution requirements

In [1]:
# Import required libraries
import os
import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.merge import merge
import warnings
warnings.filterwarnings('ignore')

print("‚úÖ Libraries imported successfully")

‚úÖ Libraries imported successfully


In [2]:
# =================================================================
# CONFIGURATION
# =================================================================

# Input directory with HLS 6-band images
INPUT_DIR = r'C:\Kaam_Dhanda\Minor_Project\Flood_Analysis_HLS_Exports-20251121T100600Z-1-001\Flood_Analysis_HLS_Exports'

# Output directories
OUTPUT_CHIPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Output_chips_HLS'
FILTERED_IMAGES_DIR = r'C:\Kaam_Dhanda\Minor_Project\Filtered_HLS_Images'

# Quality thresholds
MIN_FILE_SIZE_MB = 5          # Minimum file size in MB (filters out corrupted/incomplete files)
MIN_WIDTH = 1000              # Minimum image width in pixels
MIN_HEIGHT = 1000             # Minimum image height in pixels
MAX_NODATA_PERCENT = 30       # Maximum percentage of no-data values allowed

# Chipping parameters
CHIP_SIZE = 224               # Standard size for deep learning (224x224 for vision models)
MAX_NODATA_PER_CHIP = 20      # Maximum % of no-data allowed per chip

# Districts to process
DISTRICTS = ['Barpeta', 'Dhemaji', 'Lakhimpur', 'Nalbari', 'Sonitpur']

print("‚úÖ Configuration loaded")
print(f"   Input directory: {INPUT_DIR}")
print(f"   Min file size: {MIN_FILE_SIZE_MB} MB")
print(f"   Min dimensions: {MIN_WIDTH}x{MIN_HEIGHT} pixels")
print(f"   Chip size: {CHIP_SIZE}x{CHIP_SIZE} pixels")

‚úÖ Configuration loaded
   Input directory: C:\Kaam_Dhanda\Minor_Project\Flood_Analysis_HLS_Exports-20251121T100600Z-1-001\Flood_Analysis_HLS_Exports
   Min file size: 5 MB
   Min dimensions: 1000x1000 pixels
   Chip size: 224x224 pixels


In [3]:
# =================================================================
# IMAGE QUALITY VALIDATION FUNCTIONS
# =================================================================

def validate_image_quality(file_path, min_size_mb, min_width, min_height, max_nodata_percent):
    """
    Validates an image file based on multiple quality criteria.
    
    Returns:
        (bool, dict): (is_valid, metadata_dict)
    """
    results = {
        'file_path': file_path,
        'file_size_mb': 0,
        'width': 0,
        'height': 0,
        'bands': 0,
        'nodata_percent': 0,
        'dtype': None,
        'crs': None,
        'is_valid': False,
        'rejection_reason': []
    }
    
    # Check 1: File exists
    if not os.path.exists(file_path):
        results['rejection_reason'].append("File not found")
        return False, results
    
    # Check 2: File size
    file_size_bytes = os.path.getsize(file_path)
    file_size_mb = file_size_bytes / (1024 * 1024)
    results['file_size_mb'] = round(file_size_mb, 2)
    
    if file_size_mb < min_size_mb:
        results['rejection_reason'].append(f"File too small ({file_size_mb:.2f} MB < {min_size_mb} MB)")
        return False, results
    
    # Check 3: Can open and read metadata
    try:
        with rasterio.open(file_path) as src:
            results['width'] = src.width
            results['height'] = src.height
            results['bands'] = src.count
            results['dtype'] = str(src.dtypes[0])
            results['crs'] = str(src.crs) if src.crs else "None"
            
            # Check dimensions
            if src.width < min_width or src.height < min_height:
                results['rejection_reason'].append(
                    f"Image too small ({src.width}x{src.height} < {min_width}x{min_height})"
                )
                return False, results
            
            # Check 4: Sample data for no-data percentage (check first band)
            sample_data = src.read(1)
            
            # Check for no-data values
            if src.nodata is not None:
                nodata_pixels = np.sum(sample_data == src.nodata)
            else:
                # If no nodata value specified, check for NaN or zeros
                nodata_pixels = np.sum(np.isnan(sample_data)) + np.sum(sample_data == 0)
            
            total_pixels = sample_data.size
            nodata_percent = (nodata_pixels / total_pixels) * 100
            results['nodata_percent'] = round(nodata_percent, 2)
            
            if nodata_percent > max_nodata_percent:
                results['rejection_reason'].append(
                    f"Too much no-data ({nodata_percent:.1f}% > {max_nodata_percent}%)"
                )
                return False, results
            
    except Exception as e:
        results['rejection_reason'].append(f"Error reading file: {str(e)}")
        return False, results
    
    # If all checks pass
    results['is_valid'] = True
    return True, results


def scan_and_filter_images(input_dir, districts, min_size_mb, min_width, min_height, max_nodata_percent):
    """
    Scans all images in the directory and filters based on quality criteria.
    
    Returns:
        dict: {district: {'pre_flood': path, 'post_flood': path, 'metadata': {...}}}
    """
    valid_images = {}
    rejected_images = {}
    
    print("\n" + "="*70)
    print("SCANNING AND VALIDATING IMAGES")
    print("="*70)
    
    for district in districts:
        print(f"\nüìç District: {district}")
        
        # Look for pre and post flood files
        pre_file = os.path.join(input_dir, f'{district}_PreFlood_HLS_6Band.tif')
        post_file = os.path.join(input_dir, f'{district}_PostFlood_HLS_6Band.tif')
        
        # Validate pre-flood image
        pre_valid, pre_results = validate_image_quality(
            pre_file, min_size_mb, min_width, min_height, max_nodata_percent
        )
        
        # Validate post-flood image
        post_valid, post_results = validate_image_quality(
            post_file, min_size_mb, min_width, min_height, max_nodata_percent
        )
        
        # Both must be valid to include the district
        if pre_valid and post_valid:
            valid_images[district] = {
                'pre_flood': pre_file,
                'post_flood': post_file,
                'pre_metadata': pre_results,
                'post_metadata': post_results
            }
            print(f"   ‚úÖ ACCEPTED")
            print(f"      Pre-flood:  {pre_results['width']}x{pre_results['height']}, "
                  f"{pre_results['file_size_mb']} MB, {pre_results['nodata_percent']}% no-data")
            print(f"      Post-flood: {post_results['width']}x{post_results['height']}, "
                  f"{post_results['file_size_mb']} MB, {post_results['nodata_percent']}% no-data")
        else:
            rejected_images[district] = {
                'pre_results': pre_results,
                'post_results': post_results
            }
            print(f"   ‚ùå REJECTED")
            if not pre_valid:
                print(f"      Pre-flood issues: {', '.join(pre_results['rejection_reason'])}")
            if not post_valid:
                print(f"      Post-flood issues: {', '.join(post_results['rejection_reason'])}")
    
    print("\n" + "="*70)
    print(f"‚úÖ Valid districts: {len(valid_images)}/{len(districts)}")
    print(f"‚ùå Rejected districts: {len(rejected_images)}/{len(districts)}")
    print("="*70)
    
    return valid_images, rejected_images

In [4]:
# =================================================================
# RUN IMAGE QUALITY SCAN
# =================================================================

# Scan and filter images based on quality criteria
valid_images, rejected_images = scan_and_filter_images(
    INPUT_DIR,
    DISTRICTS,
    MIN_FILE_SIZE_MB,
    MIN_WIDTH,
    MIN_HEIGHT,
    MAX_NODATA_PERCENT
)

# Display summary
print("\nüìä SUMMARY:")
if valid_images:
    print("\n‚úÖ Valid districts ready for processing:")
    for district in valid_images.keys():
        print(f"   ‚Ä¢ {district}")
else:
    print("\n‚ö†Ô∏è No valid images found! Please check your quality thresholds.")

if rejected_images:
    print("\n‚ùå Rejected districts:")
    for district in rejected_images.keys():
        print(f"   ‚Ä¢ {district}")


SCANNING AND VALIDATING IMAGES

üìç District: Barpeta
   ‚úÖ ACCEPTED
      Pre-flood:  1979x1326, 52.16 MB, 3.33% no-data
      Post-flood: 1979x1326, 49.03 MB, 7.16% no-data

üìç District: Dhemaji
   ‚úÖ ACCEPTED
      Pre-flood:  1979x1326, 52.16 MB, 3.33% no-data
      Post-flood: 1979x1326, 49.03 MB, 7.16% no-data

üìç District: Dhemaji
   ‚ùå REJECTED
      Post-flood issues: File too small (0.40 MB < 5 MB)

üìç District: Lakhimpur
   ‚ùå REJECTED
      Post-flood issues: File too small (0.40 MB < 5 MB)

üìç District: Lakhimpur
   ‚ùå REJECTED
      Post-flood issues: Too much no-data (94.3% > 30%)

üìç District: Nalbari
   ‚ùå REJECTED
      Post-flood issues: Too much no-data (94.3% > 30%)

üìç District: Nalbari
   ‚úÖ ACCEPTED
      Pre-flood:  1559x1300, 40.97 MB, 2.08% no-data
      Post-flood: 1559x1300, 37.39 MB, 10.68% no-data

üìç District: Sonitpur
   ‚úÖ ACCEPTED
      Pre-flood:  1559x1300, 40.97 MB, 2.08% no-data
      Post-flood: 1559x1300, 37.39 MB, 10.68%

In [5]:
# =================================================================
# CHIPPING FUNCTION FOR VALID IMAGES
# =================================================================

def chip_image_with_quality_check(input_filepath, output_directory, chip_size, max_nodata_percent):
    """
    Cuts a large GeoTIFF into smaller chips, only saving high-quality chips.
    
    Returns:
        int: Number of valid chips created
    """
    try:
        src = rasterio.open(input_filepath)
    except Exception as e:
        print(f"   ‚ùå Error opening file: {e}")
        return 0
    
    width = src.width
    height = src.height
    count = 0
    skipped = 0
    
    print(f"   Processing {os.path.basename(input_filepath)}...")
    print(f"   Image dimensions: {width}x{height}, {src.count} bands")
    
    # Loop through the image in chunks
    for i in range(0, height, chip_size):
        for j in range(0, width, chip_size):
            
            # Define the window
            window = Window(j, i, min(chip_size, width - j), min(chip_size, height - i))
            transform = src.window_transform(window)
            
            # Read all bands for this chip
            chip_data = src.read(window=window)
            
            # Quality check 1: Check chip size (skip edge chips that are too small)
            if window.height < chip_size * 0.5 or window.width < chip_size * 0.5:
                skipped += 1
                continue
            
            # Quality check 2: Check for no-data percentage (check first band as representative)
            first_band = chip_data[0]
            
            if src.nodata is not None:
                nodata_pixels = np.sum(first_band == src.nodata)
            else:
                nodata_pixels = np.sum(np.isnan(first_band)) + np.sum(first_band == 0)
            
            total_pixels = first_band.size
            nodata_percent = (nodata_pixels / total_pixels) * 100
            
            if nodata_percent > max_nodata_percent:
                skipped += 1
                continue
            
            # Quality check 3: Check for data variation (avoid uniform chips)
            if np.std(first_band) < 0.01:  # Very low standard deviation
                skipped += 1
                continue
            
            # Update metadata for the chip
            profile = src.profile.copy()
            profile.update({
                'height': window.height,
                'width': window.width,
                'transform': transform,
                'compress': 'LZW'
            })
            
            # Save the chip
            base_filename = os.path.basename(input_filepath)
            file_stem = base_filename.replace('.tif', '')
            chip_filename = f'{file_stem}_chip_{count}.tif'
            output_path = os.path.join(output_directory, chip_filename)
            
            try:
                with rasterio.open(output_path, 'w', **profile) as dst:
                    dst.write(chip_data)
                count += 1
            except Exception as e:
                print(f"   ‚ö†Ô∏è Failed to write chip: {e}")
    
    src.close()
    
    print(f"   ‚úÖ Created {count} valid chips (skipped {skipped} low-quality chips)")
    return count

In [6]:
# =================================================================
# PROCESS VALID IMAGES - CREATE CHIPS
# =================================================================

if not valid_images:
    print("‚ö†Ô∏è No valid images to process. Adjust quality thresholds if needed.")
else:
    print("\n" + "="*70)
    print("CHIPPING VALID IMAGES")
    print("="*70)
    
    # Create output directories
    os.makedirs(OUTPUT_CHIPS_DIR, exist_ok=True)
    
    chip_statistics = {}
    
    for district, files in valid_images.items():
        print(f"\nüèûÔ∏è Processing {district}...")
        
        # Create district output directories
        pre_output_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'pre_flood')
        post_output_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'post_flood')
        
        os.makedirs(pre_output_dir, exist_ok=True)
        os.makedirs(post_output_dir, exist_ok=True)
        
        # Chip pre-flood image
        print(f"\n   Pre-flood image:")
        pre_chip_count = chip_image_with_quality_check(
            files['pre_flood'],
            pre_output_dir,
            CHIP_SIZE,
            MAX_NODATA_PER_CHIP
        )
        
        # Chip post-flood image
        print(f"\n   Post-flood image:")
        post_chip_count = chip_image_with_quality_check(
            files['post_flood'],
            post_output_dir,
            CHIP_SIZE,
            MAX_NODATA_PER_CHIP
        )
        
        chip_statistics[district] = {
            'pre_flood_chips': pre_chip_count,
            'post_flood_chips': post_chip_count
        }
    
    # Display final summary
    print("\n" + "="*70)
    print("üìä CHIPPING SUMMARY")
    print("="*70)
    
    total_chips = 0
    for district, stats in chip_statistics.items():
        district_total = stats['pre_flood_chips'] + stats['post_flood_chips']
        total_chips += district_total
        print(f"\n{district}:")
        print(f"   Pre-flood:  {stats['pre_flood_chips']} chips")
        print(f"   Post-flood: {stats['post_flood_chips']} chips")
        print(f"   Total:      {district_total} chips")
    
    print(f"\n{'='*70}")
    print(f"‚úÖ Total chips created: {total_chips}")
    print(f"üìÅ Output directory: {OUTPUT_CHIPS_DIR}")
    print(f"{'='*70}")


CHIPPING VALID IMAGES

üèûÔ∏è Processing Barpeta...

   Pre-flood image:
   Processing Barpeta_PreFlood_HLS_6Band.tif...
   Image dimensions: 1979x1326, 6 bands
   ‚úÖ Created 52 valid chips (skipped 2 low-quality chips)

   Post-flood image:
   Processing Barpeta_PostFlood_HLS_6Band.tif...
   Image dimensions: 1979x1326, 6 bands
   ‚úÖ Created 52 valid chips (skipped 2 low-quality chips)

   Post-flood image:
   Processing Barpeta_PostFlood_HLS_6Band.tif...
   Image dimensions: 1979x1326, 6 bands
   ‚úÖ Created 53 valid chips (skipped 1 low-quality chips)

üèûÔ∏è Processing Nalbari...

   Pre-flood image:
   Processing Nalbari_PreFlood_HLS_6Band.tif...
   Image dimensions: 1559x1300, 6 bands
   ‚úÖ Created 53 valid chips (skipped 1 low-quality chips)

üèûÔ∏è Processing Nalbari...

   Pre-flood image:
   Processing Nalbari_PreFlood_HLS_6Band.tif...
   Image dimensions: 1559x1300, 6 bands
   ‚úÖ Created 42 valid chips (skipped 0 low-quality chips)

   Post-flood image:
   Processing

---
## Optional: Copy Valid Full-Size Images
You can copy the validated full-size images to a separate directory for archival purposes.

In [7]:
# Optional: Copy validated full-size images to a separate directory
import shutil

if valid_images:
    print("üì¶ Copying validated full-size images...")
    os.makedirs(FILTERED_IMAGES_DIR, exist_ok=True)
    
    for district, files in valid_images.items():
        # Copy pre-flood image
        pre_dest = os.path.join(FILTERED_IMAGES_DIR, f'{district}_PreFlood_HLS_6Band.tif')
        shutil.copy2(files['pre_flood'], pre_dest)
        
        # Copy post-flood image
        post_dest = os.path.join(FILTERED_IMAGES_DIR, f'{district}_PostFlood_HLS_6Band.tif')
        shutil.copy2(files['post_flood'], post_dest)
        
        print(f"   ‚úÖ Copied {district} images")
    
    print(f"\n‚úÖ All validated images copied to: {FILTERED_IMAGES_DIR}")
else:
    print("‚ö†Ô∏è No valid images to copy")

üì¶ Copying validated full-size images...
   ‚úÖ Copied Barpeta images
   ‚úÖ Copied Nalbari images

‚úÖ All validated images copied to: C:\Kaam_Dhanda\Minor_Project\Filtered_HLS_Images


In [9]:
import numpy as np
import rasterio
import torch
import os

# =======================================================================
# CONFIGURATION - HLS/SENTINEL-2 PARAMETERS (for Prithvi)
# =======================================================================

# NOTE: These are official normalization parameters for HLS (Sentinel-2) data 
# scaled to 0-10000, which are commonly used with Prithvi models.
HLS_NORM_MEANS = np.array([1353, 1146, 989, 2661, 2378, 1782], dtype=np.float32) 
HLS_NORM_STDS = np.array([870, 891, 1007, 1251, 1251, 1140], dtype=np.float32)
SCALE_FACTOR = 10000.0 

# IMPORTANT: SET YOUR ROOT DIRECTORY
ROOT_CHIPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Output_chips_HLS'

# Dictionary to store all processed tensors
processed_tensors = {}

# =======================================================================
# CORE PROCESSING FUNCTION
# =======================================================================

def preprocess_hls_chip(file_path, means, stds, scale_factor):
    """
    Reads a 6-band HLS GeoTIFF, normalizes it, and converts it to a 
    PyTorch Tensor (1, C=6, H, W).
    """
    try:
        with rasterio.open(file_path) as src:
            # Read all 6 bands: (C, H, W)
            data = src.read().astype(np.float32)
            
            if data.shape[0] != 6:
                raise ValueError(f"Skipping {os.path.basename(file_path)}: Expected 6 bands, found {data.shape[0]}.")

    except rasterio.RasterioIOError:
        print(f"Error: Could not open or read {file_path}.")
        return None, None

    # 1. Scaling (converting 0-10000 range to 0-1 range)
    data = np.clip(data, 0, 10000) / scale_factor 
    
    # 2. Normalization (Z-Score)
    # Reshape means/stds for broadcasting: (Channels, 1, 1)
    means_reshaped = means.reshape(6, 1, 1) / scale_factor
    stds_reshaped = stds.reshape(6, 1, 1) / scale_factor
    
    normalized_data = (data - means_reshaped) / stds_reshaped

    # 3. Convert to PyTorch Tensor and add batch dimension (1, C, H, W)
    tensor = torch.from_numpy(normalized_data)
    tensor = tensor.unsqueeze(0) 

    return tensor

# =======================================================================
# BATCH EXECUTION (Populates the processed_tensors dictionary)
# =======================================================================

print(f"Starting batch HLS tensor conversion from: {ROOT_CHIPS_DIR}")

for district_name in os.listdir(ROOT_CHIPS_DIR):
    district_path = os.path.join(ROOT_CHIPS_DIR, district_name)
    
    if not os.path.isdir(district_path):
        continue

    processed_tensors[district_name] = {'pre_flood': [], 'post_flood': []}
    
    for phase in ['pre_flood', 'post_flood']:
        phase_path = os.path.join(district_path, phase)
        
        if not os.path.isdir(phase_path): continue

        for chip_filename in os.listdir(phase_path):
            if chip_filename.endswith('.tif'):
                chip_file_path = os.path.join(phase_path, chip_filename)
                
                # Run the core pre-processing function
                tensor = preprocess_hls_chip(
                    chip_file_path, HLS_NORM_MEANS, HLS_NORM_STDS, SCALE_FACTOR
                )
                
                if tensor is not None:
                    # Store the resulting tensor
                    processed_tensors[district_name][phase].append(tensor)

# --- Final Check ---
total_tensors = sum(len(p['pre_flood']) + len(p['post_flood']) for d, p in processed_tensors.items())
print(f"\n=======================================================")
print(f"‚úÖ Tensor Conversion Complete. Total Tensors Created: {total_tensors}")
print("=======================================================")

Starting batch HLS tensor conversion from: C:\Kaam_Dhanda\Minor_Project\Output_chips_HLS

‚úÖ Tensor Conversion Complete. Total Tensors Created: 185

‚úÖ Tensor Conversion Complete. Total Tensors Created: 185


# Fine Tuning 

## **Step 1: Load Ground Truth Labels**
Before fine-tuning, you need to prepare your ground truth flood masks. These should be binary masks where:
- **0** = Non-flooded areas
- **1** = Flooded areas

You can create these using:
- Manual annotation in QGIS
- SAR-based change detection masks (from your previous work)
- Combination of multiple data sources

In [14]:
# =================================================================
# LOAD GROUND TRUTH LABELS
# =================================================================

import torch
import rasterio
import numpy as np
from pathlib import Path

# Configure paths to your ground truth masks
GROUND_TRUTH_DIR = r'C:\Kaam_Dhanda\Minor_Project\Flood_Masks'  # Directory with your SAR-based masks
LABEL_CHIP_SIZE = 224  # Should match your HLS chip size

def load_ground_truth_chip(mask_path, target_size=224):
    """
    Load a ground truth mask chip and convert to tensor.
    
    Args:
        mask_path: Path to the mask GeoTIFF
        target_size: Expected size (will resize if needed)
    
    Returns:
        torch.Tensor: Shape (1, 1, H, W) with binary values {0, 1}
    """
    try:
        with rasterio.open(mask_path) as src:
            mask = src.read(1).astype(np.float32)
            
            # Ensure binary values (0 and 1)
            mask = (mask > 0).astype(np.float32)
            
            # Convert to tensor: (H, W) -> (1, 1, H, W)
            tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
            
            # Resize if needed
            if mask.shape[0] != target_size or mask.shape[1] != target_size:
                tensor = torch.nn.functional.interpolate(
                    tensor, 
                    size=(target_size, target_size), 
                    mode='nearest'
                )
            
            return tensor
            
    except Exception as e:
        print(f"Error loading mask {mask_path}: {e}")
        return None


def pair_images_with_labels(image_tensors_dict, ground_truth_dir):
    """
    Match HLS image chips with corresponding ground truth masks.
    
    Args:
        image_tensors_dict: The processed_tensors dictionary from previous cells
        ground_truth_dir: Directory containing ground truth masks
    
    Returns:
        tuple: (matched_images, matched_labels) as lists of tensors
    """
    matched_images = []
    matched_labels = []
    
    print("\n" + "="*70)
    print("PAIRING IMAGES WITH GROUND TRUTH LABELS")
    print("="*70)
    
    for district, phases in image_tensors_dict.items():
        print(f"\nüìç District: {district}")
        
        # We'll use post-flood images for training (when flooding is visible)
        post_flood_tensors = phases.get('post_flood', [])
        
        # Look for corresponding mask directory
        mask_dir = Path(ground_truth_dir) / district
        
        if not mask_dir.exists():
            print(f"   ‚ö†Ô∏è No mask directory found for {district}")
            continue
        
        # Get all mask files
        mask_files = sorted(mask_dir.glob('*Flood_Mask*.tif'))
        
        if not mask_files:
            print(f"   ‚ö†Ô∏è No mask files found in {mask_dir}")
            continue
        
        # Match images with masks (assuming same naming convention)
        pairs_found = 0
        for img_idx, img_tensor in enumerate(post_flood_tensors):
            # Try to find corresponding mask
            if img_idx < len(mask_files):
                mask_tensor = load_ground_truth_chip(str(mask_files[img_idx]))
                
                if mask_tensor is not None:
                    matched_images.append(img_tensor)
                    matched_labels.append(mask_tensor)
                    pairs_found += 1
        
        print(f"   ‚úÖ Paired {pairs_found} image-mask pairs")
    
    print("\n" + "="*70)
    print(f"‚úÖ Total training pairs: {len(matched_images)}")
    print("="*70)
    
    return matched_images, matched_labels


# Execute pairing (using the processed_tensors from cell 9)
if 'processed_tensors' in globals() and processed_tensors:
    train_images, train_labels = pair_images_with_labels(processed_tensors, GROUND_TRUTH_DIR)
    
    if train_images and train_labels:
        print(f"\nüìä Training data ready:")
        print(f"   Images: {len(train_images)} samples")
        print(f"   Labels: {len(train_labels)} samples")
        print(f"   Image shape: {train_images[0].shape}")
        print(f"   Label shape: {train_labels[0].shape}")
    else:
        print("\n‚ö†Ô∏è No training pairs found. Please check your ground truth directory.")
else:
    print("‚ö†Ô∏è Please run cell 9 first to generate processed_tensors")


PAIRING IMAGES WITH GROUND TRUTH LABELS

üìç District: Barpeta
   ‚úÖ Paired 53 image-mask pairs

üìç District: Nalbari
   ‚úÖ Paired 53 image-mask pairs

üìç District: Nalbari
   ‚úÖ Paired 38 image-mask pairs

‚úÖ Total training pairs: 91

üìä Training data ready:
   Images: 91 samples
   Labels: 91 samples
   Image shape: torch.Size([1, 6, 224, 224])
   Label shape: torch.Size([1, 1, 224, 224])
   ‚úÖ Paired 38 image-mask pairs

‚úÖ Total training pairs: 91

üìä Training data ready:
   Images: 91 samples
   Labels: 91 samples
   Image shape: torch.Size([1, 6, 224, 224])
   Label shape: torch.Size([1, 1, 224, 224])


## **Step 2: Define the Fine-Tuning Model**
We'll create a segmentation model using a pre-trained encoder (backbone) and add a decoder for flood detection.

In [15]:
# =================================================================
# FLOOD SEGMENTATION MODEL
# =================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp  # Popular library for segmentation

class FloodSegmentationModel(nn.Module):
    """
    Flood detection segmentation model using a pre-trained encoder.
    Uses U-Net architecture with a ResNet or EfficientNet encoder.
    """
    def __init__(self, encoder_name='resnet34', encoder_weights='imagenet', in_channels=6, classes=1):
        """
        Args:
            encoder_name: Backbone encoder (resnet34, resnet50, efficientnet-b0, etc.)
            encoder_weights: Pre-trained weights ('imagenet' or None)
            in_channels: Number of input channels (6 for HLS)
            classes: Number of output classes (1 for binary segmentation)
        """
        super().__init__()
        
        # Create U-Net model with pre-trained encoder
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=None  # We'll apply sigmoid during training/inference
        )
        
    def forward(self, x):
        """Forward pass through the model."""
        return self.model(x)


# Alternative: Use DeepLabV3+ for better performance
class FloodDeepLabModel(nn.Module):
    """
    Alternative model using DeepLabV3+ architecture.
    Generally performs better than U-Net for segmentation tasks.
    """
    def __init__(self, encoder_name='resnet50', encoder_weights='imagenet', in_channels=6, classes=1):
        super().__init__()
        
        self.model = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=None
        )
        
    def forward(self, x):
        return self.model(x)


# Initialize the model
print("üîß Initializing flood segmentation model...")

# Choose your model architecture
MODEL_TYPE = 'unet'  # Options: 'unet' or 'deeplabv3plus'
ENCODER = 'resnet34'  # Options: 'resnet34', 'resnet50', 'efficientnet-b0', 'mobilenet_v2'
INPUT_CHANNELS = 6  # HLS has 6 bands
OUTPUT_CLASSES = 1  # Binary segmentation (flood vs non-flood)

if MODEL_TYPE == 'unet':
    model = FloodSegmentationModel(
        encoder_name=ENCODER,
        encoder_weights='imagenet',
        in_channels=INPUT_CHANNELS,
        classes=OUTPUT_CLASSES
    )
else:
    model = FloodDeepLabModel(
        encoder_name=ENCODER,
        encoder_weights='imagenet',
        in_channels=INPUT_CHANNELS,
        classes=OUTPUT_CLASSES
    )

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"‚úÖ Model initialized: {MODEL_TYPE.upper()} with {ENCODER} encoder")
print(f"   Device: {device}")
print(f"   Input channels: {INPUT_CHANNELS}")
print(f"   Output classes: {OUTPUT_CLASSES}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")

üîß Initializing flood segmentation model...
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to C:\Users\Nipun/.cache\torch\hub\checkpoints\resnet34-333f7ec4.pth
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to C:\Users\Nipun/.cache\torch\hub\checkpoints\resnet34-333f7ec4.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 83.3M/83.3M [00:07<00:00, 11.7MB/s]



‚úÖ Model initialized: UNET with resnet34 encoder
   Device: cpu
   Input channels: 6
   Output classes: 1
   Total parameters: 24,445,777
   Trainable parameters: 24,445,777


## **Step 3: Setup Training Pipeline**
Configure the dataset, data loaders, loss function, and optimizer.

In [17]:
# =================================================================
# TRAINING SETUP
# =================================================================

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn as nn

# -------------------------
# Custom Dataset
# -------------------------
class FloodDataset(Dataset):
    """PyTorch Dataset for flood detection training."""
    
    def __init__(self, images, labels):
        """
        Args:
            images: List of image tensors (1, C, H, W)
            labels: List of label tensors (1, 1, H, W)
        """
        self.images = images
        self.labels = labels
        
        assert len(images) == len(labels), "Images and labels must have same length"
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Remove batch dimension for DataLoader (it will add it back)
        image = self.images[idx].squeeze(0)  # (C, H, W)
        label = self.labels[idx].squeeze(0)  # (1, H, W)
        
        return image, label


# -------------------------
# Loss Functions
# -------------------------
class DiceLoss(nn.Module):
    """Dice Loss for segmentation tasks."""
    
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice


class CombinedLoss(nn.Module):
    """Combination of BCE and Dice loss for better training."""
    
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
    
    def forward(self, pred, target):
        bce = self.bce_loss(pred, target)
        dice = self.dice_loss(pred, target)
        return self.bce_weight * bce + self.dice_weight * dice


# -------------------------
# Setup Training Components
# -------------------------

# Check if we have training data
if 'train_images' not in globals() or not train_images:
    print("‚ö†Ô∏è No training data found. Please run the ground truth loading cell first.")
else:
    print("\n" + "="*70)
    print("SETTING UP TRAINING PIPELINE")
    print("="*70)
    
    # Create dataset
    full_dataset = FloodDataset(train_images, train_labels)
    
    # Split into train and validation (80/20 split)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    print(f"\nüìä Dataset split:")
    print(f"   Training samples: {len(train_dataset)}")
    print(f"   Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    BATCH_SIZE = 8  # Adjust based on your GPU memory
    NUM_WORKERS = 4  # Adjust based on your CPU cores
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print(f"\nüì¶ Data loaders created:")
    print(f"   Batch size: {BATCH_SIZE}")
    print(f"   Training batches: {len(train_loader)}")
    print(f"   Validation batches: {len(val_loader)}")
    
    # Define loss function
    criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)
    print(f"\nüéØ Loss function: Combined (BCE + Dice)")
    
    # Define optimizer
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    
    optimizer = AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )
    
    # Learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5
    )
    
    print(f"\n‚öôÔ∏è Optimizer: AdamW")
    print(f"   Learning rate: {LEARNING_RATE}")
    print(f"   Weight decay: {WEIGHT_DECAY}")
    print(f"   LR scheduler: ReduceLROnPlateau")
    
    print("\n" + "="*70)
    print("‚úÖ Training pipeline ready!")
    print("="*70)


SETTING UP TRAINING PIPELINE

üìä Dataset split:
   Training samples: 72
   Validation samples: 19

üì¶ Data loaders created:
   Batch size: 8
   Training batches: 9
   Validation batches: 3

üéØ Loss function: Combined (BCE + Dice)

‚öôÔ∏è Optimizer: AdamW
   Learning rate: 0.0001
   Weight decay: 1e-05
   LR scheduler: ReduceLROnPlateau

‚úÖ Training pipeline ready!


## **Step 4: Training Loop**
Execute the fine-tuning process with validation and checkpoint saving.

In [18]:
# =================================================================
# TRAINING LOOP
# =================================================================

import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# Create checkpoint directory
CHECKPOINT_DIR = r'C:\Kaam_Dhanda\Minor_Project\model_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Training configuration
NUM_EPOCHS = 50
EARLY_STOP_PATIENCE = 10
SAVE_BEST_ONLY = True

# Metrics tracking
train_losses = []
val_losses = []
best_val_loss = float('inf')
epochs_without_improvement = 0


def calculate_iou(pred, target, threshold=0.5):
    """Calculate Intersection over Union (IoU) metric."""
    pred = (torch.sigmoid(pred) > threshold).float()
    
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    
    iou = (intersection + 1e-6) / (union + 1e-6)
    return iou.item()


def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    running_iou = 0.0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (images, labels) in enumerate(pbar):
        # Move data to device
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        running_loss += loss.item()
        running_iou += calculate_iou(outputs, labels)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'iou': running_iou / (batch_idx + 1)
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_iou = running_iou / len(train_loader)
    
    return epoch_loss, epoch_iou


def validate(model, val_loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    running_iou = 0.0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for batch_idx, (images, labels) in enumerate(pbar):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            running_iou += calculate_iou(outputs, labels)
            
            pbar.set_postfix({
                'loss': running_loss / (batch_idx + 1),
                'iou': running_iou / (batch_idx + 1)
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_iou = running_iou / len(val_loader)
    
    return epoch_loss, epoch_iou


# Check if training data is ready
if 'train_loader' not in globals():
    print("‚ö†Ô∏è Training pipeline not set up. Please run previous cells first.")
else:
    print("\n" + "="*70)
    print("STARTING FINE-TUNING")
    print("="*70)
    print(f"Epochs: {NUM_EPOCHS}")
    print(f"Device: {device}")
    print(f"Checkpoint directory: {CHECKPOINT_DIR}")
    print("="*70 + "\n")
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\nüìÖ Epoch {epoch + 1}/{NUM_EPOCHS}")
        print("-" * 70)
        
        # Train
        train_loss, train_iou = train_one_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        
        # Validate
        val_loss, val_iou = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Print epoch summary
        print(f"\nüìä Epoch {epoch + 1} Summary:")
        print(f"   Train Loss: {train_loss:.4f} | Train IoU: {train_iou:.4f}")
        print(f"   Val Loss:   {val_loss:.4f} | Val IoU:   {val_iou:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            
            checkpoint_path = os.path.join(CHECKPOINT_DIR, 'best_model.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'val_iou': val_iou,
            }, checkpoint_path)
            
            print(f"   ‚úÖ Saved best model (Val Loss: {val_loss:.4f})")
        else:
            epochs_without_improvement += 1
        
        # Early stopping
        if epochs_without_improvement >= EARLY_STOP_PATIENCE:
            print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch + 1} epochs")
            print(f"   No improvement for {EARLY_STOP_PATIENCE} epochs")
            break
    
    print("\n" + "="*70)
    print("‚úÖ FINE-TUNING COMPLETE!")
    print("="*70)
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Model saved to: {CHECKPOINT_DIR}")
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(range(1, len(train_losses) + 1), train_losses, 'o-', label='Train')
    plt.plot(range(1, len(val_losses) + 1), val_losses, 's-', label='Val')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Progression')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(CHECKPOINT_DIR, 'training_curves.png'), dpi=150)
    plt.show()
    
    print(f"\nüìà Training curves saved to: {os.path.join(CHECKPOINT_DIR, 'training_curves.png')}")


STARTING FINE-TUNING
Epochs: 50
Device: cpu
Checkpoint directory: C:\Kaam_Dhanda\Minor_Project\model_checkpoints


üìÖ Epoch 1/50
----------------------------------------------------------------------


Training:   0%|          | 0/9 [00:05<?, ?it/s]



RuntimeError: DataLoader worker (pid(s) 31080, 24760, 17068, 17376) exited unexpectedly

## **Step 5: Inference and Prediction**
Use the trained model to generate flood predictions on new data.

In [19]:
# =================================================================
# INFERENCE AND FLOOD MAP GENERATION
# =================================================================

import torch
import numpy as np
import rasterio
from rasterio.merge import merge
from pathlib import Path

# Output directory for predictions
PREDICTION_OUTPUT_DIR = r'C:\Kaam_Dhanda\Minor_Project\Flood_Maps_AI'
os.makedirs(PREDICTION_OUTPUT_DIR, exist_ok=True)


def load_best_model(checkpoint_path, model, device):
    """Load the best model from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(f"‚úÖ Loaded model from epoch {checkpoint['epoch']}")
    print(f"   Validation Loss: {checkpoint['val_loss']:.4f}")
    print(f"   Validation IoU: {checkpoint['val_iou']:.4f}")
    return model


def predict_chip(model, image_tensor, device, threshold=0.5):
    """
    Generate prediction for a single chip.
    
    Args:
        model: Trained model
        image_tensor: Input image tensor (1, C, H, W)
        device: torch device
        threshold: Probability threshold for binary classification
    
    Returns:
        numpy.ndarray: Binary mask (H, W)
    """
    model.eval()
    with torch.no_grad():
        # Move to device
        image = image_tensor.to(device)
        
        # Get prediction
        output = model(image)
        
        # Apply sigmoid and threshold
        prob = torch.sigmoid(output)
        pred_mask = (prob > threshold).cpu().numpy()
        
        # Remove batch and channel dimensions
        pred_mask = pred_mask.squeeze()
        
    return pred_mask.astype(np.uint8)


def generate_flood_maps(model, image_tensors_dict, output_dir, device):
    """
    Generate flood maps for all districts.
    
    Args:
        model: Trained model
        image_tensors_dict: Dictionary of tensors from preprocessing
        output_dir: Where to save predictions
        device: torch device
    """
    print("\n" + "="*70)
    print("GENERATING FLOOD PREDICTIONS")
    print("="*70)
    
    for district, phases in image_tensors_dict.items():
        print(f"\nüèûÔ∏è Processing {district}...")
        
        # Create district output directory
        district_dir = os.path.join(output_dir, district)
        os.makedirs(district_dir, exist_ok=True)
        
        # Process post-flood images (where flooding is visible)
        post_flood_tensors = phases.get('post_flood', [])
        
        if not post_flood_tensors:
            print(f"   ‚ö†Ô∏è No post-flood images found")
            continue
        
        predictions_made = 0
        
        for idx, image_tensor in enumerate(tqdm(post_flood_tensors, desc=f"  {district}")):
            # Generate prediction
            pred_mask = predict_chip(model, image_tensor, device)
            
            # Save prediction as GeoTIFF
            # Note: You'll need to get the geospatial metadata from original chips
            output_path = os.path.join(district_dir, f'{district}_FloodPrediction_chip_{idx}.tif')
            
            # For now, save as simple array (you can add geospatial info later)
            pred_mask_rgb = (pred_mask * 255).astype(np.uint8)
            
            # Create a simple GeoTIFF
            with rasterio.open(
                output_path,
                'w',
                driver='GTiff',
                height=pred_mask.shape[0],
                width=pred_mask.shape[1],
                count=1,
                dtype=rasterio.uint8,
                compress='LZW'
            ) as dst:
                dst.write(pred_mask, 1)
            
            predictions_made += 1
        
        print(f"   ‚úÖ Generated {predictions_made} prediction masks")
    
    print("\n" + "="*70)
    print(f"‚úÖ All predictions saved to: {output_dir}")
    print("="*70)


# Execute inference
if 'model' in globals() and os.path.exists(os.path.join(CHECKPOINT_DIR, 'best_model.pth')):
    print("üîÆ Starting inference with trained model...")
    
    # Load best model
    model = load_best_model(
        os.path.join(CHECKPOINT_DIR, 'best_model.pth'),
        model,
        device
    )
    
    # Generate predictions
    if 'processed_tensors' in globals() and processed_tensors:
        generate_flood_maps(model, processed_tensors, PREDICTION_OUTPUT_DIR, device)
    else:
        print("‚ö†Ô∏è No processed tensors found. Please run the preprocessing cells first.")
else:
    print("‚ö†Ô∏è Trained model not found. Please complete training first.")


# Calculate flood statistics
print("\nüìä FLOOD STATISTICS:")
for district in os.listdir(PREDICTION_OUTPUT_DIR):
    district_path = os.path.join(PREDICTION_OUTPUT_DIR, district)
    if os.path.isdir(district_path):
        mask_files = list(Path(district_path).glob('*.tif'))
        
        total_pixels = 0
        flood_pixels = 0
        
        for mask_file in mask_files:
            with rasterio.open(mask_file) as src:
                mask = src.read(1)
                total_pixels += mask.size
                flood_pixels += np.sum(mask == 1)
        
        if total_pixels > 0:
            flood_percent = (flood_pixels / total_pixels) * 100
            print(f"\n{district}:")
            print(f"   Flooded pixels: {flood_pixels:,}")
            print(f"   Total pixels: {total_pixels:,}")
            print(f"   Flooded area: {flood_percent:.2f}%")

‚ö†Ô∏è Trained model not found. Please complete training first.

üìä FLOOD STATISTICS:


In [10]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss # Loss function for binary segmentation

# --- 1. Custom Dataset Definition (Crucial for PyTorch) ---
class FloodDataset(Dataset):
    def __init__(self, input_tensors, label_tensors):
        # input_tensors are your X (HLS images)
        self.inputs = input_tensors 
        # label_tensors are your Y (Ground Truth Masks)
        self.labels = label_tensors 

    def __len__(self):
        # The length of the dataset is the number of chips
        return len(self.inputs)

    def __getitem__(self, idx):
        # Return a pair of (Image, Label) for the training loop
        return self.inputs[idx], self.labels[idx]


# --- 2. Training Loop Setup ---
def setup_training(model, train_inputs, train_labels):
    # A. Create Dataset and DataLoader
    train_dataset = FloodDataset(train_inputs, train_labels)
    # DataLoader manages batching and shuffling
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) 

    # B. Define Loss and Optimizer
    criterion = BCEWithLogitsLoss() # Good loss function for binary segmentation
    optimizer = Adam(model.parameters(), lr=1e-5) # Use a small learning rate for fine-tuning

    # C. Start Training (Conceptual loop structure)
    num_epochs = 10
    print(f"Starting fine-tuning for {num_epochs} epochs...")
    
    for epoch in range(num_epochs):
        model.train() # Set model to training mode
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            
            # --- FORWARD PASS (Model Prediction) ---
            outputs = model(inputs) 
            
            # --- BACKWARD PASS (Learning) ---
            loss = criterion(outputs, targets.float())
            loss.backward()
            optimizer.step()
        
        print(f"Epoch {epoch+1} complete. Loss: {loss.item():.4f}")


# --- 3. The FINAL STEP (Inference) ---
# After training, you replace model.train() with model.eval() and run inference 
# on the remaining (unlabeled) chips to generate your final flood masks.


In [12]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from torchmetrics.classification import BinaryJaccardIndex # For IoU metric
from torch.optim import Adam
import numpy as np
# from your notebook: from terratorch.registry import BACKBONE_REGISTRY 
# from your notebook: from PrithviFloodSegmentationModel import PrithviFloodSegmentationModel 


# =================================================================
# I. HYPERPARAMETERS & CONFIGURATION
# =================================================================
# NOTE: Replace 'your-project-id' with a unique identifier
PROJECT_NAME = 'Flood_Mapping_Assam_HLS' 
NUM_EPOCHS = 20
BATCH_SIZE = 4
LEARNING_RATE = 1e-5


# =================================================================
# II. DATA PIPELINE (The glue for your tensors)
# =================================================================

class FloodDataset(Dataset):
    """Dataset for pairing HLS input tensors (X) with Ground Truth label tensors (Y)."""
    def __init__(self, input_tensors_list, label_tensors_list):
        # input_tensors_list should be a list of (1, 6, 512, 512) HLS chips
        self.inputs = input_tensors_list 
        # label_tensors_list should be a list of (1, 1, 512, 512) binary masks
        self.labels = label_tensors_list

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        # We need to squeeze the batch dimension (0) for PyTorch Lightning, 
        # as the DataLoader will add it back.
        X = self.inputs[idx].squeeze(0) 
        Y = self.labels[idx].squeeze(0).long() # Labels must be Long type for loss function
        return X, Y


# NOTE: For this step, you must load your ground truth labels manually 
# and populate the following two lists before calling the trainer.
# train_inputs_list = [t for d in processed_tensors for t in d['pre_flood']]
# train_labels_list = [load_label(path) for path in label_paths]


# =================================================================
# III. FINE-TUNING MODULE (The brains of the operation)
# =================================================================

class PrithviFloodModule(pl.LightningModule):
    def __init__(self, output_classes=2):
        super().__init__()
        # Initialize the model structure you defined in the previous step
        self.model = PrithviFloodSegmentationModel(output_classes=output_classes)
        
        # Define the Loss Function (Binary Cross-Entropy + Dice Loss is standard for segmentation)
        # Note: BCEWithLogitsLoss is robust and combines Sigmoid + BCE
        self.criterion = nn.BCEWithLogitsLoss() 
        
        # Define the Evaluation Metric (IoU is Jaccard Index for binary problems)
        self.iou_metric = BinaryJaccardIndex().to(self.device)
        
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        # Model output is logits (before sigmoid)
        logits = self.forward(x) 
        
        # Reshape the label tensor to match the output shape for loss calculation
        # [B, 2, H, W] vs [B, 1, H, W]. Use only the background/foreground channel for loss.
        loss = self.criterion(logits[:, 1], y.float()) # Target: [B, H, W]
        
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        
        # Apply sigmoid to logits to get probabilities, then round to get prediction (0 or 1)
        preds = torch.sigmoid(logits[:, 1]).round() 
        
        val_loss = self.criterion(logits[:, 1], y.float())
        self.log('val_loss', val_loss)
        
        # Log IoU metric
        self.iou_metric.update(preds, y)
        self.log('val_iou', self.iou_metric, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

# =================================================================
# IV. EXECUTION
# =================================================================

def run_fine_tuning_pipeline(train_inputs, train_labels):
    # 1. Setup DataModule/Dataset
    train_dataset = FloodDataset(train_inputs, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # 2. Instantiate Model and Lightning Module
    model_module = PrithviFloodModule(learning_rate=LEARNING_RATE, output_classes=2)
    
    # 3. Setup Trainer
    trainer = pl.Trainer(
        max_epochs=NUM_EPOCHS,
        logger=pl.loggers.CSVLogger(save_dir='logs/', name=PROJECT_NAME),
        callbacks=[pl.callbacks.ModelCheckpoint(monitor='val_iou', mode='max')]
    )
    
    # 4. Start Fine-Tuning
    trainer.fit(model_module, train_loader)
    print("‚úÖ Fine-Tuning Complete. Model weights saved to logs/ directory.")

# --- NEXT ACTION ---
# You must execute Step 1 (create labels) and then run this pipeline.
# Example Call (Conceptual - requires actual data lists):
# run_fine_tuning_pipeline(your_input_list, your_label_list)

In [13]:
# Assuming you have loaded your Ground Truth labels into a list called train_labels_list
# and your image inputs into a list called train_inputs_list
# (These lists are created by iterating over your local file system).

# --- FINAL EXECUTION ---
# You must ensure the two lists contain matching tensors before running.
# Example: train_inputs_list = [t for d in processed_tensors for t in d['pre_flood']]
#          train_labels_list = [load_label(path) for path in your_label_files]

def execute_final_training(train_inputs_list, train_labels_list):
    if not train_inputs_list or not train_labels_list:
        print("‚ùå Error: Input or Label lists are empty. Cannot start training.")
        return

    # 1. Setup Data
    # This calls the FloodDataset and DataLoader you defined previously.
    train_dataset = FloodDataset(train_inputs_list, train_labels_list)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    # 2. Instantiate Model (Prithvi-600M wrapped in the Lightning Module)
    # The PrithviFloodModule class must be defined in your script environment.
    model_module = PrithviFloodModule(learning_rate=1e-5, output_classes=2)

    # 3. Setup Trainer
    trainer = pl.Trainer(
        max_epochs=NUM_EPOCHS,
        logger=pl.loggers.CSVLogger(save_dir='logs/', name=PROJECT_NAME),
        callbacks=[pl.callbacks.ModelCheckpoint(monitor='val_iou', mode='max')]
    )

    # 4. Start Fine-Tuning
    trainer.fit(model_module, train_loader)
    print("‚úÖ Fine-Tuning Complete. Check the 'logs/' directory for model weights.")

# NOTE: You must call this function with your actual loaded lists:
# execute_final_training(your_inputs_list, your_labels_list)

# **Chipping Images**

In [17]:
import rasterio
from rasterio.windows import Window
import numpy as np
import torch
from rasterio.merge import merge


In [None]:

# Define the input and output chip size
CHIP_SIZE = 512 

def chip_image(input_filepath, output_directory):
    """Cuts a large GeoTIFF into smaller chips."""
    with rasterio.open(input_filepath) as src:
        # Get the width and height of the entire image
        width = src.width
        height = src.height
        
        count = 0
        # Loop through the image in chunks of CHIP_SIZE
        for i in range(0, height, CHIP_SIZE):
            for j in range(0, width, CHIP_SIZE):
                
                # Define the window (area) to read from the large image
                # The window accounts for the edges where the size might be less than CHIP_SIZE
                window = Window(j, i, min(CHIP_SIZE, width - j), min(CHIP_SIZE, height - i))
                transform = src.window_transform(window)
                
                # Read the data from the defined window
                chip_data = src.read(1, window=window)

                # Skip if the chip contains mostly "no data" values (e.g., beyond your AOI)
                if np.sum(chip_data == src.nodata) / chip_data.size > 0.95:
                    continue

                # Update the metadata profile for the new small chip file
                profile = src.profile
                profile.update({
                    'height': window.height,
                    'width': window.width,
                    'transform': transform
                })
                
                # Save the chip
                output_path = f"{output_directory}/{src.name.split('/')[-1].replace('.tif', '')}_chip_{count}.tif"
                with rasterio.open(output_path, 'w', **profile) as dst:
                    dst.write(chip_data, 1)
                
                count += 1
                
    print(f"Successfully chipped {input_filepath} into {count} tiles.")


In [None]:
# =================================================================
# GLOBAL CONFIGURATION
# =================================================================

# IMPORTANT: Use 'r' strings for Windows paths to avoid SyntaxWarnings/Errors
TIF_DIR = r'C:\Kaam_Dhanda\Minor_Project\Old Images' 
OUTPUT_CHIPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Output_chips' 

# List of districts to process
DISTRICTS = ['Barpeta', 'Dhemaji', 'Lakhimpur', 'Nalbari', 'Sonitpur']
CHIP_SIZE = 512 # Standard size for deep learning input (e.g., 512x512 pixels)


# =================================================================
# CHIPPING FUNCTION (Core Logic)
# =================================================================

def chip_image(input_filepath, output_directory):
    """Cuts a large GeoTIFF into smaller, non-overlapping chips."""
    
    # 1. Safely open the input image
    try:
        src = rasterio.open(input_filepath)
    except rasterio.RasterioIOError as e:
        print(f"Error opening input file {input_filepath}: {e}")
        return

    width = src.width
    height = src.height
    count = 0
    
    # Loop through the image in chunks of CHIP_SIZE
    for i in range(0, height, CHIP_SIZE):
        for j in range(0, width, CHIP_SIZE):
            
            # Define the window (area) to read from the large image
            window = Window(j, i, min(CHIP_SIZE, width - j), min(CHIP_SIZE, height - i))
            transform = src.window_transform(window)
            
            # Read the data from the defined window (assuming single band: 'VV')
            chip_data = src.read(1, window=window)

            # Skip if the chip contains mostly 'no data' values (e.g., beyond your AOI)
            if np.sum(chip_data == src.nodata) / chip_data.size > 0.95:
                continue

            # Update the metadata profile for the new small chip file
            profile = src.profile
            profile.update({
                'height': window.height,
                'width': window.width,
                'transform': transform,
                'count': 1, # Ensure the profile reflects a single band
                'compress': 'LZW' # Optional: Add compression to reduce chip size
            })
            
            # --- CRITICAL FIX: Robust Output Path Construction ---
            # 1. Get the base filename (e.g., 'Barpeta_PreFlood_Image.tif')
            base_filename = os.path.basename(input_filepath)
            
            # 2. Remove the '.tif' extension for the chip name stem
            file_stem = base_filename.replace('.tif', '')

            # 3. Construct the final output path using os.path.join()
            chip_filename = f'{file_stem}_chip_{count}.tif'
            output_path = os.path.join(output_directory, chip_filename)
            
            # 4. Save the chip
            try:
                with rasterio.open(output_path, 'w', **profile) as dst:
                    dst.write(chip_data, 1)
                count += 1
            except rasterio.RasterioIOError as e:
                 print(f"Failed to write chip {output_path}: {e}")

    src.close()
    print(f"‚úÖ Successfully chipped {input_filepath} into {count} tiles.")


# =================================================================
# MAIN EXECUTION LOGIC
# =================================================================

files_to_chip = {}

# --- Generate File Pairs ---
for district in DISTRICTS:
    pre_file = os.path.join(TIF_DIR, f'{district}_PreFlood_Image.tif')
    post_file = os.path.join(TIF_DIR, f'{district}_PostFlood_Image.tif')
    
    if os.path.exists(pre_file) and os.path.exists(post_file):
        files_to_chip[district] = {
            'pre_flood': pre_file,
            'post_flood': post_file
        }
    else:
        print(f"‚ö†Ô∏è Skipping {district}: One or both primary files were not found.")
        
print(f"Successfully prepared {len(files_to_chip)} district pairs for chipping.")


# --- Run Chipping Process ---
for district, files in files_to_chip.items():
    print(f"\n--- Chipping files for {district} ---")
    
    # Define the output directories
    pre_output_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'pre_flood')
    post_output_dir = os.path.join(OUTPUT_CHIPS_DIR, district, 'post_flood')

    # Create the output directories if they don't exist
    os.makedirs(pre_output_dir, exist_ok=True)
    os.makedirs(post_output_dir, exist_ok=True)
    
    # Run chipping for the pre-flood image
    chip_image(
        input_filepath=files['pre_flood'],
        output_directory=pre_output_dir
    )
    
    # Run chipping for the post-flood image
    chip_image(
        input_filepath=files['post_flood'],
        output_directory=post_output_dir
    )

Successfully prepared 5 district pairs for chipping.

--- Chipping files for Barpeta ---
‚úÖ Successfully chipped C:\Kaam_Dhanda\Minor_Project\Old Images\Barpeta_PreFlood_Image.tif into 96 tiles.
‚úÖ Successfully chipped C:\Kaam_Dhanda\Minor_Project\Old Images\Barpeta_PostFlood_Image.tif into 96 tiles.

--- Chipping files for Dhemaji ---
‚úÖ Successfully chipped C:\Kaam_Dhanda\Minor_Project\Old Images\Dhemaji_PreFlood_Image.tif into 104 tiles.
‚úÖ Successfully chipped C:\Kaam_Dhanda\Minor_Project\Old Images\Dhemaji_PostFlood_Image.tif into 104 tiles.

--- Chipping files for Lakhimpur ---
‚úÖ Successfully chipped C:\Kaam_Dhanda\Minor_Project\Old Images\Lakhimpur_PreFlood_Image.tif into 324 tiles.
‚úÖ Successfully chipped C:\Kaam_Dhanda\Minor_Project\Old Images\Lakhimpur_PostFlood_Image.tif into 324 tiles.

--- Chipping files for Nalbari ---
‚úÖ Successfully chipped C:\Kaam_Dhanda\Minor_Project\Old Images\Nalbari_PreFlood_Image.tif into 80 tiles.
‚úÖ Successfully chipped C:\Kaam_Dhanda\M

In [3]:
import numpy as np
import rasterio
import torch
import os

# =======================================================================
# CONFIGURATION
# =======================================================================

# IMPORTANT: SET YOUR ROOT DIRECTORY HERE
ROOT_CHIPS_DIR = r'C:\Kaam_Dhanda\Minor_Project\Output_chips'

# Sentinel-1 Normalization Parameters for VV (based on common practice)
# NOTE: These are general values. For maximum accuracy, check the specific
# Prithvi-600m documentation for its exact SAR data normalization.
SAR_NORM_MEAN = -15.0  # Common mean for VV dB values
SAR_NORM_STD = 5.0    # Common standard deviation for VV dB values

# Dictionary to store all processed tensors
processed_tensors = {}

# =======================================================================
# CORE PROCESSING FUNCTION
# =======================================================================

def preprocess_sar_chip(file_path, sar_mean, sar_std):
    """
    Reads a single-band SAR GeoTIFF, standardizes it, and converts it
    to a PyTorch Tensor (1, C=1, H, W) for model inference.
    """
    try:
        with rasterio.open(file_path) as src:
            # Read the single band (VV)
            data = src.read(1).astype(np.float32)
            
            # Check for empty data / no-data values
            if np.all(data == src.nodata):
                return None

    except rasterio.RasterioIOError:
        print(f"Error: Could not open or read {file_path}. Skipping.")
        return None

    # 1. Standardization (Z-Score Normalization)
    # Apply Z-score: (Data - Mean) / Std Dev
    normalized_data = (data - sar_mean) / sar_std

    # 2. Convert to PyTorch Tensor
    # Reshape from (H, W) to (C, H, W) -> (1, H, W)
    tensor = torch.from_numpy(normalized_data).unsqueeze(0)
    
    # Add a batch dimension, making the shape (1, C, H, W) -> (1, 1, H, W)
    tensor = tensor.unsqueeze(0) 

    return tensor

# =======================================================================
# BATCH EXECUTION
# =======================================================================

print(f"Starting batch pre-processing from: {ROOT_CHIPS_DIR}")

# Iterate through all district folders (Barpeta, Dhemaji, etc.)
for district_name in os.listdir(ROOT_CHIPS_DIR):
    district_path = os.path.join(ROOT_CHIPS_DIR, district_name)
    
    if not os.path.isdir(district_path):
        continue

    processed_tensors[district_name] = {'pre_flood': [], 'post_flood': []}
    print(f"\n--- Processing District: {district_name} ---")

    # Iterate through 'pre_flood' and 'post_flood' folders
    for phase in ['pre_flood', 'post_flood']:
        phase_path = os.path.join(district_path, phase)
        
        if not os.path.isdir(phase_path):
            continue

        # Process all .tif files (image chips) in the phase folder
        for chip_filename in os.listdir(phase_path):
            if chip_filename.endswith('.tif'):
                chip_file_path = os.path.join(phase_path, chip_filename)
                
                # Run the core pre-processing function
                tensor = preprocess_sar_chip(
                    chip_file_path, SAR_NORM_MEAN, SAR_NORM_STD
                )
                
                if tensor is not None:
                    # Store the resulting tensor
                    processed_tensors[district_name][phase].append(tensor)
                    # print(f"    Processed: {chip_filename}")

# =======================================================================
# FINAL CHECK
# =======================================================================

print("\n=======================================================")
print("‚úÖ Batch Pre-processing Complete.")
print("=======================================================")

# Print the final structure for verification
for district, phases in processed_tensors.items():
    print(f"District: {district}")
    for phase, tensors in phases.items():
        if tensors:
            # Check the shape of the first tensor in the list
            print(f"  {phase}: {len(tensors)} chips, each with shape {tensors[0].shape}")
        else:
            print(f"  {phase}: 0 chips found.")

# The 'processed_tensors' dictionary now holds all your data ready for the Prithvi model.


Starting batch pre-processing from: C:\Kaam_Dhanda\Minor_Project\Output_chips

--- Processing District: Barpeta ---

--- Processing District: Dhemaji ---

--- Processing District: Lakhimpur ---

--- Processing District: Nalbari ---

--- Processing District: Sonitpur ---

‚úÖ Batch Pre-processing Complete.
District: Barpeta
  pre_flood: 96 chips, each with shape torch.Size([1, 1, 512, 512])
  post_flood: 96 chips, each with shape torch.Size([1, 1, 512, 512])
District: Dhemaji
  pre_flood: 104 chips, each with shape torch.Size([1, 1, 512, 512])
  post_flood: 104 chips, each with shape torch.Size([1, 1, 512, 512])
District: Lakhimpur
  pre_flood: 324 chips, each with shape torch.Size([1, 1, 512, 512])
  post_flood: 324 chips, each with shape torch.Size([1, 1, 512, 512])
District: Nalbari
  pre_flood: 80 chips, each with shape torch.Size([1, 1, 512, 512])
  post_flood: 80 chips, each with shape torch.Size([1, 1, 512, 512])
District: Sonitpur
  pre_flood: 198 chips, each with shape torch.Si

## **run the temporal AI inference and then stitch the predictions back together**

In [14]:

def prepare_chip_pair(pre_path, post_path):
    with rasterio.open(pre_path) as src_pre, rasterio.open(post_path) as src_post:
        # Load data as NumPy arrays (assuming single band, VV polarization)
        pre_chip = src_pre.read(1)
        post_chip = src_post.read(1)
        
        # Stack them to create the temporal input (e.g., shape: 2, 512, 512)
        temporal_input = np.stack([pre_chip, post_chip], axis=0)
        
        # Convert to PyTorch Tensor, add a batch dimension (1), and move to GPU (if available)
        tensor_input = torch.from_numpy(temporal_input).float().unsqueeze(0)
        
        # Store the geospatial profile for later stitching
        profile = src_pre.profile
        
    return tensor_input, profile

In [15]:

def run_inference_and_save(pre_path, post_path, output_mask_dir):
    tensor_input, profile = prepare_chip_pair(pre_path, post_path)
    
    # 1. Run the prediction
    # model.eval() is required for inference mode
    with torch.no_grad():
        # output is typically a logit map (e.g., shape: 1, num_classes, 512, 512)
        output_logits = model(tensor_input) 
        
    # 2. Get the final classification (0 or 1)
    # This finds the class with the highest probability (e.g., 0=not-flood, 1=flood)
    # Reshape and convert back to a NumPy array (shape: 512, 512)
    predicted_mask_tensor = torch.argmax(output_logits, dim=1).squeeze().cpu()
    predicted_mask_array = predicted_mask_tensor.numpy().astype(rasterio.uint8)
    
    # 3. Save the prediction mask
    chip_filename = os.path.basename(pre_path).replace('PreFlood_Image', 'Flood_Mask')
    output_path = os.path.join(output_mask_dir, chip_filename)
    
    # Update profile to reflect the new data type (binary mask)
    profile.update(dtype=rasterio.uint8, count=1) 
    
    with rasterio.open(output_path, 'w', **profile) as dst:
        dst.write(predicted_mask_array, 1)
    
    return output_path

In [None]:

def stitch_masks(mask_dir, district_name, final_output_dir):
    """Stitches all predicted flood mask chips into a single GeoTIFF."""
    
    mask_files = [os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.tif')]
    
    # Open all mask datasets
    sources = [rasterio.open(f) for f in mask_files]
    
    # Use rasterio.merge to create a mosaic
    stitched_array, out_transform = merge(sources)
    
    # Get the metadata from the first source file
    out_meta = sources[0].profile.copy()
    
    # Update the metadata for the merged output
    out_meta.update({
        "driver": "GTiff",
        "height": stitched_array.shape[1],
        "width": stitched_array.shape[2],
        "transform": out_transform,
        "count": 1,
        "dtype": 'uint8'
    })
    
    # Write the final stitched GeoTIFF
    final_output_path = os.path.join(final_output_dir, f'{district_name}_Final_Flood_Mask.tif')
    with rasterio.open(final_output_path, "w", **out_meta) as dest:
        dest.write(stitched_array)
        
    # Close all source files
    for src in sources:
        src.close()
        
    print(f"‚úÖ Final stitched mask saved to: {final_output_path}")