# Large Image Land Cover Classification

This notebook demonstrates how to perform land cover classification on large GeoTIFF images using a sliding window approach.

The process:
1. Load a large GeoTIFF image
2. Break it into overlapping 256x256 chips
3. Run inference on each chip
4. Stitch predictions back together
5. Save as a georeferenced GeoTIFF

In [None]:
import sys
import yaml
import random
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from box import Box
from matplotlib.colors import ListedColormap
import rioxarray as rxr

from claymodel.finetune.segment.chesapeake_model import ChesapeakeSegmentor
from large_image_inference import LargeImageSegmentor  # Now in repo root

## 1. Define Paths and Parameters

In [None]:
ROOT_PATH = Path.cwd()  # Current directory (repo root)

# Model checkpoints
CHESAPEAKE_CHECKPOINT = ROOT_PATH / "checkpoints" / "segment" / "chesapeake-7class-segment_epoch-39_val-iou-0.8765.ckpt"
CLAY_CHECKPOINT = ROOT_PATH / "checkpoints" / "clay-v1.5.ckpt"
METADATA_PATH = ROOT_PATH / "configs" / "metadata.yaml"

# Input/output paths - UPDATE THESE WITH YOUR FILE PATHS
INPUT_IMAGE_DIR = ROOT_PATH / "data/cvpr/files/train" 
INPUT_IMAGES = list(INPUT_IMAGE_DIR.glob("*naip*.tif"))
INPUT_IMAGE = random.choice(INPUT_IMAGES)

OUTPUT_IMAGE = ROOT_PATH / "output" / "prediction.tif"  # Where to save the prediction


# Inference parameters
CHIP_SIZE = 256  # Size of each chip
STRIDE = 256  # Stride for sliding window (128 = 50% overlap)
BATCH_SIZE = 16  # Number of chips to process at once
DEVICE = "mps"  # "cpu", "mps", or "cuda"

# Parallel workers - MUST be 0 for GPU devices (MPS/CUDA)
NUM_WORKERS = 0 if DEVICE in ["mps", "cuda"] else 4
USE_PARALLEL = NUM_WORKERS > 0

## 2. Load Model and Metadata

In [None]:
# Load the trained segmentation model
print("Loading model...")
model = ChesapeakeSegmentor.load_from_checkpoint(
    checkpoint_path=CHESAPEAKE_CHECKPOINT,
    ckpt_path=CLAY_CHECKPOINT,
)

# Load metadata for normalization
with open(METADATA_PATH) as f:
    metadata = Box(yaml.safe_load(f))

print("Model loaded successfully!")

## 3. Create the Large Image Segmentor

In [None]:
# Initialize the segmentor
segmentor = LargeImageSegmentor(
    model=model,
    metadata=metadata,
    platform="naip",
    chip_size=CHIP_SIZE,
    stride=STRIDE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    use_parallel=USE_PARALLEL,
    device=DEVICE
)

print("Segmentor configured:")
print(f"  - Chip size: {CHIP_SIZE}x{CHIP_SIZE}")
print(f"  - Stride: {STRIDE} (overlap: {100 * (1 - STRIDE/CHIP_SIZE):.0f}%)")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Parallel workers: {NUM_WORKERS if USE_PARALLEL else 'Disabled'}")
print(f"  - Device: {DEVICE}")

## 4. Run Inference on Large Image

In [None]:
# Run prediction on the large image
prediction = segmentor.predict_large_image(INPUT_IMAGE, OUTPUT_IMAGE)

In [None]:
# DEBUG: Test a single chip prediction
print("Testing single chip prediction...")
image, geo_data = segmentor.load_image(INPUT_IMAGE)
chips, positions = segmentor.extract_chips(image)

test_chip = chips[0]
print(f"Chip shape: {test_chip.shape}")
print(f"Chip min/max: {test_chip.min():.2f} / {test_chip.max():.2f}")

# Get probability output
prob_output = segmentor.predict_chips([test_chip], return_probs=True)
print(f"\nProbability output shape: {prob_output.shape}")
print(f"Probability output min/max: {prob_output.min():.4f} / {prob_output.max():.4f}")
print(f"Probabilities sum to 1? {np.allclose(prob_output.sum(axis=1), 1.0)}")

# Get class prediction
class_pred = np.argmax(prob_output[0], axis=0)
print(f"\nClass prediction shape: {class_pred.shape}")
print(f"Unique classes in single chip: {np.unique(class_pred)}")
print(f"Class distribution:")
for c in np.unique(class_pred):
    count = (class_pred == c).sum()
    pct = 100 * count / class_pred.size
    print(f"  Class {c}: {pct:.1f}%")

In [None]:
# DEBUG: Test the stitching process
print("\nTesting stitching with first 10 chips...")
test_chips = chips[:10]
test_positions = positions[:10]

# Get probabilities
probs = segmentor.predict_chips(test_chips, return_probs=True)
print(f"Probabilities shape: {probs.shape}")
print(f"Each chip has {probs.shape[1]} classes")

# Check a few chip predictions before stitching
for i in range(min(3, len(probs))):
    chip_pred = np.argmax(probs[i], axis=0)
    unique = np.unique(chip_pred)
    print(f"Chip {i} classes: {unique}")

# Now test stitching
stitched = segmentor.stitch_predictions(probs, test_positions, (image.shape[1], image.shape[2]))
print(f"\nStitched result shape: {stitched.shape}")
print(f"Stitched unique classes: {np.unique(stitched)}")
print(f"Stitched class distribution:")
for c in np.unique(stitched):
    count = (stitched == c).sum()
    pct = 100 * count / stitched.size
    print(f"  Class {c}: {pct:.1f}%")

In [None]:
# Visualize the stitched test result
# Define colormap
colors_normalized = [(0, 0, 1, 1), (34/255, 139/255, 34/255, 1), 
                     (154/255, 205/255, 50/255, 1), (210/255, 180/255, 140/255, 1),
                     (169/255, 169/255, 169/255, 1), (105/255, 105/255, 105/255, 1),
                     (1, 1, 1, 1)]
cmap = ListedColormap(colors_normalized)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Show a region of the original image
test_region = image[:3, :512, :512].transpose(1, 2, 0)
test_region_vis = np.clip(test_region / test_region.max(), 0, 1)
axes[0].imshow(test_region_vis)
axes[0].set_title("Original Image (First 512x512)")
axes[0].axis('off')

# Show the stitched prediction for that region
axes[1].imshow(stitched[:512, :512], cmap=cmap, vmin=0, vmax=6)
axes[1].set_title("Stitched Prediction (First 512x512)")
axes[1].axis('off')

plt.tight_layout()
plt.show()

print("\nâœ“ Stitching is working! Classes are preserved.")

In [None]:
# Check the results
print(f"Prediction shape: {prediction.shape}")
print(f"Unique classes: {np.unique(prediction)}")
print("\nClass distribution:")
for class_id in np.unique(prediction):
    count = np.sum(prediction == class_id)
    percentage = 100 * count / prediction.size
    print(f"  Class {class_id}: {count:,} pixels ({percentage:.2f}%)")

In [None]:
# bar chart of class distribution
class_ids, counts = np.unique(prediction, return_counts=True)

plt.bar(class_ids, counts)
plt.xlabel("Class ID")
plt.ylabel("Pixel Count")
plt.title("Class Distribution in Prediction")

In [None]:
# Debug: Check individual chip predictions
print("Debugging chip predictions...")
image, geo_data = segmentor.load_image(INPUT_IMAGE)
chips, positions = segmentor.extract_chips(image)

# Test a few chips
test_chips = chips[:5]
test_preds = segmentor.predict_chips(test_chips)

print(f"\nTesting {len(test_chips)} individual chips:")
for i, pred in enumerate(test_preds):
    unique_classes = np.unique(pred)
    print(f"  Chip {i}: Classes found: {unique_classes}, shape: {pred.shape}")
    
print(f"\nFull prediction unique classes: {np.unique(prediction)}")

## 5. Visualize the Results

In [None]:
# Define colormap for visualization
class_labels = [
    "0: Water",
    "1: Tree Canopy",
    "2: Low Vegetation",
    "3: Barren Land",
    "4: Impervious (Other)",
    "5: Impervious (Road)",
    "6: No Data",
]

colors = [
    (0, 0, 255, 255),      # Deep Blue for water
    (34, 139, 34, 255),    # Forest Green for tree canopy
    (154, 205, 50, 255),   # Yellow Green for low vegetation
    (210, 180, 140, 255),  # Tan for barren land
    (169, 169, 169, 255),  # Dark Gray for impervious (other)
    (105, 105, 105, 255),  # Dim Gray for impervious (road)
    (255, 255, 255, 255),  # White for no data
]

# Normalize colors to 0-1 range
colors_normalized = [(r/255, g/255, b/255, a/255) for r, g, b, a in colors]
cmap = ListedColormap(colors_normalized)

In [None]:
# Load the original image for comparison
original_image = rxr.open_rasterio(INPUT_IMAGE)

# Display side by side
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Original image (RGB)
rgb_image = original_image[:3, :, :].values.transpose(1, 2, 0)
rgb_image = np.clip(rgb_image / rgb_image.max(), 0, 1)  # Normalize for display
axes[0].imshow(rgb_image)
axes[0].set_title("Original Image", fontsize=14)
axes[0].axis('off')

# Prediction
im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=6)
axes[1].set_title("Land Cover Prediction", fontsize=14)
axes[1].axis('off')

# Add legend
handles = [
    plt.Line2D([0], [0], marker='o', color='w', 
               markerfacecolor=colors_normalized[i], markersize=10)
    for i in range(len(colors_normalized))
]
fig.legend(handles, class_labels, loc='lower center', ncol=4, fontsize=10)

plt.tight_layout()
plt.show()

## 6. Optional: Create a Zoomed-in View

In [None]:
# Define a region to zoom into (adjust these coordinates)
row_start, row_end = 0, 512
col_start, col_end = 0, 512

# Extract the region
rgb_zoom = rgb_image[row_start:row_end, col_start:col_end, :]
pred_zoom = prediction[row_start:row_end, col_start:col_end]

# Display
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(rgb_zoom)
axes[0].set_title("Original (Zoomed)", fontsize=12)
axes[0].axis('off')

axes[1].imshow(pred_zoom, cmap=cmap, vmin=0, vmax=6)
axes[1].set_title("Prediction (Zoomed)", fontsize=12)
axes[1].axis('off')

plt.tight_layout()
plt.show()

## Notes

### Performance Tips:
- **GPU**: Set `DEVICE = "cuda"` or `DEVICE = "mps"` if you have a GPU available - this will be much faster
- **Parallel Processing**: Set `USE_PARALLEL = True` and adjust `NUM_WORKERS` (typically 4-8) for faster data loading
  - For CPU: Parallel workers load and normalize chips while model runs inference
  - For GPU: Parallel workers reduce CPU bottleneck in data preparation
  - Disable (`USE_PARALLEL = False`) if you encounter memory issues
- **Batch Size**: Increase if you have more memory available
- **Stride**: Larger stride = faster processing but less accurate at boundaries
- **Overlap**: The default 50% overlap (stride = chip_size/2) provides good quality

### Memory Considerations:
- Very large images may need to be processed in tiles
- Monitor memory usage and adjust batch size accordingly
- The output prediction is kept in memory - for very large images, consider processing in sections
- Parallel workers use additional memory - reduce `NUM_WORKERS` if needed

### Output Format:
- The output is saved as a GeoTIFF with the same CRS and georeference as the input
- Can be opened in QGIS, ArcGIS, or any GIS software
- Pixel values correspond to class IDs (0-6)