# Land Cover Change Detection - ISRO Production Training

## Complete Training Pipeline for Satellite Image Change Detection

### Features:
- SNUNet Architecture with CBAM Attention
- Hybrid Loss (BCE + Dice + Focal + Tversky)
- EMA (Exponential Moving Average) for stable predictions
- 8x Test-Time Augmentation (TTA) for best accuracy
- Real-time visualization and GPU monitoring
- Auto file list generation from dataset structure
- Save EVERY epoch model + Best model

### Target: F1 Score > 0.85

---
## Cell 1: Install Dependencies

In [None]:
!pip install -q tqdm matplotlib seaborn pandas scikit-learn albumentations

---
## Cell 2: Import Libraries

In [None]:
import os
import sys
import time
import warnings
import subprocess
import copy
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import random
import glob
import shutil

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm
from IPython.display import display, clear_output, HTML
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms

from sklearn.metrics import f1_score, jaccard_score, cohen_kappa_score, precision_score, recall_score

warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

print("="*60)
print("LAND COVER CHANGE DETECTION - KAGGLE TRAINING")
print("="*60)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

---
## Cell 3: Configuration

In [None]:
@dataclass
class Config:
    """All training hyperparameters"""
    # DATA PATHS - KAGGLE DATASET
    data_root: str = "/kaggle/input/dataset1new"
    output_dir: str = "/kaggle/working"
    checkpoint_dir: str = "/kaggle/working/checkpoints"
    models_dir: str = "/kaggle/working/models"
    patch_size: int = 256
    
    # TRAINING
    batch_size: int = 16
    num_workers: int = 2
    pin_memory: bool = False
    epochs: int = 50
    learning_rate: float = 2e-4
    weight_decay: float = 0.01
    
    # MODEL
    base_channel: int = 32
    use_attention: bool = True
    
    # LOSS WEIGHTS (Hybrid Loss)
    bce_weight: float = 0.3
    dice_weight: float = 0.3
    focal_weight: float = 0.2
    tversky_weight: float = 0.2
    focal_gamma: float = 2.0
    tversky_alpha: float = 0.3
    tversky_beta: float = 0.7
    
    # ADVANCED TECHNIQUES
    use_ema: bool = True
    ema_decay: float = 0.999
    use_tta: bool = True
    tta_transforms: int = 8
    gradient_clip: float = 1.0
    
    # MODEL SAVING
    save_every_epoch: bool = True
    save_best_only: bool = True
    max_models_to_keep: int = 10
    
    # SCHEDULER
    scheduler_type: str = "cosine_warm"
    warmup_epochs: int = 3
    T_0: int = 10
    T_mult: int = 2
    
    # EARLY STOPPING
    patience: int = 15
    min_delta: float = 0.001
    
    # LOGGING
    log_interval: int = 50
    save_every: int = 5
    
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

config = Config()
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.models_dir, exist_ok=True)

print("\nConfiguration loaded")
print(f"Data Root: {config.data_root}")
print(f"Device: {config.device}")
print(f"Batch Size: {config.batch_size}")
print(f"Epochs: {config.epochs}")
print(f"Learning Rate: {config.learning_rate}")
print(f"Save Every Epoch: {config.save_every_epoch}")
print(f"Save Best Only: {config.save_best_only}")

---
## Cell 4: Auto-Generate File Lists

In [None]:
def generate_file_lists(data_root, output_dir):
    """
    Auto-generate train/val/test file lists from dataset structure.
    Expected structure:
    - data_root/train/A/, train/B/, train/label/
    - data_root/val/A/, val/B/, val/label/
    - data_root/test/A/, test/B/, test/label/
    """
    splits = ['train', 'val', 'test']
    file_lists = {}
    
    print("\nGenerating file lists...")
    print(f"Dataset root: {data_root}")
    
    for split in splits:
        a_dir = os.path.join(data_root, split, 'A')
        
        if not os.path.exists(a_dir):
            print(f"Warning: {split} folder not found, skipping...")
            continue
        
        files = sorted(glob.glob(os.path.join(a_dir, '*.png')))
        
        list_path = os.path.join(output_dir, f'{split}_list.txt')
        
        with open(list_path, 'w') as f:
            for img_path in files:
                filename = os.path.basename(img_path)
                line = f"{split}/A/{filename} {split}/B/{filename} {split}/label/{filename}\n"
                f.write(line)
        
        file_lists[split] = list_path
        print(f"{split}: {len(files)} samples -> {list_path}")
    
    return file_lists

file_lists = generate_file_lists(config.data_root, config.output_dir)

config.train_list = file_lists.get('train', '')
config.val_list = file_lists.get('val', '')
config.test_list = file_lists.get('test', '')

print("\nFile lists generated successfully!")

---
## Cell 5: Enhanced Utility Classes

In [None]:
class ModelManager:
    """Manages model saving and loading with cleanup"""
    def __init__(self, models_dir, max_models_to_keep=10):
        self.models_dir = models_dir
        self.max_models_to_keep = max_models_to_keep
        self.saved_models = []
        
    def save_model(self, model, epoch, metrics, is_best=False, ema_shadow=None):
        """Save model with comprehensive metadata"""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        
        if is_best:
            filename = f"BEST_model_epoch{epoch:03d}_f1_{metrics['f1']:.4f}_{timestamp}.pth"
        else:
            filename = f"model_epoch{epoch:03d}_f1_{metrics['f1']:.4f}_{timestamp}.pth"
        
        filepath = os.path.join(self.models_dir, filename)
        
        save_dict = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'f1': metrics['f1'],
            'iou': metrics['iou'],
            'precision': metrics['precision'],
            'recall': metrics['recall'],
            'loss': metrics['loss'],
            'timestamp': timestamp,
            'is_best': is_best
        }
        
        if ema_shadow:
            save_dict['ema_shadow'] = ema_shadow
        
        torch.save(save_dict, filepath)
        self.saved_models.append((filepath, epoch, metrics['f1'], is_best))
        
        self._cleanup_old_models()
        
        return filepath
    
    def _cleanup_old_models(self):
        """Keep only recent models"""
        if len(self.saved_models) > self.max_models_to_keep:
            self.saved_models.sort(key=lambda x: x[1])
            
            models_to_remove = []
            for model_info in self.saved_models:
                if not model_info[3]:
                    models_to_remove.append(model_info)
                    if len(self.saved_models) - len(models_to_remove) <= self.max_models_to_keep:
                        break
            
            for model_info in models_to_remove:
                if os.path.exists(model_info[0]):
                    os.remove(model_info[0])
                    self.saved_models.remove(model_info)
    
    def get_best_model_path(self):
        """Get path to best model"""
        best_models = [m for m in self.saved_models if m[3]]
        if best_models:
            best_models.sort(key=lambda x: x[2], reverse=True)
            return best_models[0][0]
        return None
    
    def list_models(self):
        """List all saved models"""
        print(f"\nSaved Models ({len(self.saved_models)}):")
        print("-" * 80)
        for path, epoch, f1, is_best in sorted(self.saved_models, key=lambda x: x[1]):
            status = "[BEST]" if is_best else ""
            print(f"Epoch {epoch:03d} | F1: {f1:.4f} {status} | {os.path.basename(path)}")


class AverageMeter:
    """Computes and stores average values"""
    def __init__(self, name: str = ""):
        self.name = name
        self.reset()

    def reset(self):
        self.val = self.avg = self.sum = self.count = 0
        self.history = []

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.history.append(val)


class EMA:
    """Exponential Moving Average for model weights"""
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        self._register()

    def _register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_avg = self.decay * self.shadow[name] + (1.0 - self.decay) * param.data
                self.shadow[name] = new_avg.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}


class GPUMonitor:
    """Real-time GPU monitoring"""
    def __init__(self):
        self.history = []
        
    def get_stats(self) -> Dict:
        if not torch.cuda.is_available():
            return {'util': 0, 'mem_used': 0, 'mem_total': 0}
        
        mem_used = torch.cuda.memory_allocated() / 1e9
        mem_total = torch.cuda.get_device_properties(0).total_memory / 1e9
        
        stats = {
            'util': (mem_used / mem_total) * 100,
            'mem_used': mem_used,
            'mem_total': mem_total
        }
        self.history.append(stats)
        return stats


class TrainingLogger:
    """Comprehensive training logger with visualization"""
    def __init__(self, config):
        self.config = config
        self.history = defaultdict(list)
        self.epoch_times = []
        self.gpu = GPUMonitor()
        
    def log_epoch(self, epoch, train_loss, val_metrics, epoch_time, lr):
        self.epoch_times.append(epoch_time)
        self.history['epoch'].append(epoch)
        self.history['train_loss'].append(train_loss)
        self.history['val_loss'].append(val_metrics['loss'])
        self.history['val_f1'].append(val_metrics['f1'])
        self.history['val_iou'].append(val_metrics['iou'])
        self.history['val_precision'].append(val_metrics['precision'])
        self.history['val_recall'].append(val_metrics['recall'])
        self.history['lr'].append(lr)
        self.history['epoch_time'].append(epoch_time)
        
    def get_dataframe(self):
        return pd.DataFrame(self.history)
    
    def estimate_remaining(self, current, total):
        if not self.epoch_times:
            return "Calculating..."
        avg = np.mean(self.epoch_times)
        remaining = (total - current) * avg
        return str(timedelta(seconds=int(remaining)))

print("Utility classes loaded")

---
## Cell 6: Dataset with Advanced Augmentations

In [None]:
class ChangeDetectionDataset(Dataset):
    """
    Optimized dataset for change detection with proper label handling.
    """
    def __init__(self, root_dir, list_path, mode='train', patch_size=256):
        self.root_dir = root_dir
        self.mode = mode
        self.patch_size = patch_size
        
        self.files = []
        with open(list_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 3:
                    self.files.append(parts)
        
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        
        print(f"{mode.upper()}: {len(self.files)} samples loaded")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img1_path = os.path.join(self.root_dir, self.files[idx][0])
        img2_path = os.path.join(self.root_dir, self.files[idx][1])
        label_path = os.path.join(self.root_dir, self.files[idx][2])

        img1 = Image.open(img1_path).convert('RGB')
        img2 = Image.open(img2_path).convert('RGB')
        label = Image.open(label_path).convert('L')

        if self.mode == 'train':
            img1, img2, label = self._augment(img1, img2, label)
        
        img1 = TF.normalize(TF.to_tensor(img1), self.mean, self.std)
        img2 = TF.normalize(TF.to_tensor(img2), self.mean, self.std)
        
        label = TF.to_tensor(label)
        if label.max() > 1:
            label = (label > 0.5).float()
        elif label.max() <= 1 and label.max() > 0.1:
            label = (label > 0.5).float()
        else:
            label = (label > 0.001).float()

        return {
            'image1': img1,
            'image2': img2,
            'label': label,
            'name': os.path.basename(self.files[idx][0])
        }

    def _augment(self, img1, img2, label):
        """Apply synchronized augmentations"""
        if random.random() > 0.5:
            img1 = TF.hflip(img1)
            img2 = TF.hflip(img2)
            label = TF.hflip(label)
        
        if random.random() > 0.5:
            img1 = TF.vflip(img1)
            img2 = TF.vflip(img2)
            label = TF.vflip(label)
        
        if random.random() > 0.5:
            angle = random.choice([90, 180, 270])
            img1 = TF.rotate(img1, angle)
            img2 = TF.rotate(img2, angle)
            label = TF.rotate(label, angle)
        
        if random.random() > 0.5:
            brightness = random.uniform(0.8, 1.2)
            contrast = random.uniform(0.8, 1.2)
            img1 = TF.adjust_brightness(img1, brightness)
            img1 = TF.adjust_contrast(img1, contrast)
            img2 = TF.adjust_brightness(img2, brightness)
            img2 = TF.adjust_contrast(img2, contrast)
        
        return img1, img2, label

print("Dataset class loaded")

---
## Cell 7: SNUNet Model with CBAM Attention

In [None]:
class ChannelAttention(nn.Module):
    """Channel attention module for CBAM"""
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return self.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
    """Spatial attention module for CBAM"""
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        combined = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv(combined))


class CBAM(nn.Module):
    """Convolutional Block Attention Module"""
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x


class ConvBlock(nn.Module):
    """Double convolution block with optional CBAM"""
    def __init__(self, in_ch, out_ch, use_cbam=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.cbam = CBAM(out_ch) if use_cbam else None

    def forward(self, x):
        x = self.conv(x)
        if self.cbam:
            x = self.cbam(x)
        return x


class SNUNet(nn.Module):
    """
    SNUNet: Siamese Nested U-Net for Change Detection
    """
    def __init__(self, in_ch=3, num_classes=1, base_ch=32, use_attention=True):
        super().__init__()
        C = base_ch
        
        self.conv0_0 = ConvBlock(in_ch, C)
        self.pool1 = nn.MaxPool2d(2)
        self.conv1_0 = ConvBlock(C, C*2)
        self.pool2 = nn.MaxPool2d(2)
        self.conv2_0 = ConvBlock(C*2, C*4)
        self.pool3 = nn.MaxPool2d(2)
        self.conv3_0 = ConvBlock(C*4, C*8)
        self.pool4 = nn.MaxPool2d(2)
        self.conv4_0 = ConvBlock(C*8, C*16)
        
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv0_1 = ConvBlock(C*2 + C*4, C, use_cbam=use_attention)
        self.conv1_1 = ConvBlock(C*4 + C*8, C*2, use_cbam=use_attention)
        self.conv2_1 = ConvBlock(C*8 + C*16, C*4, use_cbam=use_attention)
        self.conv3_1 = ConvBlock(C*16 + C*32, C*8, use_cbam=use_attention)
        
        self.conv0_2 = ConvBlock(C*2 + C*2 + C, C, use_cbam=use_attention)
        self.conv1_2 = ConvBlock(C*4 + C*4 + C*2, C*2, use_cbam=use_attention)
        self.conv2_2 = ConvBlock(C*8 + C*8 + C*4, C*4, use_cbam=use_attention)
        
        self.conv0_3 = ConvBlock(C*2 + C*2 + C + C, C, use_cbam=use_attention)
        self.conv1_3 = ConvBlock(C*4 + C*4 + C*2 + C*2, C*2, use_cbam=use_attention)
        
        self.conv0_4 = ConvBlock(C*2 + C*2 + C + C + C, C)
        
        self.final = nn.Conv2d(C, num_classes, kernel_size=1)
        
    def forward(self, x1, x2):
        x1_0_0 = self.conv0_0(x1)
        x1_1_0 = self.conv1_0(self.pool1(x1_0_0))
        x1_2_0 = self.conv2_0(self.pool2(x1_1_0))
        x1_3_0 = self.conv3_0(self.pool3(x1_2_0))
        x1_4_0 = self.conv4_0(self.pool4(x1_3_0))
        
        x2_0_0 = self.conv0_0(x2)
        x2_1_0 = self.conv1_0(self.pool1(x2_0_0))
        x2_2_0 = self.conv2_0(self.pool2(x2_1_0))
        x2_3_0 = self.conv3_0(self.pool3(x2_2_0))
        x2_4_0 = self.conv4_0(self.pool4(x2_3_0))
        
        x0_1 = self.conv0_1(torch.cat([x1_0_0, x2_0_0, self.up(x1_1_0), self.up(x2_1_0)], 1))
        x1_1 = self.conv1_1(torch.cat([x1_1_0, x2_1_0, self.up(x1_2_0), self.up(x2_2_0)], 1))
        x2_1 = self.conv2_1(torch.cat([x1_2_0, x2_2_0, self.up(x1_3_0), self.up(x2_3_0)], 1))
        x3_1 = self.conv3_1(torch.cat([x1_3_0, x2_3_0, self.up(x1_4_0), self.up(x2_4_0)], 1))
        
        x0_2 = self.conv0_2(torch.cat([x1_0_0, x2_0_0, x0_1, self.up(x1_1)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_1_0, x2_1_0, x1_1, self.up(x2_1)], 1))
        x2_2 = self.conv2_2(torch.cat([x1_2_0, x2_2_0, x2_1, self.up(x3_1)], 1))
        
        x0_3 = self.conv0_3(torch.cat([x1_0_0, x2_0_0, x0_1, x0_2, self.up(x1_2)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_1_0, x2_1_0, x1_1, x1_2, self.up(x2_2)], 1))
        
        x0_4 = self.conv0_4(torch.cat([x1_0_0, x2_0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
        
        return self.final(x0_4)

model = SNUNet(3, 1, config.base_channel, config.use_attention)
param_count = sum(p.numel() for p in model.parameters())
print(f"SNUNet loaded: {param_count:,} parameters")

---
## Cell 8: Hybrid Loss Functions

In [None]:
class DiceLoss(nn.Module):
    """Dice Loss for segmentation"""
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred).view(-1)
        target = target.view(-1)
        intersection = (pred * target).sum()
        return 1 - (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)


class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target):
        bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pt = torch.exp(-bce)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce
        return focal_loss.mean()


class TverskyLoss(nn.Module):
    """Tversky Loss for better FP/FN control"""
    def __init__(self, alpha=0.3, beta=0.7, smooth=1.0):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred).view(-1)
        target = target.view(-1)
        
        TP = (pred * target).sum()
        FP = ((1 - target) * pred).sum()
        FN = (target * (1 - pred)).sum()
        
        tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
        return 1 - tversky


class HybridLoss(nn.Module):
    """
    Hybrid Loss combining BCE, Dice, Focal, and Tversky losses.
    """
    def __init__(self, bce_w=0.3, dice_w=0.3, focal_w=0.2, tversky_w=0.2, 
                 focal_gamma=2.0, tversky_alpha=0.3, tversky_beta=0.7):
        super().__init__()
        self.bce_w = bce_w
        self.dice_w = dice_w
        self.focal_w = focal_w
        self.tversky_w = tversky_w
        
        self.dice = DiceLoss()
        self.focal = FocalLoss(gamma=focal_gamma)
        self.tversky = TverskyLoss(alpha=tversky_alpha, beta=tversky_beta)

    def forward(self, pred, target):
        bce = F.binary_cross_entropy_with_logits(pred, target)
        dice = self.dice(pred, target)
        focal = self.focal(pred, target)
        tversky = self.tversky(pred, target)
        
        total = (self.bce_w * bce + self.dice_w * dice + 
                 self.focal_w * focal + self.tversky_w * tversky)
        
        return total

print("Hybrid Loss loaded")

---
## Cell 9: Visualization Dashboard

In [None]:
def display_dashboard(logger, epoch, total_epochs, best_f1, checkpoint_dir):
    """Display comprehensive training dashboard"""
    df = logger.get_dataframe()
    if len(df) == 0:
        return
    
    clear_output(wait=True)
    
    fig = plt.figure(figsize=(16, 10))
    fig.suptitle(f'Training Dashboard - Epoch {epoch}/{total_epochs} | Best F1: {best_f1:.4f}', 
                 fontsize=14, fontweight='bold')
    
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
    
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(df['epoch'], df['train_loss'], 'b-', label='Train', lw=2)
    ax1.plot(df['epoch'], df['val_loss'], 'r-', label='Val', lw=2)
    ax1.fill_between(df['epoch'], df['train_loss'], df['val_loss'], alpha=0.2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Loss Curves')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(df['epoch'], df['val_f1'], 'g-', label='F1', lw=2, marker='o', ms=3)
    ax2.plot(df['epoch'], df['val_iou'], 'm-', label='IoU', lw=2, marker='s', ms=3)
    ax2.axhline(y=0.85, color='gold', ls='--', alpha=0.7, label='Target (0.85)')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Score')
    ax2.set_title('F1 & IoU')
    ax2.legend(loc='lower right')
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1)
    
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.plot(df['epoch'], df['val_precision'], 'c-', label='Precision', lw=2)
    ax3.plot(df['epoch'], df['val_recall'], 'y-', label='Recall', lw=2)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Score')
    ax3.set_title('Precision & Recall')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim(0, 1)
    
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.plot(df['epoch'], df['lr'], 'orange', lw=2)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('LR')
    ax4.set_title('Learning Rate Schedule')
    ax4.grid(True, alpha=0.3)
    ax4.set_yscale('log')
    
    ax5 = fig.add_subplot(gs[1, 1])
    ax5.bar(df['epoch'], df['epoch_time'], color='steelblue', alpha=0.7)
    ax5.axhline(y=np.mean(df['epoch_time']), color='red', ls='--', label=f'Avg: {np.mean(df["epoch_time"]):.1f}s')
    ax5.set_xlabel('Epoch')
    ax5.set_ylabel('Time (s)')
    ax5.set_title('Epoch Duration')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    
    ax6 = fig.add_subplot(gs[1, 2])
    ax6.axis('off')
    
    gpu = logger.gpu.get_stats()
    remaining = logger.estimate_remaining(epoch, total_epochs)
    
    stats_text = f"""
    TRAINING STATUS
    {'='*30}
    
    Epoch: {epoch}/{total_epochs}
    Best F1: {best_f1:.4f}
    Current F1: {df['val_f1'].iloc[-1]:.4f}
    Current IoU: {df['val_iou'].iloc[-1]:.4f}
    
    GPU Memory: {gpu['mem_used']:.1f}/{gpu['mem_total']:.1f} GB
    ETA: {remaining}
    
    Train Loss: {df['train_loss'].iloc[-1]:.4f}
    Val Loss: {df['val_loss'].iloc[-1]:.4f}
    """
    ax6.text(0.1, 0.5, stats_text, fontsize=11, family='monospace', va='center',
             bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.3))
    
    plt.tight_layout()
    plt.savefig(os.path.join(checkpoint_dir, 'dashboard.png'), dpi=100)
    plt.show()

print("Visualization loaded")

---
## Cell 10: Training & Validation Functions

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device, config, ema=None):
    """Train for one epoch"""
    model.train()
    losses = AverageMeter()
    pbar = tqdm(loader, desc='Training', leave=False)
    
    for i, batch in enumerate(pbar):
        img1 = batch['image1'].to(device, non_blocking=True)
        img2 = batch['image2'].to(device, non_blocking=True)
        label = batch['label'].to(device, non_blocking=True)
        
        optimizer.zero_grad()
        output = model(img1, img2)
        loss = criterion(output, label)
        
        loss.backward()
        
        if config.gradient_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
        
        optimizer.step()
        
        if ema is not None:
            ema.update()
        
        losses.update(loss.item(), img1.size(0))
        
        if i % config.log_interval == 0:
            pbar.set_postfix({'Loss': f'{losses.avg:.4f}'})
    
    return losses.avg


def validate(model, loader, criterion, device, use_tta=False):
    """Validate with optional TTA"""
    model.eval()
    losses = AverageMeter()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Validating', leave=False):
            img1 = batch['image1'].to(device, non_blocking=True)
            img2 = batch['image2'].to(device, non_blocking=True)
            label = batch['label'].to(device, non_blocking=True)
            
            if use_tta:
                outputs = []
                for flip_h in [False, True]:
                    for flip_v in [False, True]:
                        for rot in [0, 90]:
                            aug_img1, aug_img2 = img1, img2
                            if flip_h:
                                aug_img1 = torch.flip(aug_img1, [3])
                                aug_img2 = torch.flip(aug_img2, [3])
                            if flip_v:
                                aug_img1 = torch.flip(aug_img1, [2])
                                aug_img2 = torch.flip(aug_img2, [2])
                            if rot == 90:
                                aug_img1 = torch.rot90(aug_img1, 1, [2, 3])
                                aug_img2 = torch.rot90(aug_img2, 1, [2, 3])
                            
                            out = model(aug_img1, aug_img2)
                            
                            if rot == 90:
                                out = torch.rot90(out, -1, [2, 3])
                            if flip_v:
                                out = torch.flip(out, [2])
                            if flip_h:
                                out = torch.flip(out, [3])
                            
                            outputs.append(out)
                
                output = torch.mean(torch.stack(outputs), dim=0)
            else:
                output = model(img1, img2)
            
            loss = criterion(output, label)
            losses.update(loss.item(), img1.size(0))
            
            pred = (torch.sigmoid(output) > 0.5).cpu().numpy().flatten().astype(int)
            target = label.cpu().numpy().flatten().astype(int)
            
            all_preds.extend(pred)
            all_targets.extend(target)
    
    preds = np.array(all_preds)
    targets = np.array(all_targets)
    
    metrics = {
        'loss': losses.avg,
        'f1': f1_score(targets, preds, zero_division=0),
        'iou': jaccard_score(targets, preds, zero_division=0),
        'precision': precision_score(targets, preds, zero_division=0),
        'recall': recall_score(targets, preds, zero_division=0),
    }
    
    try:
        metrics['kappa'] = cohen_kappa_score(targets, preds)
    except:
        metrics['kappa'] = 0.0
    
    return metrics

print("Training functions loaded")

---
## Cell 11: Main Training Loop

In [None]:
def train(config):
    """
    Main training function with all advanced techniques.
    """
    print("\n" + "="*70)
    print("STARTING TRAINING")
    print("="*70)
    
    device = torch.device(config.device)
    
    print("\nLoading datasets...")
    train_ds = ChangeDetectionDataset(config.data_root, config.train_list, 'train', config.patch_size)
    val_ds = ChangeDetectionDataset(config.data_root, config.val_list, 'val', config.patch_size)
    
    train_loader = DataLoader(
        train_ds, 
        batch_size=config.batch_size, 
        shuffle=True,
        num_workers=config.num_workers, 
        pin_memory=config.pin_memory, 
        drop_last=True
    )
    val_loader = DataLoader(
        val_ds, 
        batch_size=config.batch_size, 
        shuffle=False,
        num_workers=config.num_workers, 
        pin_memory=config.pin_memory
    )
    
    print(f"Train: {len(train_ds)} samples ({len(train_loader)} batches)")
    print(f"Val: {len(val_ds)} samples")
    
    print("\nInitializing model...")
    model = SNUNet(3, 1, config.base_channel, config.use_attention).to(device)
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    ema = EMA(model, config.ema_decay) if config.use_ema else None
    if ema:
        print(f"EMA: Enabled (decay={config.ema_decay})")
    
    model_manager = ModelManager(config.models_dir, config.max_models_to_keep)
    
    criterion = HybridLoss(
        bce_w=config.bce_weight,
        dice_w=config.dice_weight,
        focal_w=config.focal_weight,
        tversky_w=config.tversky_weight,
        focal_gamma=config.focal_gamma,
        tversky_alpha=config.tversky_alpha,
        tversky_beta=config.tversky_beta
    )
    
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=config.learning_rate, 
        weight_decay=config.weight_decay
    )
    
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=config.T_0, 
        T_mult=config.T_mult,
        eta_min=1e-6
    )
    
    print(f"\nLoss: Hybrid (BCE:{config.bce_weight} + Dice:{config.dice_weight} + Focal:{config.focal_weight} + Tversky:{config.tversky_weight})")
    print(f"Optimizer: AdamW (lr={config.learning_rate}, wd={config.weight_decay})")
    print(f"Scheduler: CosineAnnealingWarmRestarts (T_0={config.T_0}, T_mult={config.T_mult})")
    
    logger = TrainingLogger(config)
    best_f1 = 0.0
    patience_counter = 0
    
    print("\n" + "="*70)
    print("TRAINING STARTED")
    print("="*70 + "\n")
    
    for epoch in range(1, config.epochs + 1):
        epoch_start = time.time()
        lr = optimizer.param_groups[0]['lr']
        
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, config, ema)
        
        if ema:
            ema.apply_shadow()
        
        val_metrics = validate(model, val_loader, criterion, device, use_tta=config.use_tta)
        
        if ema:
            ema.restore()
        
        scheduler.step()
        
        epoch_time = time.time() - epoch_start
        
        logger.log_epoch(epoch, train_loss, val_metrics, epoch_time, lr)
        
        if epoch % 1 == 0:
            display_dashboard(logger, epoch, config.epochs, best_f1, config.checkpoint_dir)
        
        # Save model every epoch
        if config.save_every_epoch:
            ema_shadow = ema.shadow if ema else None
            model_path = model_manager.save_model(model, epoch, val_metrics, is_best=False, ema_shadow=ema_shadow)
            print(f"Model saved: {os.path.basename(model_path)}")
        
        # Check for best model
        if val_metrics['f1'] > best_f1 + config.min_delta:
            best_f1 = val_metrics['f1']
            patience_counter = 0
            
            ema_shadow = ema.shadow if ema else None
            best_path = model_manager.save_model(model, epoch, val_metrics, is_best=True, ema_shadow=ema_shadow)
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'best_f1': best_f1,
                'optimizer_state_dict': optimizer.state_dict(),
                'ema_shadow': ema_shadow
            }, os.path.join(config.checkpoint_dir, 'best_model.pth'))
            
            print(f"\nNEW BEST! F1: {best_f1:.4f} | IoU: {val_metrics['iou']:.4f}")
            print(f"Best model saved to: {best_path}")
        else:
            patience_counter += 1
        
        print(f"\nEpoch {epoch}/{config.epochs} | "
              f"Loss: {train_loss:.4f}/{val_metrics['loss']:.4f} | "
              f"F1: {val_metrics['f1']:.4f} | IoU: {val_metrics['iou']:.4f} | "
              f"LR: {lr:.2e} | Time: {epoch_time:.1f}s")
        
        if patience_counter >= config.patience:
            print(f"\nEarly stopping at epoch {epoch} (no improvement for {config.patience} epochs)")
            break
        
        if epoch % config.save_every == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'f1': val_metrics['f1'],
            }, os.path.join(config.checkpoint_dir, f'checkpoint_epoch_{epoch}.pth'))
    
    print("\n" + "="*70)
    print("TRAINING COMPLETE!")
    print("="*70)
    print(f"\nBest F1 Score: {best_f1:.4f}")
    print(f"Best Model: {config.checkpoint_dir}/best_model.pth")
    
    model_manager.list_models()
    
    logger.get_dataframe().to_csv(os.path.join(config.checkpoint_dir, 'training_history.csv'), index=False)
    print(f"\nTraining history: {config.checkpoint_dir}/training_history.csv")
    
    return model, logger, best_f1, model_manager

print("Main training function loaded")

---
## Cell 12: START TRAINING

In [None]:
model, logger, best_f1, model_manager = train(config)

---
## Cell 13: Final Results & Analysis

In [None]:
df = pd.read_csv(os.path.join(config.checkpoint_dir, 'training_history.csv'))

print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)

best_idx = df['val_f1'].idxmax()
print(f"\nBest Epoch: {df.loc[best_idx, 'epoch']}")
print(f"F1 Score:   {df.loc[best_idx, 'val_f1']:.4f}")
print(f"IoU:        {df.loc[best_idx, 'val_iou']:.4f}")
print(f"Precision:  {df.loc[best_idx, 'val_precision']:.4f}")
print(f"Recall:     {df.loc[best_idx, 'val_recall']:.4f}")

print(f"\nTraining Statistics:")
print(f"Total Epochs: {len(df)}")
print(f"Avg Epoch Time: {df['epoch_time'].mean():.1f}s")
print(f"Total Time: {df['epoch_time'].sum()/60:.1f} min")

print(f"\nProgress:")
print(f"Initial F1: {df['val_f1'].iloc[0]:.4f}")
print(f"Final F1:   {df['val_f1'].iloc[-1]:.4f}")
print(f"Best F1:    {df['val_f1'].max():.4f}")

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(df['epoch'], df['train_loss'], 'b-', label='Train')
axes[0].plot(df['epoch'], df['val_loss'], 'r-', label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss Curves')
axes[0].legend()
axes[0].grid(True)

axes[1].plot(df['epoch'], df['val_f1'], 'g-', label='F1', lw=2)
axes[1].plot(df['epoch'], df['val_iou'], 'm-', label='IoU', lw=2)
axes[1].axhline(y=0.85, color='gold', ls='--', label='Target')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Score')
axes[1].set_title('F1 & IoU')
axes[1].legend()
axes[1].grid(True)
axes[1].set_ylim(0, 1)

axes[2].plot(df['epoch'], df['val_precision'], 'c-', label='Precision')
axes[2].plot(df['epoch'], df['val_recall'], 'y-', label='Recall')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Score')
axes[2].set_title('Precision & Recall')
axes[2].legend()
axes[2].grid(True)
axes[2].set_ylim(0, 1)

plt.tight_layout()
plt.savefig(os.path.join(config.checkpoint_dir, 'final_results.png'), dpi=150)
plt.show()

print(f"\nResults saved to {config.checkpoint_dir}")

---
## Cell 14: Test Set Evaluation

In [None]:
if config.test_list and os.path.exists(config.test_list):
    print("\n" + "="*60)
    print("TEST SET EVALUATION")
    print("="*60)
    
    best_model_path = model_manager.get_best_model_path()
    if best_model_path:
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"\nLoaded best model (F1: {checkpoint['f1']:.4f})")
    
    test_ds = ChangeDetectionDataset(config.data_root, config.test_list, 'test', config.patch_size)
    test_loader = DataLoader(test_ds, batch_size=config.batch_size, shuffle=False, 
                             num_workers=config.num_workers)
    
    device = torch.device(config.device)
    criterion = HybridLoss()
    test_metrics = validate(model, test_loader, criterion, device, use_tta=True)
    
    print(f"\nTest Results (with 8x TTA):")
    print(f"F1 Score:   {test_metrics['f1']:.4f}")
    print(f"IoU:        {test_metrics['iou']:.4f}")
    print(f"Precision:  {test_metrics['precision']:.4f}")
    print(f"Recall:     {test_metrics['recall']:.4f}")
    print(f"Kappa:      {test_metrics['kappa']:.4f}")
else:
    print("\nTest list not found, skipping test evaluation.")

---
## Cell 15: Visualize Predictions

In [None]:
def visualize_predictions(model, dataset, device, num_samples=4):
    """Visualize model predictions"""
    model.eval()
    
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = [axes]
    
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    for i, idx in enumerate(indices):
        sample = dataset[idx]
        img1 = sample['image1'].unsqueeze(0).to(device)
        img2 = sample['image2'].unsqueeze(0).to(device)
        
        with torch.no_grad():
            pred = torch.sigmoid(model(img1, img2))
        
        img1_vis = (sample['image1'] * std + mean).numpy().transpose(1, 2, 0)
        img2_vis = (sample['image2'] * std + mean).numpy().transpose(1, 2, 0)
        label_vis = sample['label'].squeeze().numpy()
        pred_vis = pred.squeeze().cpu().numpy()
        
        img1_vis = np.clip(img1_vis, 0, 1)
        img2_vis = np.clip(img2_vis, 0, 1)
        
        axes[i][0].imshow(img1_vis)
        axes[i][0].set_title('Before')
        axes[i][0].axis('off')
        
        axes[i][1].imshow(img2_vis)
        axes[i][1].set_title('After')
        axes[i][1].axis('off')
        
        axes[i][2].imshow(label_vis, cmap='gray')
        axes[i][2].set_title('Ground Truth')
        axes[i][2].axis('off')
        
        axes[i][3].imshow(pred_vis, cmap='RdYlGn', vmin=0, vmax=1)
        axes[i][3].set_title(f'Prediction')
        axes[i][3].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(config.checkpoint_dir, 'predictions.png'), dpi=150)
    plt.show()

print("\nSample Predictions:")
visualize_predictions(model, val_ds, torch.device(config.device), num_samples=4)

---
## Cell 16: Save Final Model for Submission

In [None]:
best_model_path = model_manager.get_best_model_path()
if best_model_path:
    final_save_path = os.path.join(config.checkpoint_dir, 'final_model_isro.pth')
    
    checkpoint = torch.load(best_model_path)
    
    torch.save({
        'model_state_dict': checkpoint['model_state_dict'],
        'model_config': {
            'in_ch': 3,
            'num_classes': 1,
            'base_ch': config.base_channel,
            'use_attention': config.use_attention
        },
        'best_f1': checkpoint['f1'],
        'training_config': {
            'epochs': config.epochs,
            'batch_size': config.batch_size,
            'learning_rate': config.learning_rate,
            'loss': 'HybridLoss (BCE+Dice+Focal+Tversky)',
            'ema': config.use_ema,
            'tta': config.use_tta
        }
    }, final_save_path)
    
    print(f"\nFinal model saved: {final_save_path}")
    print(f"Best F1: {checkpoint['f1']:.4f}")
    print("\nFor ISRO submission, use this model file.")
else:
    print("\nNo best model found.")

---
## Cell 17: List Output Files

In [None]:
print("\nOutput Files:")
print("="*50)

for f in sorted(os.listdir(config.checkpoint_dir)):
    filepath = os.path.join(config.checkpoint_dir, f)
    size = os.path.getsize(filepath) / (1024*1024)
    print(f"{f:<40} {size:.2f} MB")

print("\nModels Directory:")
print("="*50)
for f in sorted(os.listdir(config.models_dir))[:10]:
    filepath = os.path.join(config.models_dir, f)
    size = os.path.getsize(filepath) / (1024*1024)
    print(f"{f:<40} {size:.2f} MB")
if len(os.listdir(config.models_dir)) > 10:
    print(f"... and {len(os.listdir(config.models_dir)) - 10} more models")

print("\nTraining complete! Download best_model.pth for deployment.")