# üöÄ Land Cover Change Detection - Complete Training Pipeline

## üìä Features:
- ‚úÖ Real-time GPU monitoring
- ‚úÖ Live training dashboard
- ‚úÖ Go/No-Go checkpoints
- ‚úÖ Prediction visualization
- ‚úÖ Comprehensive metrics tracking

## üì¶ Cell 1: Install Dependencies

In [None]:
!pip install -q tqdm matplotlib seaborn pandas scikit-learn

## üìö Cell 2: Import Libraries

In [None]:
import os
import sys
import time
import warnings
import subprocess
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import random

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm
from IPython.display import display, clear_output, HTML
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim
import torchvision.transforms.functional as TF

from sklearn.metrics import f1_score, jaccard_score, cohen_kappa_score, precision_score, recall_score

warnings.filterwarnings('ignore')
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## ‚öôÔ∏è Cell 3: Configuration

In [None]:
@dataclass
class TrainingConfig:
    """All training hyperparameters"""
    # Data paths - YOUR KAGGLE PATH
    data_root: str = "/kaggle/input/dataset1new"
    train_list: str = "/kaggle/working/train_list.txt"
    val_list: str = "/kaggle/working/val_list.txt"
    test_list: str = "/kaggle/working/test_list.txt"
    patch_size: int = 256
    
    # Training
    batch_size: int = 16
    num_workers: int = 4
    pin_memory: bool = True
    epochs: int = 100
    learning_rate: float = 3e-4
    weight_decay: float = 1e-4
    
    # Model
    base_channel: int = 32
    use_attention: bool = True
    
    # Loss
    bce_weight: float = 0.7
    dice_weight: float = 0.3
    focal_gamma: float = 2.0
    use_focal: bool = True
    
    # Early Stopping & Checkpoints
    patience: int = 15
    checkpoint_dir: str = "/kaggle/working/checkpoints"
    save_every: int = 5
    log_interval: int = 10
    visualize_interval: int = 5
    
    # Go/No-Go Thresholds
    epoch_3_f1_threshold: float = 0.60
    epoch_10_f1_threshold: float = 0.70
    overfitting_patience: int = 5
    
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

config = TrainingConfig()
os.makedirs(config.checkpoint_dir, exist_ok=True)
print("‚úÖ Configuration loaded")
print(f"üì± Device: {config.device}")
print(f"üìÇ Data Root: {config.data_root}")

## üîß Cell 4: Utility Classes

In [None]:
class AverageMeter:
    """Computes and stores average values"""
    def __init__(self, name: str = ""):
        self.name = name
        self.reset()

    def reset(self):
        self.val = self.avg = self.sum = self.count = 0
        self.history = []

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.history.append(val)


class GPUMonitor:
    """Real-time GPU monitoring"""
    def __init__(self):
        self.utilization_history = []
        self.memory_history = []
        
    def get_gpu_stats(self) -> Dict:
        try:
            result = subprocess.run(
                ['nvidia-smi', '--query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu',
                 '--format=csv,noheader,nounits'],
                capture_output=True, text=True, timeout=5
            )
            if result.returncode == 0:
                values = result.stdout.strip().split(', ')
                stats = {
                    'gpu_util': float(values[0]),
                    'memory_used': float(values[1]),
                    'memory_total': float(values[2]),
                    'temperature': float(values[3])
                }
                self.utilization_history.append(stats['gpu_util'])
                return stats
        except:
            pass
        return {'gpu_util': 0, 'memory_used': 0, 'memory_total': 0, 'temperature': 0}
    
    def get_avg_utilization(self) -> float:
        return np.mean(self.utilization_history) if self.utilization_history else 0


class TrainingLogger:
    """Comprehensive training logger"""
    def __init__(self, config):
        self.config = config
        self.metrics_history = defaultdict(list)
        self.epoch_times = []
        self.gpu_monitor = GPUMonitor()
        
    def log_epoch(self, epoch, train_metrics, val_metrics, epoch_time, lr):
        self.epoch_times.append(epoch_time)
        self.metrics_history['epoch'].append(epoch)
        self.metrics_history['train_loss'].append(train_metrics['loss'])
        self.metrics_history['val_loss'].append(val_metrics['loss'])
        self.metrics_history['val_f1'].append(val_metrics['f1'])
        self.metrics_history['val_iou'].append(val_metrics['iou'])
        self.metrics_history['val_precision'].append(val_metrics['precision'])
        self.metrics_history['val_recall'].append(val_metrics['recall'])
        self.metrics_history['lr'].append(lr)
        self.metrics_history['epoch_time'].append(epoch_time)
        self.metrics_history['loss_gap'].append(abs(train_metrics['loss'] - val_metrics['loss']))
        
    def get_dataframe(self):
        return pd.DataFrame(self.metrics_history)
    
    def estimate_remaining_time(self, current_epoch, total_epochs):
        if not self.epoch_times:
            return "Calculating..."
        avg_time = np.mean(self.epoch_times)
        remaining = (total_epochs - current_epoch) * avg_time
        return str(timedelta(seconds=int(remaining)))

print("‚úÖ Utility classes loaded")

## üì¶ Cell 5: Dataset

In [None]:
class ChangeDetectionDataset(Dataset):
    """Optimized dataset for change detection"""
    def __init__(self, root_dir, list_path, mode='train', patch_size=256):
        self.root_dir = root_dir
        self.mode = mode
        self.patch_size = patch_size
        
        self.files = []
        with open(list_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 3:
                    self.files.append(parts)
        
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        print(f"üìÇ Loaded {len(self.files)} samples for {mode}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img1_path = os.path.join(self.root_dir, self.files[idx][0])
        img2_path = os.path.join(self.root_dir, self.files[idx][1])
        label_path = os.path.join(self.root_dir, self.files[idx][2])

        img1 = Image.open(img1_path).convert('RGB')
        img2 = Image.open(img2_path).convert('RGB')
        label = Image.open(label_path).convert('L')

        if self.mode == 'train':
            img1, img2, label = self._augment(img1, img2, label)
        
        img1 = TF.normalize(TF.to_tensor(img1), self.mean, self.std)
        img2 = TF.normalize(TF.to_tensor(img2), self.mean, self.std)
        label = TF.to_tensor(label)

        return {'image1': img1, 'image2': img2, 'label': label, 'name': self.files[idx][0]}

    def _augment(self, img1, img2, label):
        if random.random() > 0.5:
            img1, img2, label = TF.hflip(img1), TF.hflip(img2), TF.hflip(label)
        if random.random() > 0.5:
            img1, img2, label = TF.vflip(img1), TF.vflip(img2), TF.vflip(label)
        if random.random() > 0.5:
            angle = random.choice([90, 180, 270])
            img1, img2, label = TF.rotate(img1, angle), TF.rotate(img2, angle), TF.rotate(label, angle)
        return img1, img2, label

print("‚úÖ Dataset class loaded")

## üß† Cell 6: SNUNet Model with CBAM

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.fc(self.avg_pool(x)) + self.fc(self.max_pool(x)))


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        return self.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1)))


class CBAM(nn.Module):
    def __init__(self, in_planes):
        super().__init__()
        self.ca = ChannelAttention(in_planes)
        self.sa = SpatialAttention()

    def forward(self, x):
        return x * self.sa(x * self.ca(x))


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_cbam=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)
        )
        self.cbam = CBAM(out_ch) if use_cbam else None

    def forward(self, x):
        x = self.conv(x)
        return self.cbam(x) if self.cbam else x


class SNUNet(nn.Module):
    def __init__(self, in_ch=3, num_classes=1, C=32, use_attn=True):
        super().__init__()
        # Encoder
        self.conv0_0 = ConvBlock(in_ch, C)
        self.pool1 = nn.MaxPool2d(2)
        self.conv1_0 = ConvBlock(C, C*2)
        self.pool2 = nn.MaxPool2d(2)
        self.conv2_0 = ConvBlock(C*2, C*4)
        self.pool3 = nn.MaxPool2d(2)
        self.conv3_0 = ConvBlock(C*4, C*8)
        self.pool4 = nn.MaxPool2d(2)
        self.conv4_0 = ConvBlock(C*8, C*16)
        
        # Decoder
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv0_1 = ConvBlock(C*2 + C*4, C, use_cbam=use_attn)
        self.conv1_1 = ConvBlock(C*4 + C*8, C*2, use_cbam=use_attn)
        self.conv2_1 = ConvBlock(C*8 + C*16, C*4, use_cbam=use_attn)
        self.conv3_1 = ConvBlock(C*16 + C*32, C*8, use_cbam=use_attn)
        self.conv0_2 = ConvBlock(C*2 + C*2 + C, C, use_cbam=use_attn)
        self.conv1_2 = ConvBlock(C*4 + C*4 + C*2, C*2, use_cbam=use_attn)
        self.conv2_2 = ConvBlock(C*8 + C*8 + C*4, C*4, use_cbam=use_attn)
        self.conv0_3 = ConvBlock(C*2 + C*2 + C + C, C, use_cbam=use_attn)
        self.conv1_3 = ConvBlock(C*4 + C*4 + C*2 + C*2, C*2, use_cbam=use_attn)
        self.conv0_4 = ConvBlock(C*2 + C*2 + C + C + C, C)
        self.final = nn.Conv2d(C, num_classes, 1)
        
    def forward(self, x1, x2):
        # Encoder 1
        x1_0_0 = self.conv0_0(x1)
        x1_1_0 = self.conv1_0(self.pool1(x1_0_0))
        x1_2_0 = self.conv2_0(self.pool2(x1_1_0))
        x1_3_0 = self.conv3_0(self.pool3(x1_2_0))
        x1_4_0 = self.conv4_0(self.pool4(x1_3_0))
        # Encoder 2
        x2_0_0 = self.conv0_0(x2)
        x2_1_0 = self.conv1_0(self.pool1(x2_0_0))
        x2_2_0 = self.conv2_0(self.pool2(x2_1_0))
        x2_3_0 = self.conv3_0(self.pool3(x2_2_0))
        x2_4_0 = self.conv4_0(self.pool4(x2_3_0))
        # Decoder
        x0_1 = self.conv0_1(torch.cat([x1_0_0, x2_0_0, self.up(x1_1_0), self.up(x2_1_0)], 1))
        x1_1 = self.conv1_1(torch.cat([x1_1_0, x2_1_0, self.up(x1_2_0), self.up(x2_2_0)], 1))
        x2_1 = self.conv2_1(torch.cat([x1_2_0, x2_2_0, self.up(x1_3_0), self.up(x2_3_0)], 1))
        x3_1 = self.conv3_1(torch.cat([x1_3_0, x2_3_0, self.up(x1_4_0), self.up(x2_4_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x1_0_0, x2_0_0, x0_1, self.up(x1_1)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_1_0, x2_1_0, x1_1, self.up(x2_1)], 1))
        x2_2 = self.conv2_2(torch.cat([x1_2_0, x2_2_0, x2_1, self.up(x3_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x1_0_0, x2_0_0, x0_1, x0_2, self.up(x1_2)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_1_0, x2_1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x1_0_0, x2_0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
        return self.final(x0_4)

print("‚úÖ SNUNet model loaded")

## üìâ Cell 7: Loss Functions

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs).view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        return 1 - (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)


class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce)
        return (self.alpha * (1 - pt) ** self.gamma * bce).mean()


class HybridLoss(nn.Module):
    def __init__(self, bce_weight=0.7, dice_weight=0.3, use_focal=True, gamma=2.0):
        super().__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.dice = DiceLoss()
        self.focal = FocalLoss(gamma=gamma) if use_focal else None
        self.use_focal = use_focal

    def forward(self, inputs, targets):
        ce = self.focal(inputs, targets) if self.use_focal else F.binary_cross_entropy_with_logits(inputs, targets)
        return self.bce_weight * ce + self.dice_weight * self.dice(inputs, targets)

print("‚úÖ Loss functions loaded")

## üìà Cell 8: Visualization Dashboard

In [None]:
def display_dashboard(logger, epoch, total_epochs, gpu_stats, checkpoint_dir):
    """Display real-time training dashboard"""
    clear_output(wait=True)
    df = logger.get_dataframe()
    if len(df) == 0:
        return
    
    plt.style.use('dark_background')
    fig = plt.figure(figsize=(16, 10))
    fig.suptitle(f'üöÄ Training Dashboard - Epoch {epoch}/{total_epochs}', fontsize=16, fontweight='bold')
    gs = fig.add_gridspec(3, 4, hspace=0.35, wspace=0.3)
    
    # Loss Curves
    ax1 = fig.add_subplot(gs[0, :2])
    ax1.plot(df['epoch'], df['train_loss'], 'b-', label='Train', lw=2)
    ax1.plot(df['epoch'], df['val_loss'], 'r-', label='Val', lw=2)
    ax1.fill_between(df['epoch'], df['train_loss'], df['val_loss'], alpha=0.2, color='yellow')
    ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('üìâ Loss Curves')
    ax1.legend(); ax1.grid(True, alpha=0.3)
    
    # F1 & IoU
    ax2 = fig.add_subplot(gs[0, 2:])
    ax2.plot(df['epoch'], df['val_f1'], 'g-', label='F1', lw=2, marker='o', ms=4)
    ax2.plot(df['epoch'], df['val_iou'], 'm-', label='IoU', lw=2, marker='s', ms=4)
    ax2.axhline(y=0.60, color='yellow', ls='--', alpha=0.5, label='Epoch3 Target')
    ax2.axhline(y=0.70, color='orange', ls='--', alpha=0.5, label='Epoch10 Target')
    ax2.axhline(y=0.85, color='lime', ls='--', alpha=0.5, label='Final Target')
    ax2.set_xlabel('Epoch'); ax2.set_ylabel('Score'); ax2.set_title('üìä F1 & IoU')
    ax2.legend(loc='lower right'); ax2.grid(True, alpha=0.3); ax2.set_ylim(0, 1)
    
    # Precision & Recall
    ax3 = fig.add_subplot(gs[1, :2])
    ax3.plot(df['epoch'], df['val_precision'], 'c-', label='Precision', lw=2)
    ax3.plot(df['epoch'], df['val_recall'], 'y-', label='Recall', lw=2)
    ax3.set_xlabel('Epoch'); ax3.set_ylabel('Score'); ax3.set_title('üéØ Precision & Recall')
    ax3.legend(); ax3.grid(True, alpha=0.3); ax3.set_ylim(0, 1)
    
    # Learning Rate
    ax4 = fig.add_subplot(gs[1, 2])
    ax4.plot(df['epoch'], df['lr'], 'orange', lw=2)
    ax4.set_xlabel('Epoch'); ax4.set_ylabel('LR'); ax4.set_title('üìà Learning Rate')
    ax4.grid(True, alpha=0.3)
    
    # Loss Gap
    ax5 = fig.add_subplot(gs[1, 3])
    colors = ['green' if g < 0.3 else 'red' for g in df['loss_gap']]
    ax5.bar(df['epoch'], df['loss_gap'], color=colors, alpha=0.7)
    ax5.axhline(y=0.3, color='red', ls='--', label='Overfit Threshold')
    ax5.set_xlabel('Epoch'); ax5.set_ylabel('Gap'); ax5.set_title('‚ö†Ô∏è Loss Gap')
    ax5.grid(True, alpha=0.3)
    
    # GPU Stats
    ax6 = fig.add_subplot(gs[2, 0])
    ax6.axis('off')
    gpu_text = f"üñ•Ô∏è GPU MONITOR\n{'‚îÅ'*16}\nUtil: {gpu_stats['gpu_util']:.1f}%\nMem: {gpu_stats['memory_used']:.0f}/{gpu_stats['memory_total']:.0f}MB\nTemp: {gpu_stats['temperature']:.0f}¬∞C"
    ax6.text(0.1, 0.5, gpu_text, fontsize=11, family='monospace', va='center', color='cyan')
    
    # Status
    ax7 = fig.add_subplot(gs[2, 1])
    ax7.axis('off')
    best_f1 = df['val_f1'].max()
    status_color = 'red' if (epoch >= 3 and best_f1 < 0.60) else ('orange' if (epoch >= 10 and best_f1 < 0.70) else 'lime')
    remaining = logger.estimate_remaining_time(epoch, total_epochs)
    status_text = f"üìä STATUS\n{'‚îÅ'*16}\nEpoch: {epoch}/{total_epochs}\nBest F1: {best_f1:.4f}\nBest IoU: {df['val_iou'].max():.4f}\nETA: {remaining}"
    ax7.text(0.1, 0.5, status_text, fontsize=11, family='monospace', va='center', color=status_color)
    
    # Checkpoints
    ax8 = fig.add_subplot(gs[2, 2:])
    ax8.axis('off')
    checks = []
    if epoch >= 3:
        f1_3 = df[df['epoch'] <= 3]['val_f1'].max()
        checks.append(f"Epoch 3 (F1>0.60): {'‚úÖ' if f1_3 >= 0.60 else '‚ùå'} ({f1_3:.3f})")
    if epoch >= 10:
        f1_10 = df[df['epoch'] <= 10]['val_f1'].max()
        checks.append(f"Epoch 10 (F1>0.70): {'‚úÖ' if f1_10 >= 0.70 else '‚ùå'} ({f1_10:.3f})")
    check_text = "üìã GO/NO-GO CHECKPOINTS\n" + "‚îÅ"*24 + "\n" + "\n".join(checks) if checks else "üìã Checkpoints pending..."
    ax8.text(0.1, 0.5, check_text, fontsize=10, family='monospace', va='center', color='white')
    
    plt.tight_layout()
    plt.savefig(os.path.join(checkpoint_dir, 'dashboard.png'), dpi=100, facecolor='black')
    plt.show()

print("‚úÖ Visualization loaded")

## üèãÔ∏è Cell 9: Training & Validation Functions

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device, logger, log_interval=10):
    model.train()
    losses = AverageMeter()
    pbar = tqdm(loader, desc='Training', leave=False)
    
    for i, batch in enumerate(pbar):
        img1 = batch['image1'].to(device, non_blocking=True)
        img2 = batch['image2'].to(device, non_blocking=True)
        label = batch['label'].to(device, non_blocking=True)
        
        optimizer.zero_grad()
        output = model(img1, img2)
        loss = criterion(output, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        losses.update(loss.item(), img1.size(0))
        if i % log_interval == 0:
            gpu = logger.gpu_monitor.get_gpu_stats()
            pbar.set_postfix({'Loss': f'{losses.avg:.4f}', 'GPU': f'{gpu["gpu_util"]:.0f}%'})
    
    return {'loss': losses.avg}


def validate(model, loader, criterion, device):
    model.eval()
    losses = AverageMeter()
    all_preds, all_targets = [], []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Validating', leave=False):
            img1 = batch['image1'].to(device, non_blocking=True)
            img2 = batch['image2'].to(device, non_blocking=True)
            label = batch['label'].to(device, non_blocking=True)
            
            output = model(img1, img2)
            losses.update(criterion(output, label).item(), img1.size(0))
            all_preds.append((torch.sigmoid(output) > 0.5).cpu())
            all_targets.append(label.cpu())
    
    preds = torch.cat(all_preds).numpy().flatten().astype(int)
    targets = torch.cat(all_targets).numpy().flatten().astype(int)
    
    return {
        'loss': losses.avg,
        'f1': f1_score(targets, preds, zero_division=0),
        'iou': jaccard_score(targets, preds, zero_division=0),
        'precision': precision_score(targets, preds, zero_division=0),
        'recall': recall_score(targets, preds, zero_division=0),
        'kappa': cohen_kappa_score(targets, preds) if len(np.unique(preds)) > 1 else 0
    }


def check_go_nogo(epoch, val_f1, config, val_loss_history):
    if epoch == 3 and val_f1 < config.epoch_3_f1_threshold:
        return False, f"üö® STOP: Epoch 3 F1 ({val_f1:.3f}) < {config.epoch_3_f1_threshold}"
    if epoch in [10, 15] and val_f1 < config.epoch_10_f1_threshold:
        return True, f"‚ö†Ô∏è WARNING: Epoch {epoch} F1 ({val_f1:.3f}) < {config.epoch_10_f1_threshold}"
    if len(val_loss_history) >= config.overfitting_patience:
        recent = val_loss_history[-config.overfitting_patience:]
        if all(recent[i] > recent[i-1] for i in range(1, len(recent))):
            return False, f"üõë STOP: Overfitting detected!"
    return True, "‚úÖ Training OK"

print("‚úÖ Training functions loaded")

## üöÄ Cell 10: Main Training Loop

In [None]:
def train(config):
    print("="*70)
    print("üöÄ LAND COVER CHANGE DETECTION TRAINING")
    print("="*70)
    
    device = torch.device(config.device)
    
    # Data
    train_ds = ChangeDetectionDataset(config.data_root, config.train_list, 'train', config.patch_size)
    val_ds = ChangeDetectionDataset(config.data_root, config.val_list, 'val', config.patch_size)
    
    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
                              num_workers=config.num_workers, pin_memory=config.pin_memory, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False,
                            num_workers=config.num_workers, pin_memory=config.pin_memory)
    
    print(f"üìä Train: {len(train_ds)} samples | Val: {len(val_ds)} samples")
    
    # Model
    model = SNUNet(3, 1, config.base_channel, config.use_attention).to(device)
    print(f"üìä Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    criterion = HybridLoss(config.bce_weight, config.dice_weight, config.use_focal, config.focal_gamma)
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs, eta_min=1e-6)
    
    logger = TrainingLogger(config)
    best_f1 = 0
    val_loss_history = []
    
    # Sanity Check
    print("\nüî¨ PIPELINE SANITY CHECK")
    batch = next(iter(train_loader))
    with torch.no_grad():
        loss = criterion(model(batch['image1'].to(device), batch['image2'].to(device)), batch['label'].to(device))
    gpu = logger.gpu_monitor.get_gpu_stats()
    print(f"üìä First batch loss: {loss.item():.4f} | GPU: {gpu['gpu_util']:.1f}%")
    
    # Training Loop
    print("\nüèãÔ∏è STARTING TRAINING")
    for epoch in range(1, config.epochs + 1):
        epoch_start = time.time()
        lr = optimizer.param_groups[0]['lr']
        
        train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, device, logger, config.log_interval)
        val_metrics = validate(model, val_loader, criterion, device)
        scheduler.step()
        
        epoch_time = time.time() - epoch_start
        logger.log_epoch(epoch, train_metrics, val_metrics, epoch_time, lr)
        val_loss_history.append(val_metrics['loss'])
        
        gpu = logger.gpu_monitor.get_gpu_stats()
        display_dashboard(logger, epoch, config.epochs, gpu, config.checkpoint_dir)
        
        # Go/No-Go Check
        should_continue, msg = check_go_nogo(epoch, val_metrics['f1'], config, val_loss_history)
        print(msg)
        if not should_continue:
            print("üõë Training stopped!")
            break
        
        # Save best
        if val_metrics['f1'] > best_f1:
            best_f1 = val_metrics['f1']
            torch.save({'epoch': epoch, 'model': model.state_dict(), 'best_f1': best_f1},
                      os.path.join(config.checkpoint_dir, 'best_model.pth'))
            print(f"üíæ New best! F1: {best_f1:.4f}")
        
        print(f"üìä E{epoch} | Loss: {train_metrics['loss']:.4f}/{val_metrics['loss']:.4f} | F1: {val_metrics['f1']:.4f} | IoU: {val_metrics['iou']:.4f} | {epoch_time:.1f}s")
    
    logger.get_dataframe().to_csv(os.path.join(config.checkpoint_dir, 'history.csv'), index=False)
    print(f"\nüèÜ Training Complete! Best F1: {best_f1:.4f}")
    return model, logger

print("‚úÖ Main training function loaded")

## ‚ñ∂Ô∏è Cell 11: RUN TRAINING

In [None]:
# üöÄ START TRAINING
model, logger = train(config)

## üìä Cell 12: Final Analysis

In [None]:
# Load and display training history
df = pd.read_csv(os.path.join(config.checkpoint_dir, 'history.csv'))
print("üìä Training Summary:")
print(df.describe())

# Best metrics
best_idx = df['val_f1'].idxmax()
print(f"\nüèÜ Best Epoch: {df.loc[best_idx, 'epoch']}")
print(f"   F1: {df.loc[best_idx, 'val_f1']:.4f}")
print(f"   IoU: {df.loc[best_idx, 'val_iou']:.4f}")
print(f"   Precision: {df.loc[best_idx, 'val_precision']:.4f}")
print(f"   Recall: {df.loc[best_idx, 'val_recall']:.4f}")