# 05. Visualize Full Deforestation Maps - All 3 CNN Models

This notebook runs inference for **all 3 CNN models** and visualizes comparison:

**Models:**
1. Spatial Context CNN
2. Multi-Scale CNN
3. Shallow U-Net

**Outputs:**
- Probability maps for each model
- Binary classification maps for each model
- Side-by-side comparison (3 models)
- Statistics comparison
- Regional analysis

## 1. Setup

In [1]:
import sys
from pathlib import Path

# Add project root and src to Python path
project_root = Path.cwd().parent
src_path = project_root / 'src'

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

print(f"Project root: {project_root}")
print(f"Source dir: {src_path}")

Project root: d:\HaiDang\25-26_HKI_DATN_21021411_DangNH
Source dir: d:\HaiDang\25-26_HKI_DATN_21021411_DangNH\src


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from rasterio.plot import show
import pandas as pd
from matplotlib.colors import ListedColormap
from matplotlib.patches import Rectangle
import seaborn as sns

# Set style
plt.style.use('default')
sns.set_palette("husl")

print("Libraries imported successfully!")

Libraries imported successfully!


## 2. Run Full-Image Inference for All 3 CNN Models

This section runs inference on the full study area for **all 3 CNN models**:
- Spatial Context CNN
- Multi-Scale CNN
- Shallow U-Net

**Expected time:** ~4-6 minutes total on GPU (~1.5-2 min per model)

In [3]:
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from models import get_model
from preprocessing import normalize_band, handle_nan

print("Inference libraries imported!")

Inference libraries imported!


In [4]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cuda
GPU: NVIDIA RTX A4000
GPU Memory: 17.17 GB


### 2.1. Load Full Image Stack (14 channels - S2 only)

In [5]:
# Paths to TIFF files
s2_2024_path = project_root / 'data' / 'raw' / 'sentinel2' / 'S2_2024_01_30.tif'
s2_2025_path = project_root / 'data' / 'raw' / 'sentinel2' / 'S2_2025_02_28.tif'

print("Loading TIFF files (S2 only - 14 channels)...")

# Load S2 2024
with rasterio.open(s2_2024_path) as src:
    s2_2024 = src.read()  # (7, H, W)
    transform = src.transform
    crs = src.crs
    height, width = src.height, src.width

# Load S2 2025
with rasterio.open(s2_2025_path) as src:
    s2_2025 = src.read()  # (7, H, W)

# Stack only Sentinel-2: (14, H, W) = 7 S2 (2024) + 7 S2 (2025)
all_bands = np.concatenate([s2_2024, s2_2025], axis=0)

# Transpose to (H, W, 14)
all_bands = np.transpose(all_bands, (1, 2, 0))

print(f"Loaded: {all_bands.shape} ({all_bands.dtype})")
print(f"Channels: 14 (S2 only)")
print(f"Transform: {transform}")
print(f"CRS: {crs}")


Loading TIFF files (S2 only - 14 channels)...
Loaded: (10917, 12547, 14) (float32)
Channels: 14 (S2 only)
Transform: | 10.00, 0.00, 465450.00|
| 0.00,-10.00, 1055820.00|
| 0.00, 0.00, 1.00|
CRS: EPSG:32648


### 2.2. Normalize Bands

In [6]:
print("Normalizing bands...")
for c in tqdm(range(14), desc="Normalize", unit="band"):
    # Handle NaN
    if np.isnan(all_bands[:, :, c]).any():
        all_bands[:, :, c] = handle_nan(all_bands[:, :, c], method='fill')

    # Normalize (same as training)
    # 14 channels: 0-6=S2_2024, 7-13=S2_2025
    # 0-3,7-10: reflectance (B,G,R,NIR)
    # 4-6,11-13: indices (NDVI,NBR,NDMI)
    if c in [0, 1, 2, 3, 7, 8, 9, 10]:  # S2 reflectance
        all_bands[:, :, c] = normalize_band(all_bands[:, :, c], method='clip', clip_range=(0, 1))
    else:  # S2 indices (4,5,6,11,12,13)
        # Scale from [-1, 1] to [0, 1]
        all_bands[:, :, c] = (all_bands[:, :, c] + 1) / 2

print("Normalization complete!")


Normalizing bands...


Normalize:   0%|          | 0/14 [00:00<?, ?band/s]

Normalization complete!


### 2.3. Load All 3 Models

In [None]:
# Model configurations
models_config = [
    {
        'name': 'spatial_cnn',
        'display_name': 'Spatial Context CNN',
        'checkpoint': 'spatial_context_cnn_best.pth',
        'model_type': 'spatial_context_cnn',
        'color': 'steelblue'
    },
    {
        'name': 'multiscale_cnn',
        'display_name': 'Multi-Scale CNN',
        'checkpoint': 'multiscale_cnn_best.pth',
        'model_type': 'multiscale_cnn',
        'color': 'coral'
    },
    {
        'name': 'shallow_unet',
        'display_name': 'Shallow U-Net',
        'checkpoint': 'shallow_unet_best.pth',
        'model_type': 'shallow_unet',
        'color': 'mediumseagreen'
    }
]

# Load all models
models = {}
print("Loading models...\n")

for config in models_config:
    model_path = project_root / 'checkpoints' / config['checkpoint']
    
    if not model_path.exists():
        print(f"⚠️  {config['display_name']}: Checkpoint not found, skipping...")
        continue
    
    print(f"Loading {config['display_name']}...")
    model = get_model(config['model_type'], in_channels=14)
    checkpoint = torch.load(model_path, map_location=device)
    
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"  Epoch: {checkpoint.get('epoch', 'unknown')}")
        print(f"  Val Accuracy: {checkpoint.get('val_acc', 0)*100:.2f}%")
    else:
        model.load_state_dict(checkpoint)
    
    model = model.to(device)
    model.eval()
    models[config['name']] = model
    print(f"✓ {config['display_name']} loaded\n")

print(f"\nTotal models loaded: {len(models)}/3")

if len(models) == 0:
    raise RuntimeError(
        'No model checkpoints found!\n'
        'Please train the models first by running notebook 03_train_models.ipynb'
    )

### 2.4. Run Sliding Window Inference for All Models

In [None]:
# Inference parameters
window_size = 128
stride = 64  # 50% overlap
batch_size = 32 if device.type == 'cuda' else 8

h, w, c = all_bands.shape

# Calculate windows
n_rows = (h - window_size) // stride + 1
n_cols = (w - window_size) // stride + 1
total_windows = n_rows * n_cols

print(f"Image size: {h} x {w}")
print(f"Total windows: {total_windows:,}")
print(f"Batch size: {batch_size}")
print()

# Extract all windows once
print("Extracting windows...")
windows = []
positions = []

for i in tqdm(range(n_rows), desc="Extracting", unit="row"):
    for j in range(n_cols):
        y = i * stride
        x = j * stride
        
        # Extract patch
        patch = all_bands[y:y+window_size, x:x+window_size, :]  # (128, 128, 14)
        
        # Convert to torch tensor: (C, H, W)
        patch_tensor = torch.from_numpy(patch).permute(2, 0, 1).float()
        
        windows.append(patch_tensor)
        positions.append((y, x))

print(f"Extracted {len(windows):,} windows\n")

# Run inference for each model
prob_maps = {}
binary_maps = {}

n_batches = (len(windows) + batch_size - 1) // batch_size

for config in models_config:
    model_name = config['name']
    
    if model_name not in models:
        print(f"⚠️  Skipping {config['display_name']} (not loaded)\n")
        continue
    
    print("=" * 80)
    print(f"Running inference: {config['display_name']}")
    print("=" * 80)
    
    model = models[model_name]
    
    # Initialize output arrays
    prob_map = np.zeros((h, w), dtype=np.float32)
    count_map = np.zeros((h, w), dtype=np.int32)
    
    # Process in batches
    with torch.no_grad():
        for batch_idx in tqdm(range(n_batches), desc="Inference", unit="batch"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(windows))
            
            # Create batch
            batch_patches = torch.stack(windows[start_idx:end_idx]).to(device)
            
            # Forward pass
            outputs = model(batch_patches)  # (B, 1, 128, 128) logits
            probs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()  # (B, 128, 128)
            
            # Add to probability map
            for i, (y, x) in enumerate(positions[start_idx:end_idx]):
                prob_map[y:y+window_size, x:x+window_size] += probs[i]
                count_map[y:y+window_size, x:x+window_size] += 1
    
    # Average overlapping predictions
    prob_map = np.divide(prob_map, count_map, where=count_map > 0)
    
    # Create binary map
    binary_map = (prob_map > 0.5).astype(np.uint8)
    
    # Store results
    prob_maps[model_name] = prob_map
    binary_maps[model_name] = binary_map
    
    print(f"\n✓ {config['display_name']} complete")
    print(f"  Probability range: [{prob_map.min():.4f}, {prob_map.max():.4f}]")
    print(f"  Deforestation: {binary_map.sum() / binary_map.size * 100:.2f}%\n")

print("\n" + "=" * 80)
print(f"ALL INFERENCE COMPLETED - {len(prob_maps)} models")
print("=" * 80)

### 2.5. Statistics Comparison

In [None]:
# Calculate statistics for each model
pixel_size_m = 10.0
pixel_area_m2 = pixel_size_m * pixel_size_m

stats_data = []

for config in models_config:
    model_name = config['name']
    
    if model_name not in prob_maps:
        continue
    
    prob_map = prob_maps[model_name]
    binary_map = binary_maps[model_name]
    
    total_pixels = prob_map.size
    defor_pixels = binary_map.sum()
    defor_percentage = (defor_pixels / total_pixels) * 100
    defor_area_km2 = defor_pixels * pixel_area_m2 / 1e6
    total_area_km2 = total_pixels * pixel_area_m2 / 1e6
    
    stats_data.append({
        'Model': config['display_name'],
        'Deforestation (%)': defor_percentage,
        'Area (km²)': defor_area_km2,
        'Total Area (km²)': total_area_km2,
        'Mean Probability': prob_map.mean(),
        'Color': config['color']
    })

# Display statistics
print("=" * 80)
print("DEFORESTATION STATISTICS - ALL MODELS")
print("=" * 80)
print()

for data in stats_data:
    print(f"{data['Model']}:")
    print(f"  Total area: {data['Total Area (km²)']:.2f} km²")
    print(f"  Deforestation area: {data['Area (km²)']:.2f} km²")
    print(f"  Deforestation: {data['Deforestation (%)']:.2f}%")
    print(f"  Mean probability: {data['Mean Probability']:.4f}")
    print()

print("=" * 80)

### 2.6. Save All Results as GeoTIFF

In [None]:
# Create output directory
output_dir = project_root / 'outputs'
output_dir.mkdir(exist_ok=True)

print("Saving results...\n")

for config in models_config:
    model_name = config['name']
    
    if model_name not in prob_maps:
        continue
    
    prob_map = prob_maps[model_name]
    binary_map = binary_maps[model_name]
    
    # Save probability map
    prob_path = output_dir / f'{model_name}_probability_map.tif'
    with rasterio.open(
        prob_path, 'w',
        driver='GTiff',
        height=h, width=w,
        count=1,
        dtype=rasterio.float32,
        crs=crs,
        transform=transform,
        compress='lzw'
    ) as dst:
        dst.write(prob_map.astype(np.float32), 1)
    
    # Save binary map
    binary_path = output_dir / f'{model_name}_binary_map.tif'
    with rasterio.open(
        binary_path, 'w',
        driver='GTiff',
        height=h, width=w,
        count=1,
        dtype=rasterio.uint8,
        crs=crs,
        transform=transform,
        compress='lzw'
    ) as dst:
        dst.write(binary_map, 1)
    
    print(f"{config['display_name']}:")
    print(f"  ✓ {prob_path.name}")
    print(f"  ✓ {binary_path.name}")

# Save comparison summary
summary_path = output_dir / '3_models_comparison_summary.txt'
with open(summary_path, 'w', encoding='utf-8') as f:
    f.write("=" * 80 + "\n")
    f.write("3 CNN MODELS COMPARISON SUMMARY\n")
    f.write("=" * 80 + "\n\n")
    
    for data in stats_data:
        f.write(f"{data['Model']}:\n")
        f.write(f"  Total area: {data['Total Area (km²)']:.2f} km²\n")
        f.write(f"  Deforestation area: {data['Area (km²)']:.2f} km²\n")
        f.write(f"  Deforestation percentage: {data['Deforestation (%)']:.2f}%\n")
        f.write(f"  Mean probability: {data['Mean Probability']:.4f}\n\n")

print(f"\n✓ Summary: {summary_path.name}")
print("\nAll inference outputs saved!")

---
## 3. Compare Probability Maps (Side-by-Side)

Visualize probability maps from all 3 models side-by-side.

In [None]:
# Create side-by-side comparison of probability maps
n_models = len(prob_maps)

if n_models > 0:
    fig, axes = plt.subplots(1, n_models, figsize=(20, 7))
    if n_models == 1:
        axes = [axes]
    
    for i, config in enumerate(models_config):
        model_name = config['name']
        
        if model_name not in prob_maps:
            continue
        
        prob_map = prob_maps[model_name]
        
        im = axes[i].imshow(prob_map, cmap='RdYlGn_r', vmin=0, vmax=1)
        axes[i].set_title(f"{config['display_name']}\nProbability Map",
                         fontsize=13, fontweight='bold')
        axes[i].set_xlabel('X (pixels)', fontsize=10)
        axes[i].set_ylabel('Y (pixels)', fontsize=10)
        
        cbar = plt.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)
        cbar.set_label('Probability', fontsize=10, rotation=270, labelpad=15)
    
    plt.suptitle('Deforestation Probability Maps - 3 CNN Models Comparison',
                 fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    
    figures_dir = project_root / 'figures'
    figures_dir.mkdir(exist_ok=True)
    plt.savefig(figures_dir / '3_models_probability_comparison.png',
                dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Saved: {figures_dir / '3_models_probability_comparison.png'}")
else:
    print("No models available for comparison")

## 4. Compare Binary Maps (Side-by-Side)

In [None]:
# Create side-by-side comparison of binary maps
if n_models > 0:
    fig, axes = plt.subplots(1, n_models, figsize=(20, 7))
    if n_models == 1:
        axes = [axes]
    
    colors = ['#2ecc71', '#e74c3c']  # Green, Red
    cmap = ListedColormap(colors)
    
    for i, config in enumerate(models_config):
        model_name = config['name']
        
        if model_name not in binary_maps:
            continue
        
        binary_map = binary_maps[model_name]
        
        im = axes[i].imshow(binary_map, cmap=cmap, vmin=0, vmax=1)
        axes[i].set_title(f"{config['display_name']}\nBinary Classification",
                         fontsize=13, fontweight='bold')
        axes[i].set_xlabel('X (pixels)', fontsize=10)
        axes[i].set_ylabel('Y (pixels)', fontsize=10)
        
        cbar = plt.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04,
                           ticks=[0.25, 0.75])
        cbar.ax.set_yticklabels(['No Deforestation', 'Deforestation'], fontsize=9)
    
    plt.suptitle('Binary Deforestation Maps - 3 CNN Models Comparison',
                 fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.savefig(figures_dir / '3_models_binary_comparison.png',
                dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Saved: {figures_dir / '3_models_binary_comparison.png'}")

## 5. Statistics Comparison Bar Charts

In [None]:
# Create bar charts comparing statistics
if len(stats_data) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    model_names = [s['Model'] for s in stats_data]
    colors_list = [s['Color'] for s in stats_data]
    
    # Deforestation percentage
    defor_pcts = [s['Deforestation (%)'] for s in stats_data]
    axes[0].bar(range(len(model_names)), defor_pcts, color=colors_list,
               edgecolor='black', alpha=0.8)
    axes[0].set_xticks(range(len(model_names)))
    axes[0].set_xticklabels(model_names, rotation=15, ha='right')
    axes[0].set_ylabel('Deforestation (%)', fontsize=11)
    axes[0].set_title('Deforestation Percentage', fontsize=13, fontweight='bold')
    axes[0].grid(axis='y', alpha=0.3)
    for i, v in enumerate(defor_pcts):
        axes[0].text(i, v + 0.3, f'{v:.2f}%', ha='center', fontsize=10, fontweight='bold')
    
    # Deforestation area
    defor_areas = [s['Area (km²)'] for s in stats_data]
    axes[1].bar(range(len(model_names)), defor_areas, color=colors_list,
               edgecolor='black', alpha=0.8)
    axes[1].set_xticks(range(len(model_names)))
    axes[1].set_xticklabels(model_names, rotation=15, ha='right')
    axes[1].set_ylabel('Area (km²)', fontsize=11)
    axes[1].set_title('Deforestation Area', fontsize=13, fontweight='bold')
    axes[1].grid(axis='y', alpha=0.3)
    for i, v in enumerate(defor_areas):
        axes[1].text(i, v + 20, f'{v:.1f}', ha='center', fontsize=10, fontweight='bold')
    
    # Mean probability
    mean_probs = [s['Mean Probability'] for s in stats_data]
    axes[2].bar(range(len(model_names)), mean_probs, color=colors_list,
               edgecolor='black', alpha=0.8)
    axes[2].set_xticks(range(len(model_names)))
    axes[2].set_xticklabels(model_names, rotation=15, ha='right')
    axes[2].set_ylabel('Mean Probability', fontsize=11)
    axes[2].set_title('Average Probability', fontsize=13, fontweight='bold')
    axes[2].grid(axis='y', alpha=0.3)
    for i, v in enumerate(mean_probs):
        axes[2].text(i, v + 0.01, f'{v:.3f}', ha='center', fontsize=10, fontweight='bold')
    
    plt.suptitle('Statistics Comparison - 3 CNN Models',
                 fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(figures_dir / '3_models_statistics_comparison.png',
                dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Saved: {figures_dir / '3_models_statistics_comparison.png'}")

## 6. Summary

This notebook has run full-image inference for **all 3 CNN models** and generated:

**Outputs:**
- Individual probability and binary maps for each model (GeoTIFF format)
- Side-by-side probability maps comparison
- Side-by-side binary maps comparison
- Statistics comparison bar charts
- Comprehensive summary report

**Models compared:**
1. Spatial Context CNN
2. Multi-Scale CNN
3. Shallow U-Net

**Next steps:**
- Run notebook `07_compare_all_models.ipynb` to compare with Random Forest
- Analyze model agreement and differences
- Use the generated maps for further analysis

All figures have been saved to `figures/` directory.