#### **From Geospatial Patches to Model-Ready Batches** 

This notebook shows how the patch locations generated in Notebook 01 can be connected to a standard PyTorch `Dataset` and `DataLoader`, producing batches suitable for training a segmentation model. 

The example uses: 
- a USGS DEM subset of the Colorado River corridor (Grand Canyon)
- derived slope-class labels, - TorchGeo’s `GridGeoSampler` to define patch extents; and 
- a minimal custom `Dataset` for extracting aligned `(X, y)` tensors. 
  
The aim is not to train a model, but to illustrate the final step of the data pipeline: 

> **Raster → (geospatial sampler) → PyTorch Dataset → DataLoader → batches ready for a model.** 

This pattern is the same for many spatial ML tasks, such as: 
- DEM-based terrain or susceptibility models 
- UAV orthomosaics for condition mapping 
- DSM/DTM segmentation 
- Rasterised point-cloud derivatives 

The focus here is on **clean data flow and alignment**, rather than modelling. 

In [None]:
# Imports and config
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import rasterio
from rasterio.windows import from_bounds
from torchgeo.samplers import GridGeoSampler, Units
from torchgeo.datasets import RasterDataset
from pathlib import Path

# Paths to DEM and slope classes
data_dir = Path('data_out')
usgs_dem_path = Path(r'DEME_Zone1-Zone15_2021/DEME_Zone3_2021.tif')
dem_path = data_dir / f'{os.path.basename(usgs_dem_path)[:-4]}_clip.tif'
slope_class_path = data_dir / f'{os.path.basename(usgs_dem_path)[:-4]}_clip_slope_classes.tif'

print("✓ Imports complete")

In [15]:
# Load DEM metadata
with rasterio.open(dem_path) as src:
    dem_crs = src.crs
    dem_res = src.res
    dem_bounds = src.bounds

# Create TorchGeo dataset
dem_dataset = RasterDataset(paths=str(dem_path), crs=dem_crs, res=dem_res[0])

# Create GridGeoSampler (256x256 px, 50% overlap)
patch_size_px = 256
stride_px = 128
patch_size_m = patch_size_px * dem_res[0]
stride_m = stride_px * dem_res[0]
grid_sampler = GridGeoSampler(dataset=dem_dataset, size=patch_size_m, stride=stride_m, units=Units.CRS)
bboxes = list(grid_sampler)

print(f"✓ Created GridGeoSampler with {len(bboxes)} patches of size {patch_size_px}x{patch_size_px} px")

✓ Created GridGeoSampler with 961 patches of size 256x256 px


#### 03: Patches to PyTorch DataLoader

This section demonstrates how the geospatial patch extents produced by the sampler can be wrapped in a lightweight `Dataset` and iterated through a `DataLoader`. 

The goal is to show how raster-based patches become model-ready tensors with consistent shapes, types, and label alignment. 

This is a minimal, practical example of the interface between geospatial preprocessing and deep-learning workflows. 

In [16]:
# Minimal PyTorch Dataset for DEM + slope-class patches
# Renamed to RasterPatchDataset for generality
class RasterPatchDataset(Dataset):
    def __init__(self, dem_path, class_path, bboxes):
        self.dem_path = dem_path
        self.class_path = class_path
        self.bboxes = bboxes

    def __len__(self):
        return len(self.bboxes)

    def __getitem__(self, idx):
        bbox = self.bboxes[idx]
        with rasterio.open(self.dem_path) as dem_src:
            dem_window = from_bounds(bbox.minx, bbox.miny, bbox.maxx, bbox.maxy, dem_src.transform)
            dem_patch = dem_src.read(1, window=dem_window).astype(np.float32)
        with rasterio.open(self.class_path) as class_src:
            class_window = from_bounds(bbox.minx, bbox.miny, bbox.maxx, bbox.maxy, class_src.transform)
            class_patch = class_src.read(1, window=class_window).astype(np.int64)
        # Normalize DEM (min-max to [0,1]), excluding zeros (often used as nodata)
        dem_valid = dem_patch[dem_patch != 0]
        if dem_valid.size > 0:
            dem_min, dem_max = dem_valid.min(), dem_valid.max()
            dem_patch = (dem_patch - dem_min) / (dem_max - dem_min + 1e-6)
            dem_patch[dem_patch < 0] = 0
            dem_patch[dem_patch > 1] = 1
        # Convert to torch tensors
        X = torch.from_numpy(dem_patch).unsqueeze(0)  # (1, H, W)
        y = torch.from_numpy(class_patch)  # (H, W)
        return X, y

#### DataLoader: batching and inspecting model-ready patches

This cell wraps the patch dataset in a PyTorch `DataLoader`, which handles batching and shuffling. We then fetch a single batch and print the shapes, dtypes, and value ranges for both the input (DEM) and target (slope class) tensors. This confirms that the data is correctly formatted and ready to be passed to a segmentation model.

In [17]:
# DataLoader wiring
dataset = RasterPatchDataset(dem_path, slope_class_path, bboxes)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Iterate one batch
batch = next(iter(loader))
X, y = batch
print('X shape:', X.shape)  # (batch, 1, H, W)
print('y shape:', y.shape)  # (batch, H, W)
print('X dtype:', X.dtype)
print('y dtype:', y.dtype)
print('X min/max:', X.min().item(), X.max().item())
print('y min/max:', y.min().item(), y.max().item())

X shape: torch.Size([4, 1, 256, 256])
y shape: torch.Size([4, 256, 256])
X dtype: torch.float32
y dtype: torch.int64
X min/max: 0.0 1.0
y min/max: 0 4


#### Dummy model forward pass: verifying plug-and-play compatibility

This cell defines a minimal convolutional neural network (two Conv2d layers with ReLU activation) and runs a batch of patches through it. This demonstrates that the data produced by the DataLoader is already in the correct shape and type for direct use in a segmentation model, such as a U-Net. The output shape and dtype are printed to confirm compatibility.

In [18]:
# Optional: Dummy model forward pass
model = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(8, 1, kernel_size=3, padding=1)
)

out = model(X)  # (batch, 1, H, W)
print('Model output shape:', out.shape)
print('Model output dtype:', out.dtype)

Model output shape: torch.Size([4, 1, 256, 256])
Model output dtype: torch.float32


---

#### **Summary**

This notebook shows how the geospatial patch sampling layer from Notebook 01 connects to a standard PyTorch batching workflow. 

- The `Dataset` reads DEM and slope-class tiles on demand, normalises the input, and returns aligned `(X, y)` tensors. 
- A `DataLoader` then provides batches with consistent shapes that can be passed directly to a segmentation model. 
- Although no training loop is included here, this structure is equivalent to what would be used for a U-Net or similar architecture. 

Key points: 
1. The sampler defines patch locations in geographic space; the `Dataset` resolves these to pixel windows. 
2. Each batch contains consistently shaped tensors (`X: [B, 1, H, W]`, `y: [B, H, W]`). 
3. The example keeps the implementation minimal, focusing on clarity of data flow rather than modelling details. 
4. The same pattern applies to DEMs, UAV imagery, DSM/DTMs, or rasterised point-cloud products. 

Together with Notebook 01, this notebook provides a clear template for building geospatial deep-learning data pipelines using TorchGeo and PyTorch. 