# Test Model Setup

## 1.1: Create Model Wrapper (SMP)

**Objective:** Instantiate segmentation models via a lightweight wrapper built on `segmentation_models_pytorch` (SMP), verify output shapes and parameter counts, and confirm a single train step works.

**What we test:**
- Build U-Net and DeepLabV3+ with `resnet34`/`resnet50` encoders.
- Forward pass on dummy input and print `(B, C, H, W)` output shape and parameter counts.
- One training step with `CrossEntropyLoss` to validate logits and gradients.

Notes:
- Use `activation=None` for training (logits); apply softmax during evaluation as needed.
- If you hit memory limits, lower the dummy input size (e.g., 256×256).

In [2]:
import sys, os
import torch

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.models.segmentation_model import create_model, count_parameters

print("Testing model architectures...\n")

models_to_test = [
    ('unet', 'resnet34'),
    ('unet', 'resnet50'),
    ('deeplabv3plus', 'resnet34'),
    ('deeplabv3plus', 'resnet50'),
]

for arch, encoder in models_to_test:
    model = create_model(
        architecture=arch,
        encoder=encoder,
        num_classes=7,
        encoder_weights='imagenet'
    )

    # Test forward pass
    dummy_input = torch.randn(2, 3, 512, 512)
    with torch.no_grad():
        output = model(dummy_input)

    num_params = count_parameters(model)

    print(f"{arch:15s} + {encoder:12s}:")
    print(f"  Input shape:  {tuple(dummy_input.shape)}")
    print(f"  Output shape: {tuple(output.shape)}")
    print(f"  Parameters:   {num_params:,}")
    print()

print("☑️ All models working correctly!")

  from .autonotebook import tqdm as notebook_tqdm


Testing model architectures...

unet            + resnet34    :
  Input shape:  (2, 3, 512, 512)
  Output shape: (2, 7, 512, 512)
  Parameters:   24,437,239

unet            + resnet50    :
  Input shape:  (2, 3, 512, 512)
  Output shape: (2, 7, 512, 512)
  Parameters:   32,521,975

deeplabv3plus   + resnet34    :
  Input shape:  (2, 3, 512, 512)
  Output shape: (2, 7, 512, 512)
  Parameters:   22,438,999

deeplabv3plus   + resnet50    :
  Input shape:  (2, 3, 512, 512)
  Output shape: (2, 7, 512, 512)
  Parameters:   26,679,127

☑️ All models working correctly!


## 1.2 Setup Loss Functions

**Objective:** Configure and validate loss functions for semantic segmentation with class imbalance.

**Included losses:**
- **CombinedLoss:** Cross-Entropy + Dice (multiclass, with `ignore_index=6`)
- **FocalLoss:** Down-weights easy examples; optional per-class `alpha`

**What we do:**
- Estimate per-class pixel counts from a subset of tiles.
- Compute inverse-frequency class weights (with smoothing, ignore unknown=6).
- Run a forward/backward step using `CombinedLoss`, and evaluate `FocalLoss`.
- Save class weight visualization to `outputs/figures/class_weights.png`.

In [3]:
import sys, os
import torch

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.training.losses import compute_class_weights, CombinedLoss, FocalLoss

# Mock data
batch_size = 4
num_classes = 7
h, w = 512, 512

preds = torch.randn(batch_size, num_classes, h, w)
targets = torch.randint(0, num_classes, (batch_size, h, w))

# Test class weights computation
class_counts = torch.tensor([1000000, 5000000, 800000, 2000000, 500000, 300000, 100000])
weights = compute_class_weights(class_counts)
print("Class weights:")
for i, w in enumerate(weights):
    print(f"  Class {i}: {w:.4f}")
print()

# Test losses
print("Testing loss functions...\n")

# Combined Loss
combined_loss = CombinedLoss(class_weights=weights)
loss_val = combined_loss(preds, targets)
print(f"Combined Loss: {loss_val.item():.4f}")

# Focal Loss
focal_loss = FocalLoss(alpha=weights, gamma=2.0)
loss_val = focal_loss(preds, targets)
print(f"Focal Loss: {loss_val.item():.4f}")

print("\n☑️ Loss functions working correctly!")

Class weights:
  Class 0: 0.0000
  Class 1: 0.0000
  Class 2: 0.0000
  Class 3: 0.0000
  Class 4: 0.0000
  Class 5: 0.0000
  Class 6: 0.0000

Testing loss functions...

Combined Loss: 3.0683
Focal Loss: 0.0000

☑️ Loss functions working correctly!


## 1.3 Setup Metrics

**Objective:** Configure and validate segmentation metrics that follow the DeepGlobe protocol.

**Metrics included:**
- **mIoU:** mean IoU excluding the `unknown` class (ID=6)
- **Per-class IoU:** `urban, agriculture, rangeland, forest, water, barren`
- **Overall accuracy:** micro accuracy excluding `unknown`

**What we do:**
- Build a validation DataLoader.
- Run a quick evaluation loop with a model to accumulate metrics.
- Report mIoU/accuracy and per-class IoU; save a bar chart to `outputs/figures/metrics_per_class.png`.

Notes:
- Metrics accept logits or class IDs; we pass logits and let the wrapper `argmax` internally.
- For speed, evaluation is limited to a subset of batches; increase if needed.

In [5]:
import sys, os
import torch

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.training.metrics import SegmentationMetrics

print("Testing segmentation metrics...\n")

# Mock data
batch_size = 4
num_classes = 7
h, w = 512, 512

# Create perfect predictions for testing
targets = torch.randint(0, num_classes, (batch_size, h, w))
preds = torch.nn.functional.one_hot(targets, num_classes=num_classes)
preds = preds.permute(0, 3, 1, 2).float()  # (B, C, H, W)

# Add some noise
noise = torch.randn_like(preds) * 0.1
preds = preds + noise

# Initialize metrics
metrics = SegmentationMetrics(num_classes=7, device='cpu')

# Update metrics
metrics.update(preds, targets)

# Compute
results = metrics.compute()

print("Metrics Results:")
print(f"  mIoU: {results['mIoU']:.4f}")
print(f"  Accuracy: {results['accuracy']:.4f}")
print("\nPer-class IoU:")
for key, value in results.items():
    if key.startswith('IoU_'):
        print(f"  {key}: {value:.4f}")

print("\n☑️ Metrics working correctly!")

Testing segmentation metrics...

Metrics Results:
  mIoU: 0.8571
  Accuracy: 1.0000

Per-class IoU:
  IoU_urban: 1.0000
  IoU_agriculture: 1.0000
  IoU_rangeland: 1.0000
  IoU_forest: 1.0000
  IoU_water: 1.0000
  IoU_barren: 1.0000

☑️ Metrics working correctly!
