# HFANet Training Notebook
**High-Frequency Attention Network for Building Change Detection**

This notebook uses the `segmentation-models-pytorch` (smp) backbone version.

---

## 1. Setup & Imports

In [None]:
import sys
sys.path.append('..')  # Add root directory to path

import torch
import torch.nn as nn
import torch.optim as optim
import os
from datetime import datetime

# Project imports
from data.dataset import get_dataloader
from models.HFANet.hfanet import HFANet
from utils.losses import DiceLoss, FocalLoss, SoftIoULoss
from utils.training import train_one_epoch, CombinedLoss
from utils.evaluation import validate, evaluate_on_loader

print("All imports successful!")

## 2. Configuration

In [None]:
# ============================================================
# CONFIGURATION - Modify these parameters as needed
# ============================================================

CONFIG = {
    # Data
    'data_dir': '../dataset',      # Path to dataset root
    'img_size': 256,               # Input image size
    'batch_size': 8,               # Batch size
    'num_workers': 4,              # DataLoader workers
    
    # Model
    'model_name': 'hfanet',        # Model identifier
    'backbone': 'resnet34',        # Backbone encoder
    'pretrained': 'imagenet',      # Pretrained weights
    'classes': 1,                  # Output classes (binary)
    
    # Training
    'epochs': 100,                 # Max epochs
    'lr': 1e-4,                    # Learning rate
    'weight_decay': 1e-4,          # Weight decay
    'patience': 15,                # Early stopping patience
    'threshold': 0.5,              # Binarization threshold
    
    # Loss
    'loss': 'bce+dice',            # Loss function
    
    # Scheduler
    'scheduler': 'plateau',        # LR scheduler type
    
    # Checkpoints
    'checkpoint_dir': '../checkpoints',
}

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 3. Load Data

In [None]:
# Create data loaders
print("Loading datasets...")

train_loader = get_dataloader(
    CONFIG['data_dir'],
    batch_size=CONFIG['batch_size'],
    split='train',
    img_size=CONFIG['img_size'],
    num_workers=CONFIG['num_workers']
)

val_loader = get_dataloader(
    CONFIG['data_dir'],
    batch_size=CONFIG['batch_size'],
    split='val',
    img_size=CONFIG['img_size'],
    num_workers=CONFIG['num_workers']
)

print(f"Train samples: {len(train_loader.dataset)}")
print(f"Val samples: {len(val_loader.dataset)}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 4. Visualize Sample Data

In [None]:
import matplotlib.pyplot as plt

def denormalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Denormalize image tensor for visualization."""
    tensor = tensor.clone()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor.clamp(0, 1)

# Get a sample batch
sample_batch = next(iter(train_loader))

# Visualize first sample
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

img_A = denormalize(sample_batch['image_A'][0]).permute(1, 2, 0).numpy()
img_B = denormalize(sample_batch['image_B'][0]).permute(1, 2, 0).numpy()
label = sample_batch['label'][0].squeeze().numpy()

axes[0].imshow(img_A)
axes[0].set_title('Time 1 (Before)')
axes[0].axis('off')

axes[1].imshow(img_B)
axes[1].set_title('Time 2 (After)')
axes[1].axis('off')

axes[2].imshow(label, cmap='gray')
axes[2].set_title('Change Mask')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print(f"Image shape: {sample_batch['image_A'].shape}")
print(f"Label shape: {sample_batch['label'].shape}")

## 5. Create Model

In [None]:
# Initialize model
model = HFANet(
    encoder_name=CONFIG['backbone'],
    classes=CONFIG['classes'],
    pretrained=CONFIG['pretrained']
).to(device)

print(f"Model: HFANet")
print(f"Backbone: {CONFIG['backbone']}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 6. Test Forward Pass

In [None]:
# Test forward pass with sample data
model.eval()
with torch.no_grad():
    img_A = sample_batch['image_A'].to(device)
    img_B = sample_batch['image_B'].to(device)
    
    output = model(img_A, img_B)
    
    print(f"Input shape: {img_A.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Output range: [{output.min().item():.3f}, {output.max().item():.3f}]")
    
print("\n✓ Forward pass successful!")

## 7. Setup Loss, Optimizer, Scheduler

In [None]:
# Loss function
if CONFIG['loss'] == 'bce+dice':
    criterion = CombinedLoss([
        (nn.BCEWithLogitsLoss(), 1.0),
        (DiceLoss(), 1.0),
    ])
elif CONFIG['loss'] == 'bce+focal':
    criterion = CombinedLoss([
        (nn.BCEWithLogitsLoss(), 1.0),
        (FocalLoss(), 1.0),
    ])
elif CONFIG['loss'] == 'dice':
    criterion = DiceLoss()
elif CONFIG['loss'] == 'focal':
    criterion = FocalLoss()
else:
    criterion = nn.BCEWithLogitsLoss()

print(f"Loss function: {CONFIG['loss']}")

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay']
)
print(f"Optimizer: AdamW (lr={CONFIG['lr']}, weight_decay={CONFIG['weight_decay']})")

# Scheduler
if CONFIG['scheduler'] == 'plateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, verbose=True
    )
elif CONFIG['scheduler'] == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=CONFIG['epochs'], eta_min=1e-6
    )
else:
    scheduler = None

print(f"Scheduler: {CONFIG['scheduler']}")

## 8. Training Loop

In [None]:
# Create checkpoint directory
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)

# Training state
best_val_iou = 0.0
patience_counter = 0
history = {
    'train_loss': [], 'train_iou': [], 'train_f1': [],
    'val_loss': [], 'val_iou': [], 'val_f1': []
}

print("="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Epochs: {CONFIG['epochs']}")
print(f"Early stopping patience: {CONFIG['patience']}")
print("="*60 + "\n")

for epoch in range(1, CONFIG['epochs'] + 1):
    # Train one epoch
    train_loss, train_iou, train_f1 = train_one_epoch(
        model=model,
        dataloader=train_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        epoch=epoch,
        threshold=CONFIG['threshold']
    )
    
    # Validate
    val_loss, val_iou, val_f1 = validate(
        model=model,
        dataloader=val_loader,
        criterion=criterion,
        device=device,
        threshold=CONFIG['threshold']
    )
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_iou'].append(train_iou)
    history['train_f1'].append(train_f1)
    history['val_loss'].append(val_loss)
    history['val_iou'].append(val_iou)
    history['val_f1'].append(val_f1)
    
    # Print epoch summary
    print(f"\nEpoch [{epoch}/{CONFIG['epochs']}] Summary:")
    print(f"  Train - Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | F1: {train_f1:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | F1: {val_f1:.4f}")
    
    # Update scheduler
    if scheduler is not None:
        if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(val_iou)
        else:
            scheduler.step()
    
    # Save best model
    if val_iou > best_val_iou:
        best_val_iou = val_iou
        patience_counter = 0
        
        save_path = os.path.join(CONFIG['checkpoint_dir'], f"best_{CONFIG['model_name']}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_iou': val_iou,
            'val_f1': val_f1,
            'config': CONFIG,
        }, save_path)
        print(f"  ✓ New best model saved! (Val IoU: {best_val_iou:.4f})")
    else:
        patience_counter += 1
        print(f"  No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print(f"\nEarly stopping triggered after {epoch} epochs.")
        break
    
    print()

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"Best Val IoU: {best_val_iou:.4f}")

## 9. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

epochs_range = range(1, len(history['train_loss']) + 1)

# Loss
axes[0].plot(epochs_range, history['train_loss'], label='Train')
axes[0].plot(epochs_range, history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# IoU
axes[1].plot(epochs_range, history['train_iou'], label='Train')
axes[1].plot(epochs_range, history['val_iou'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('IoU')
axes[1].set_title('IoU')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# F1
axes[2].plot(epochs_range, history['train_f1'], label='Train')
axes[2].plot(epochs_range, history['val_f1'], label='Val')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('F1')
axes[2].set_title('F1 Score')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['checkpoint_dir'], f"{CONFIG['model_name']}_training_history.png"), dpi=300)
plt.show()

## 10. Evaluate on Test Set

In [None]:
# Load best model
best_model_path = os.path.join(CONFIG['checkpoint_dir'], f"best_{CONFIG['model_name']}.pth")
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Val IoU at save: {checkpoint['val_iou']:.4f}")

# Create test loader
test_loader = get_dataloader(
    CONFIG['data_dir'],
    batch_size=1,
    split='test',
    img_size=CONFIG['img_size'],
    num_workers=CONFIG['num_workers']
)
print(f"\nTest samples: {len(test_loader.dataset)}")

# Evaluate
results_dir = os.path.join(CONFIG['checkpoint_dir'], f"{CONFIG['model_name']}_results")
metrics = evaluate_on_loader(
    model=model,
    dataloader=test_loader,
    device=device,
    threshold=CONFIG['threshold'],
    save_results_path=os.path.join(results_dir, 'predictions'),
    save_pr_curve_path=os.path.join(results_dir, 'pr_curve.png')
)

## 11. Visualize Predictions

In [None]:
# Visualize some predictions
model.eval()
num_samples = 4

fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))

test_iter = iter(test_loader)
for i in range(num_samples):
    batch = next(test_iter)
    
    img_A = batch['image_A'].to(device)
    img_B = batch['image_B'].to(device)
    label = batch['label']
    
    with torch.no_grad():
        output = model(img_A, img_B)
        pred = torch.sigmoid(output).cpu()
        pred_binary = (pred > CONFIG['threshold']).float()
    
    # Denormalize images
    img_A_vis = denormalize(batch['image_A'][0]).permute(1, 2, 0).numpy()
    img_B_vis = denormalize(batch['image_B'][0]).permute(1, 2, 0).numpy()
    label_vis = label[0].squeeze().numpy()
    pred_vis = pred_binary[0].squeeze().numpy()
    
    axes[i, 0].imshow(img_A_vis)
    axes[i, 0].set_title('Time 1')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(img_B_vis)
    axes[i, 1].set_title('Time 2')
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(label_vis, cmap='gray')
    axes[i, 2].set_title('Ground Truth')
    axes[i, 2].axis('off')
    
    axes[i, 3].imshow(pred_vis, cmap='gray')
    axes[i, 3].set_title('Prediction')
    axes[i, 3].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(results_dir, 'sample_predictions.png'), dpi=300)
plt.show()

## 12. Summary

In [None]:
print("="*60)
print("EXPERIMENT SUMMARY")
print("="*60)
print(f"Model: HFANet (smp)")
print(f"Backbone: {CONFIG['backbone']}")
print(f"Loss: {CONFIG['loss']}")
print(f"Best Val IoU: {best_val_iou:.4f}")
print(f"\nTest Results:")
print(f"  IoU: {metrics['iou']:.4f}")
print(f"  F1:  {metrics['f1']:.4f}")
print(f"  Precision: {metrics['precision']:.4f}")
print(f"  Recall: {metrics['recall']:.4f}")
print("="*60)