Note:

You must have a folder named "SAR_model_data" in your google drive with following tree:

SAR_model_data

---->Prelabelled

-------->Unmasked, raw images in .tiff format. Must be   
          256x256, normalized to 0-255 and uint8

---->Labelled

-------->Masked, Labelled images in .tiff format. Must be 256x256, normalized to 0-1 and float32


"/content/drive/MyDrive/SAR_model_data"

In [1]:
# Check GPU status
import torch
print(f"✅ PyTorch: {torch.__version__}")
print(f"✅ CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
    print(f"✅ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("❌ GPU not available - make sure you selected GPU in Runtime settings")

✅ PyTorch: 2.8.0+cu126
✅ CUDA available: True
✅ GPU: Tesla T4
✅ GPU Memory: 15.8 GB


In [5]:
!pip install rasterio segmentation-models-pytorch albumentations opencv-python matplotlib torchgeo
print("✅ All packages installed!")


Collecting torchgeo
  Downloading torchgeo-0.7.1-py3-none-any.whl.metadata (18 kB)
Collecting fiona>=1.8.22 (from torchgeo)
  Downloading fiona-1.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.6/56.6 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kornia>=0.7.4 (from torchgeo)
  Downloading kornia-0.8.1-py2.py3-none-any.whl.metadata (17 kB)
Collecting lightly!=1.4.26,>=1.4.5 (from torchgeo)
  Downloading lightly-1.5.22-py3-none-any.whl.metadata (38 kB)
Collecting lightning!=2.3.*,!=2.5.0,>=2 (from lightning[pytorch-extra]!=2.3.*,!=2.5.0,>=2->torchgeo)
  Downloading lightning-2.5.5-py3-none-any.whl.metadata (39 kB)
Collecting rtree>=1.0.1 (from torchgeo)
  Downloading rtree-1.4.1-py3-none-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (2.1 kB)
Collecting torchmetrics>=1.2 (from torchgeo)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collectin

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import segmentation_models_pytorch as smp
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import rasterio
from sklearn.model_selection import train_test_split
import glob
from tqdm.auto import tqdm
from torchgeo.models import resnet50



print("✅ All packages imported successfully!")

# Your configuration code here...
class Config:
    BACKBONE = 'efficientnet-b3'
    NUM_CLASSES = 1
    BATCH_SIZE =  32
    LEARNING_RATE = 0.0001
    NUM_EPOCHS = 40
    IMAGE_SIZE = 256
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #DEVICE = torch.device('cpu')
    DATA_DIR = "./SAR_model_data"
    CHECKPOINT_DIR = "./lake_checkpoints"

config = Config()
print(f"Using device: {config.DEVICE}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

✅ All packages imported successfully!
Using device: cuda
GPU: Tesla T4


In [None]:
# Check what we have
print("📁 Checking uploaded files...")

# List all files in the extracted folders
labelled_path = "/home/shailesh/Desktop/Segmentation/SAR_model_data/Labelled"
prelabelled_path = "/home/shailesh/Desktop/Segmentation/SAR_model_data/Prelabelled"
labelled_files = os.listdir(labelled_path)
prelabelled_files = os.listdir(prelabelled_path)

# Check if folders exist and what's inside
if os.path.exists(labelled_path):
    labelled_files = os.listdir(labelled_path)
    print(f"✅ Labeled lake: {len(labelled_files)} files")
    for f in labelled_files[:10]:  # Show first 10 files
        print(f"   📄 {f}")
else:
    print("❌ Labeled lake folder not found")

if os.path.exists(prelabelled_path):
    prelabelled_files = os.listdir(prelabelled_path)
    print(f"✅ Unmasked lake: {len(prelabelled_files)} files")
    for f in prelabelled_files[:10]:
        print(f"   📄 {f}")
else:
    print("❌ Unmasked lake folder not found")

# Check file types
print("\n🔍 Checking file types...")
if os.path.exists(labelled_path):
    tiff_files = glob.glob(os.path.join(labelled_path, "*.tif*"))
    png_files = glob.glob(os.path.join(labelled_path, "*.png"))
    jpg_files = glob.glob(os.path.join(labelled_path, "*.jpg"))
    print(f"   TIFF files: {len(tiff_files)}")
    print(f"   PNG files: {len(png_files)}")
    print(f"   JPG files: {len(jpg_files)}")

def check_file_structure(filepath):
    try:
        with rasterio.open(filepath) as src:
            print(f"📄 {os.path.basename(filepath)}:")
            print(f"   Bands: {src.count}")
            print(f"   Shape: {src.height} x {src.width}")
            print(f"   Dtype: {src.dtypes[0]}")

            # Read first band to check values
            band1 = src.read(1)
            unique_vals = np.unique(band1)
            #print(f"   Unique values in band 1: {unique_vals}")

            return src.count, band1.shape
    except Exception as e:
        print(f"   Error: {e}")
        return None, None

print("\n🔍 Checking Prelabelled file structure...")
for filename in prelabelled_files:
    if filename.endswith('.tif'):
        filepath = prelabelled_path + '/' + filename
        check_file_structure(filepath)
        print("---")
print("\n🔍 Checking Labelled file structure...")
for filename in labelled_files:
    if filename.endswith('.tif'):
        filepath = labelled_path + '/' + filename
        check_file_structure(filepath)
        print("---")

📁 Checking uploaded files...
✅ Labeled lake: 245 files
   📄 S1A_IW_20210828T123039_DVP_RTC20_G_gpufed_9809_VV_clipped_to_tilichoTshoAOI.tif
   📄 S1A_IW_20240729T001133_DVP_RTC20_G_gpufed_53BD_VV.tif_clipped_to_chamlangTshoAOI.geojson.tif
   📄 S1A_IW_20250724T001125_DVP_RTC20_G_gpufed_2009_VV.tif_clipped_to_chamlangTshoAOI.geojson.tif
   📄 S1A_IW_20220926T001133_DVP_RTC20_G_gpufed_7367_VV.tif_clipped_to_tshoRolpaAOI.geojson.tif
   📄 S1A_IW_20210323T001129_DVP_RTC20_G_gpufed_AD1C_VV.tif_clipped_to_chamlangTshoAOI.geojson.tif
   📄 S1A_IW_20230720T122237_DVP_RTC20_G_gpufed_536B_VV.tif_clipped_to_tshoRolpaAOI.geojson.tif
   📄 S1A_IW_20250323T122230_DVP_RTC20_G_gpufed_6144_VV.tif_clipped_to_imjaTshoAOI.geojson.tif
   📄 S1A_IW_20230925T121409_DVP_RTC20_G_gpufed_C3D9_VV.tif_clipped_to_gokyoTshoAOI.geojson.tif
   📄 S1A_IW_20210413T122220_DVP_RTC20_G_gpufed_9EC5_VV.tif_clipped_to_tshoRolpaAOI.geojson.tif
   📄 S1A_IW_20240328T122236_DVP_RTC20_G_gpufed_689A_VV.tif_clipped_to_tshoRolpaAOI.geojson.t

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


---
📄 S1A_IW_20240328T122236_DVP_RTC20_G_gpufed_689A_VV.tif_clipped_to_imjaTshoAOI.geojson.tif:
   Bands: 1
   Shape: 256 x 256
   Dtype: float32
---
📄 S1A_IW_20220728T001131_DVP_RTC20_G_gpufed_610B_VV.tif_clipped_to_chamlangTshoAOI.geojson.tif:
   Bands: 1
   Shape: 256 x 256
   Dtype: float32
---
📄 S1A_IW_20210313T123031_DVP_RTC20_G_gpufed_0242_VV_clipped_to_tilichoTshoAOI.tif:
   Bands: 1
   Shape: 256 x 256
   Dtype: float32
---
📄 S1A_IW_20250227T122229_DVP_RTC20_G_gpufed_5DE1_VV.tif_clipped_to_tshoRolpaAOI.geojson.tif:
   Bands: 1
   Shape: 256 x 256
   Dtype: float32
---
📄 S1A_IW_20210823T122228_DVP_RTC20_G_gpufed_BB91_VV.tif_clipped_to_imjaTshoAOI.geojson.tif:
   Bands: 1
   Shape: 256 x 256
   Dtype: float32
---
📄 S1A_IW_20240726T122235_DVP_RTC20_G_gpufed_48F5_VV.tif_clipped_to_tshoRolpaAOI.geojson.tif:
   Bands: 1
   Shape: 256 x 256
   Dtype: float32
---
📄 S1A_IW_20230825T122239_DVP_RTC20_G_gpufed_9269_VV.tif_clipped_to_tshoRolpaAOI.geojson.tif:
   Bands: 1
   Shape: 256 x 25

In [10]:
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None, image_size=256, num_augmentations=1):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.image_size = image_size
        self.num_augmentations = num_augmentations

        # Length is scaled by num_augmentations
        self.length = len(image_paths) * num_augmentations

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Calculate original index and augmentation iteration
        original_idx = idx % len(self.image_paths)

        img_path = self.image_paths[original_idx]
        mask_path = self.mask_paths[original_idx]

        # 1. READ IMAGE AND MASK
        with rasterio.open(img_path) as src:
            image = src.read(1).copy()  # Single-channel

        with rasterio.open(mask_path) as src:
            mask = src.read(1).copy()

        # 2. PREPROCESS
        if mask.ndim > 2:
            mask = np.squeeze(mask)
        if image.ndim == 2:
            image = image[..., np.newaxis]
        mask = np.round(mask)

        image = image.astype(np.float32) / 255.0
        patch_mean = image.mean()
        patch_std = image.std()
        epsilon = 1e-6
        image = (image - patch_mean) / (patch_std + epsilon)

        # 3. AUGMENTATION
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

            """
            image_np = image.squeeze().cpu().numpy()
            mask_np = mask.cpu().numpy()
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(image_np, cmap='gray')
            axes[0].set_title('Augmented Image')
            axes[0].axis('off')
            axes[1].imshow(mask_np, cmap='Reds', vmin=0, vmax=1)
            axes[1].set_title('Augmented Mask')
            axes[1].axis('off')
            axes[2].imshow(image_np, cmap='gray')
            axes[2].imshow(mask_np, cmap='Reds', alpha=0.4, vmin=0, vmax=1)
            axes[2].set_title('Overlay (Check Alignment)')
            axes[2].axis('off')
            plt.show()
            """

        assert image.shape == (1, 256, 256), f"Image shape mismatch: {image.shape}"
        assert mask.shape == (256, 256), f"Mask shape mismatch: {mask.shape}"
        assert image.dtype == torch.float32, f"Image dtype mismatch: {image.dtype}"
        assert mask.dtype == torch.float32, f"Mask dtype mismatch: {mask.dtype}"

        return image, mask

    @staticmethod
    def get_train_transform():
        # Stronger augmentation to increase diversity
        return A.Compose([
            A.Resize(256, 256),
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=90, p=0.5, border_mode=cv2.BORDER_CONSTANT),
            A.ElasticTransform(alpha=1, sigma=4, p=0.3),
            A.GaussNoise(p=0.3),
            ToTensorV2()
        ], is_check_shapes=False)

    @staticmethod
    def get_val_transform():
        return A.Compose([
            A.Resize(256, 256),
            ToTensorV2(),
        ], is_check_shapes=False)


In [11]:
# Constants (Ensuring these are defined for the entire script)
NUM_CLASSES = 1 # 1 output channel for binary segmentation with Sigmoid/BCE loss
INPUT_CHANNELS = 1

# Re-define create_model for completeness, ensuring in_channels=1
def create_model(num_classes=NUM_CLASSES, backbone='efficientnet-b3'):
    model = smp.Unet(
        encoder_name=backbone,
        encoder_weights='imagenet',
        in_channels=INPUT_CHANNELS,
        classes=num_classes,
    )
    return model

class LakeDetectionLoss(nn.Module):
    def __init__(self, alpha=0.8):
        super().__init__()
        self.alpha = alpha

        # Use BCEWithLogitsLoss for binary stability (1 channel output)
        self.bce_loss = nn.BCEWithLogitsLoss()

        self.dice_loss = smp.losses.DiceLoss(mode='binary')

        self.focal_loss = smp.losses.FocalLoss(mode='binary', alpha=self.alpha)

    def forward(self, outputs, targets):
        # Prepare Targets: BCE/Dice/Focal binary modes expect (B, 1, H, W) float.
        # Targets are currently (B, H, W) LongTensor from the DataLoader.
        targets_float = targets.unsqueeze(1).float()

        # Calculate Loss Components
        bce = self.bce_loss(outputs, targets_float)
        dice = self.dice_loss(outputs, targets_float)
        focal = self.focal_loss(outputs, targets_float)

        # Combined Loss (Weights suggested for severe imbalance)
        return (0.4 * bce) + (0.3 * dice) + (0.3 * focal)

# Test model creation and Loss initialization
model = create_model(num_classes=NUM_CLASSES)
loss_fn = LakeDetectionLoss()

print("✅ Model created successfully!")
print(f"📊 Model parameters: {sum(p.numel() for p in model.parameters()):,}")

config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/49.3M [00:00<?, ?B/s]

✅ Model created successfully!
📊 Model parameters: 13,158,313


In [23]:
class LakeTrainer:
    def __init__(self, config):
        self.config = config
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.train_loader = None
        self.val_loader = None
        self.best_iou = 0.0
        self.criterion = LakeDetectionLoss()

    def find_matching_pairs(self):
        """Find only the images that have corresponding masks"""
        # NOTE: Using 'prelabelled' and 'labelled' directories as specified in the prompt
        image_dir = os.path.join(self.config.DATA_DIR, "Prelabelled")
        mask_dir = os.path.join(self.config.DATA_DIR, "Labelled")

        # Get all files
        all_images = sorted([f for f in os.listdir(image_dir) if f.endswith(('.tiff', '.tif'))])
        all_masks = sorted([f for f in os.listdir(mask_dir) if f.endswith(('.tiff', '.tif'))])

        print(f"🔍 Found {len(all_images)} images and {len(all_masks)} masks")

        # Find common filenames (images that have corresponding masks)
        common_files = list(set(all_images) & set(all_masks))
        common_files.sort()

        print(f"✅ Using {len(common_files)} matched image-mask pairs")
        print(f"❌ Ignoring {len(all_images) - len(common_files)} images without masks")

        # Create paths for common files only
        image_paths = [os.path.join(image_dir, f) for f in common_files]
        mask_paths = [os.path.join(mask_dir, f) for f in common_files]

        return image_paths, mask_paths, common_files

    def prepare_data(self):
        """Prepare train/validation split using ONLY matched pairs"""
        # ... (This method is correctly implemented and remains unchanged) ...
        image_paths, mask_paths, common_files = self.find_matching_pairs()

        if not image_paths:
            raise Exception("No matching image-mask pairs found! Please check your files.")

        if len(image_paths) < 10:
            raise Exception(f"Only {len(image_paths)} pairs found. Need at least 10 for training.")

        # Split data: 80% train, 20% validation
        train_img, val_img, train_mask, val_mask = train_test_split(
            image_paths, mask_paths, test_size=0.2, random_state=42, shuffle=True
        )

        print(f"📁 Training samples: {len(train_img)}")
        print(f"📁 Validation samples: {len(val_img)}")

        # Show some examples
        print("\n📸 Sample pairs being used:")
        for i in range(min(3, len(common_files))):
            print(f"   {i+1}. {common_files[i]}")

        # Create datasets (Assuming SegmentationDataset is the correct class)
        train_dataset = SegmentationDataset(
                      train_img, train_mask,
                      transform=SegmentationDataset.get_train_transform(),
                      num_augmentations=20  # multiply training data 20x
    )
        val_dataset = SegmentationDataset(
                      val_img, val_mask,
                      transform=SegmentationDataset.get_val_transform(),
                      num_augmentations=1
    )

        # Test one sample to check data types
        test_image, test_mask = train_dataset[0]
        print(f"✅ Data type check - Image: {test_image.dtype}, Mask: {test_mask.dtype}")
        print(f"✅ Shape check - Image: {test_image.shape}, Mask: {test_mask.shape}")

        # Create data loaders
        self.train_loader = DataLoader(train_dataset, batch_size=self.config.BATCH_SIZE,
                                     shuffle=True, num_workers=0, pin_memory=True)
        self.val_loader = DataLoader(val_dataset, batch_size=self.config.BATCH_SIZE,
                                   shuffle=False, num_workers=0, pin_memory=True)

        return train_img, train_mask, val_img, val_mask

    def setup_model(self):
        """Initialize model, loss, optimizer, and scheduler"""
        self.model = create_model(self.config.NUM_CLASSES, self.config.BACKBONE)
        self.model.to(self.config.DEVICE)

        # CRITICAL FIX 1: Use AdamW for improved performance/regularization
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=self.config.LEARNING_RATE,
            weight_decay=1e-5 # Recommended for AdamW
        )

        # Learning Rate Scheduler (Remains correct)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5
        )

        print(f"✅ Model setup complete: {self.config.BACKBONE} backbone")
        print(f"✅ Using custom LakeDetectionLoss (BCE + Dice + Focal).")
        print(f"✅ Using Optimizer: AdamW.")
        print(f"✅ Using device: {self.config.DEVICE}")

    def calculate_iou(self, outputs, targets):
        """
        CRITICAL FIX 2: Manual IoU calculation for BINARY segmentation.
        Focuses ONLY on the positive (water) class, which is the key metric.

        Outputs are (B, 1, H, W) raw logits. Targets are (B, H, W) LongTensor (0 or 1).
        """
        # Apply Sigmoid and threshold to get the binary prediction mask for water (class 1)
        preds = torch.sigmoid(outputs).squeeze(1) # (B, H, W) float [0, 1]
        preds = (preds > 0.5)                      # (B, H, W) boolean mask for water

        # True mask for water (class 1)
        true_mask = (targets == 1)                 # (B, H, W) boolean mask for water

        # Calculate Intersection and Union
        intersection = (preds & true_mask).float().sum()
        union = (preds | true_mask).float().sum()

        # IoU for Class 1 (Water)
        iou_water = intersection / union if union > 0 else torch.tensor(0.0, device=outputs.device)

        return iou_water

    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        progress_bar = tqdm(self.train_loader, desc=f'🏋️ Epoch {epoch+1}/{self.config.NUM_EPOCHS}')

        for batch_idx, (images, masks) in enumerate(progress_bar):
            images, masks = images.to(self.config.DEVICE), masks.to(self.config.DEVICE)

            # CRITICAL ADJUSTMENT: Loss expects FloatTensor for targets in binary mode
            masks_float = masks.unsqueeze(1).float() # (B, 1, H, W) float

            self.optimizer.zero_grad()
            outputs = self.model(images)
            # CRITICAL: Use the custom LakeDetectionLoss
            loss = self.criterion(outputs, masks) # Loss handles the conversion from masks (Long) to float internally

            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Avg Loss': f'{total_loss/(batch_idx+1):.4f}'
            })

        return total_loss / len(self.train_loader)

    def validate(self, epoch):
        """Validate model performance"""
        self.model.eval()
        val_loss = 0
        total_iou = 0

        with torch.no_grad():
            val_bar = tqdm(self.val_loader, desc=f'🧪 Validating Epoch {epoch+1}')
            for images, masks in val_bar:
                images, masks = images.to(self.config.DEVICE), masks.to(self.config.DEVICE)

                # CRITICAL ADJUSTMENT: Loss expects FloatTensor for targets in binary mode
                masks_float = masks.unsqueeze(1).float()

                outputs = self.model(images)
                # CRITICAL: Use the custom LakeDetectionLoss
                loss = self.criterion(outputs, masks) # Loss handles the conversion from masks (Long) to float internally
                val_loss += loss.item()

                # Calculate IoU (now fixed for binary and single-class)
                iou = self.calculate_iou(outputs, masks)
                total_iou += iou.item()

                val_bar.set_postfix({
                    'Val Loss': f'{loss.item():.4f}',
                    'IoU (Water)': f'{iou.item():.4f}'
                })

        avg_loss = val_loss / len(self.val_loader)
        avg_iou = total_iou / len(self.val_loader)

        return avg_loss, avg_iou

    def save_checkpoint(self, epoch, val_iou, is_best=False):
        """Save model checkpoint (Remains correct)"""
        # ... (Implementation is correct and remains unchanged)
        os.makedirs(self.config.CHECKPOINT_DIR, exist_ok=True)
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_iou': val_iou,
            'config': {
                'backbone': self.config.BACKBONE,
                'num_classes': self.config.NUM_CLASSES,
                'image_size': self.config.IMAGE_SIZE
            }
        }
        checkpoint_path = f'{self.config.CHECKPOINT_DIR}/epoch_{epoch+1:03d}.pth'
        torch.save(checkpoint, checkpoint_path)
        if is_best:
            best_path = f'{self.config.CHECKPOINT_DIR}/best_lake_model.pth'
            torch.save(self.model.state_dict(), best_path)

    def train(self):
        """Main training loop (Remains correct)"""
        print("🚀 Starting Lake Detection Training...")
        print("=" * 60)
        train_img_paths, train_mask_paths, val_img_paths, val_mask_paths = self.prepare_data()
        self.setup_model()
        print(f"🔧 Model: U-Net/DeepLabV3+ with {self.config.BACKBONE}")
        print(f"🎯 Target: Binary Segmentation (Water Class)")
        print(f"📊 Dataset: {len(train_img_paths)} train, {len(val_img_paths)} validation")
        print(f"⚙️  Batch size: {self.config.BATCH_SIZE}, Learning rate: {self.config.LEARNING_RATE}")
        print("=" * 60)
        train_losses, val_losses, val_ious = [], [], []
        for epoch in range(self.config.NUM_EPOCHS):
            train_loss = self.train_epoch(epoch)
            train_losses.append(train_loss)
            val_loss, val_iou = self.validate(epoch)
            val_losses.append(val_loss)
            val_ious.append(val_iou)
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]['lr']
            print(f'\n📈 Epoch {epoch+1:02d}/{self.config.NUM_EPOCHS} Summary:')
            print(f'   Train Loss: {train_loss:.4f}')
            print(f'   Val Loss:   {val_loss:.4f}')
            print(f'   Val IoU:    {val_iou:.4f}')
            print(f'   LR:         {current_lr:.6f}')
            is_best = val_iou > self.best_iou
            if is_best:
                self.best_iou = val_iou
                self.save_checkpoint(epoch, val_iou, is_best)
                print(f'   🎯 NEW BEST! IoU: {val_iou:.4f}')
            if (epoch + 1) % 10 == 0:
                self.save_checkpoint(epoch, val_iou)
                print(f'   💾 Checkpoint saved at epoch {epoch+1}')
            print("-" * 50)
            if epoch > 30 and val_iou < 0.3:
                print("🛑 Early stopping - model not learning well")
                break
        print("=" * 60)
        print(f"✅ Training complete! Best IoU: {self.best_iou:.4f}")
        print(f"📁 Models saved in: {self.config.CHECKPOINT_DIR}")
        return train_losses, val_losses, val_ious, train_img_paths, train_mask_paths
trainer = LakeTrainer(config)
print(trainer)


<__main__.LakeTrainer object at 0x7febfcf00680>


In [None]:
# Cell 7 - Main Training (Safe version with plotting)
def main():
    print("🌊 Lake Detection Model Training")
    print("=" * 50)
    print(f"📁 Looking for data at: {config.DATA_DIR}")

    # Check directories
    images_dir = os.path.join(config.DATA_DIR, "Prelabelled")
    masks_dir = os.path.join(config.DATA_DIR, "Labelled")
    if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
        print("❌ Data directories not found!")
        return None

    # Create checkpoint directory
    os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
    print(f"✅ Created checkpoint directory: {config.CHECKPOINT_DIR}")

    # Create trainer
    trainer = LakeTrainer(config)

    try:
        # Train
        result = trainer.train()  # might return None or a tuple

        # Safely get losses
        if result is None:
            print("ℹ️ Trainer.train() returned None. Using trainer's internal attributes.")
            train_losses = getattr(trainer, 'train_losses', [])
            val_losses = getattr(trainer, 'val_losses', [])
            val_ious = getattr(trainer, 'val_ious', [])
        else:
            # Try unpacking first 3 elements
            try:
                train_losses, val_losses, val_ious = result[:3]
            except Exception:
                print("⚠️ Could not unpack result from trainer.train(), using internal attributes instead.")
                train_losses = getattr(trainer, 'train_losses', [])
                val_losses = getattr(trainer, 'val_losses', [])
                val_ious = getattr(trainer, 'val_ious', [])

        # Plot results
        import matplotlib.pyplot as plt
        plt.figure(figsize=(15,5))

        plt.subplot(1,3,1)
        plt.plot(train_losses, label='Train Loss')
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1,3,2)
        plt.plot(val_losses, label='Val Loss', color='orange')
        plt.title('Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1,3,3)
        plt.plot(val_ious, label='Val IoU', color='green')
        plt.title('Validation IoU')
        plt.xlabel('Epoch')
        plt.ylabel('IoU')
        plt.legend()

        plt.tight_layout()
        plt.show()

        print("🎉 Training completed successfully!")
        print(f"📁 Models saved in: {config.CHECKPOINT_DIR}")

        return trainer

    except Exception as e:
        print(f"❌ An error occurred during training: {e}")
        return None

trainer_obj = main()


🌊 Lake Detection Model Training
📁 Looking for data at: /content/drive/MyDrive/SAR_model_data
✅ Created checkpoint directory: ./lake_checkpoints
🚀 Starting Lake Detection Training...
🔍 Found 282 images and 245 masks
✅ Using 244 matched image-mask pairs
❌ Ignoring 38 images without masks
📁 Training samples: 195
📁 Validation samples: 49

📸 Sample pairs being used:
   1. S1A_IW_20210122T001130_DVP_RTC20_G_gpufed_5E3C_VV.tif_clipped_to_chamlangTshoAOI.geojson.tif
   2. S1A_IW_20210122T001130_DVP_RTC20_G_gpufed_5E3C_VV.tif_clipped_to_gokyoTshoAOI.geojson.tif
   3. S1A_IW_20210122T001130_DVP_RTC20_G_gpufed_5E3C_VV.tif_clipped_to_imjaTshoAOI.geojson.tif
✅ Data type check - Image: torch.float32, Mask: torch.float32
✅ Shape check - Image: torch.Size([1, 256, 256]), Mask: torch.Size([256, 256])
✅ Model setup complete: efficientnet-b3 backbone
✅ Using custom LakeDetectionLoss (BCE + Dice + Focal).
✅ Using Optimizer: AdamW.
✅ Using device: cuda
🔧 Model: U-Net/DeepLabV3+ with efficientnet-b3
🎯 Targe

🏋️ Epoch 1/40:   0%|          | 0/122 [00:00<?, ?it/s]

🧪 Validating Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s]


📈 Epoch 01/40 Summary:
   Train Loss: 0.3633
   Val Loss:   0.3465
   Val IoU:    0.8058
   LR:         0.000100
   🎯 NEW BEST! IoU: 0.8058
--------------------------------------------------


🏋️ Epoch 2/40:   0%|          | 0/122 [00:00<?, ?it/s]

🧪 Validating Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s]


📈 Epoch 02/40 Summary:
   Train Loss: 0.2151
   Val Loss:   0.1843
   Val IoU:    0.8631
   LR:         0.000100
   🎯 NEW BEST! IoU: 0.8631
--------------------------------------------------


🏋️ Epoch 3/40:   0%|          | 0/122 [00:00<?, ?it/s]

🧪 Validating Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s]


📈 Epoch 03/40 Summary:
   Train Loss: 0.1468
   Val Loss:   0.1197
   Val IoU:    0.8265
   LR:         0.000100
--------------------------------------------------


🏋️ Epoch 4/40:   0%|          | 0/122 [00:00<?, ?it/s]

🧪 Validating Epoch 4:   0%|          | 0/2 [00:00<?, ?it/s]


📈 Epoch 04/40 Summary:
   Train Loss: 0.0980
   Val Loss:   0.0747
   Val IoU:    0.8888
   LR:         0.000100
   🎯 NEW BEST! IoU: 0.8888
--------------------------------------------------


🏋️ Epoch 5/40:   0%|          | 0/122 [00:00<?, ?it/s]

In [14]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive
