# Preprocessing `.tif` Imagery for CNN Inference using Rasterio

Satellite images typically come in `.tif` format, which stores geospatial raster data that is not directly usable by deep learning models like those from TorchGeo. Models such as DOFA require input as PyTorch tensors, normalized and resized to specific dimensions (e.g., 224x224). My specific task in the final project is to integrate pretrained CNNs (DOFA, CopernicusFM) for inference on user-supplied `.tif` data. This notebook demonstrates how to preprocess `.tif` files into tensors ready for model input.

---

## Summary
- Load .tif satellite images using rasterio
- Extract RGB bands and convert them into a NumPy array
- Preprocess and normalize the data using torchvision.transforms
- Convert the processed image into a PyTorch tensor
- Run the tensor through a pretrained TorchGeo CNN model

In [18]:
# Load Library
import torch
import rasterio
import numpy as np
from torchvision import transforms
from torchgeo.models import DOFA
import os
from typing import Tuple, List, Optional, Union
import warnings

### 1. Load `.tif` as Tensor Function

This function serves as a converter that transforms a `.tif` image file into a PyTorch tensor, making it compatible as input for a CNN model.

In [19]:
def load_tif_as_tensor(
    path: str, 
    size: Tuple[int, int] = (224, 224),
    bands: List[int] = [1, 2, 3],  # RGB bands by default
    band_order: str = 'RGB',  # 'RGB', 'BGR', or custom
    normalization_params: Optional[dict] = None,
    handle_nodata: bool = True,
    nodata_fill: float = 0.0,
    data_range: Optional[Tuple[float, float]] = None,
    clip_percentiles: Optional[Tuple[float, float]] = None
) -> torch.Tensor:
    """
    Args:
        path: Path to the .tif file
        size: Target size for resizing (height, width)
        bands: List of band indices to read (1-indexed)
        band_order: Band arrangement - 'RGB', 'BGR', or 'custom'
        normalization_params: Dict with 'mean' and 'std' for normalization
        handle_nodata: Whether to handle nodata values
        nodata_fill: Value to fill nodata pixels with
        data_range: Expected data range (min, max) for scaling
        clip_percentiles: Percentiles (low, high) for clipping extreme values
    
    Returns:
        PyTorch tensor with shape (1, C, H, W)
    """
    
    # Error Handling & Validation
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    
    if not path.lower().endswith(('.tif', '.tiff')):
        warnings.warn(f"File extension suggests this might not be a TIFF file: {path}")
    
    try:
        with rasterio.open(path) as src:
            # Validate bands
            if max(bands) > src.count:
                raise ValueError(f"Requested band {max(bands)} but file only has {src.count} bands")
            
            if len(bands) < 3:
                raise ValueError(f"Need at least 3 bands for RGB processing, got {len(bands)}")
            
            # Read the specified bands
            img = src.read(bands)  # Shape: (C, H, W)
            
            # Get nodata value
            nodata_value = src.nodata
            
            # Get data type info
            dtype = src.dtypes[0]
            print(f"Original data type: {dtype}")
            print(f"Image shape: {img.shape}")
            print(f"Data range: {img.min()} to {img.max()}")
            
    except rasterio.errors.RasterioIOError as e:
        raise RuntimeError(f"Failed to read rasterio file: {e}")
    except Exception as e:
        raise RuntimeError(f"Unexpected error reading file: {e}")
    
    # Handle different data types and ranges
    img = img.astype(np.float32)
    
    # Handle nodata values
    if handle_nodata and nodata_value is not None:
        img[img == nodata_value] = nodata_fill
        print(f"Handled nodata value: {nodata_value}")
    
    # Handle different data ranges
    if data_range is not None:
        # Scale from data_range to [0, 1]
        img = (img - data_range[0]) / (data_range[1] - data_range[0])
        img = np.clip(img, 0, 1)
        print(f"Scaled data from range {data_range} to [0, 1]")
    elif dtype in ['uint8']:
        # 8-bit data: scale from [0, 255] to [0, 1]
        img = img / 255.0
        print("Scaled 8-bit data to [0, 1]")
    elif dtype in ['uint16']:
        # 16-bit data: scale from [0, 65535] to [0, 1]
        img = img / 65535.0
        print("Scaled 16-bit data to [0, 1]")
    elif img.max() > 1.0:
        # Assume it needs scaling if max > 1
        img = img / img.max()
        print(f"Scaled data by max value: {img.max()}")
    
    # Clip extreme values if specified
    if clip_percentiles is not None:
        low_val = np.percentile(img, clip_percentiles[0])
        high_val = np.percentile(img, clip_percentiles[1])
        img = np.clip(img, low_val, high_val)
        # Renormalize after clipping
        img = (img - low_val) / (high_val - low_val)
        print(f"Clipped to {clip_percentiles} percentiles and renormalized")
    
    # Convert to HWC format
    img = np.transpose(img, (1, 2, 0))  # (C, H, W) -> (H, W, C)
    
    # Handle different band orders
    if band_order == 'BGR' and img.shape[2] >= 3:
        # Convert BGR to RGB
        img = img[:, :, [2, 1, 0]]
        print("Converted BGR to RGB")
    elif band_order == 'custom':
        print("Using custom band order (no reordering applied)")
    
    # Set up normalization parameters
    if normalization_params is None:
        # Use ImageNet statistics by default
        normalization_params = {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225]
        }
        print("Using ImageNet normalization parameters")
    else:
        print(f"Using custom normalization: mean={normalization_params['mean']}, std={normalization_params['std']}")
    
    # Preprocessing pipeline
    try:
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=normalization_params['mean'],
                std=normalization_params['std']
            )
        ])
        
        tensor = transform(img).unsqueeze(0)  # Add batch dimension
        print(f"Final tensor shape: {tensor.shape}")
        return tensor
        
    except Exception as e:
        raise RuntimeError(f"Error in preprocessing pipeline: {e}")

### 2. Normalization Function

This function provides normalization parameters specific to different types of satellite imagery. Since different satellites have varying spectral characteristics, using appropriate normalization values (mean and standard deviation) helps improve model performance and ensures consistent preprocessing.

In [20]:
def create_satellite_normalization_params(imagery_type: str = 'sentinel2') -> dict:
    """    
    Args:
        imagery_type: Type of satellite imagery ('sentinel2', 'landsat', 'imagenet')
    
    Returns:
        Dictionary with mean and std values
    """
    if imagery_type.lower() == 'sentinel2':
        # Sentinel-2 RGB band statistics (approximate)
        return {
            'mean': [0.485, 0.456, 0.406],  
            'std': [0.229, 0.224, 0.225]
        }
    elif imagery_type.lower() == 'landsat':
        # Landsat RGB band statistics (approximate)
        return {
            'mean': [0.5, 0.5, 0.5],  
            'std': [0.25, 0.25, 0.25]
        }
    else:
        # Default ImageNet statistics
        return {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225]
        }

### Test Run Inference with DOFA (TorchGeo)

This function loads a `.tif` image, preprocesses it into a tensor, and runs it through the DOFA model for inference.

In [21]:
def run_inference_pipeline(image_path: str, imagery_type: str = 'sentinel2'):
    try:
        # Create appropriate normalization parameters
        norm_params = create_satellite_normalization_params(imagery_type)
        
        # Load and preprocess the image
        image_tensor = load_tif_as_tensor(
            path=image_path,
            size=(224, 224),
            bands=[1, 2, 3],  # RGB bands
            band_order='RGB',
            normalization_params=norm_params,
            handle_nodata=True,
            clip_percentiles=(2, 98),  # Clip extreme 2% on each end
            data_range=None  # Let the function auto-detect
        )
        
        print("Image tensor shape:", image_tensor.shape)
        
        # Initialize DOFA model
        try:
            model = DOFA(img_size=224, patch_size=16)
            model.eval()
            print("Model initialized successfully")
        except Exception as e:
            raise RuntimeError(f"Failed to initialize DOFA model: {e}")
        
        # Run inference
        with torch.no_grad():
            try:
                # Use approximate wavelengths for RGB bands (in micrometers)
                output = model(image_tensor, wavelengths=[0.64, 0.56, 0.48])
                print("Inference complete!")
                print("Output shape:", output.shape)
                return output
            except Exception as e:
                raise RuntimeError(f"Inference failed: {e}")
                
    except Exception as e:
        print(f"Pipeline failed: {e}")
        raise

In [22]:
# Usage
if __name__ == "__main__":
    image_path = "example.tif"
    try:
        output = run_inference_pipeline(image_path, imagery_type='sentinel2')
    except Exception as e:
        print(f"Error: {e}")

Original data type: uint8
Image shape: (3, 1768, 1636)
Data range: 0 to 255
Scaled 8-bit data to [0, 1]
Clipped to (2, 98) percentiles and renormalized
Using custom normalization: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
Final tensor shape: torch.Size([1, 3, 224, 224])
Image tensor shape: torch.Size([1, 3, 224, 224])
Model initialized successfully
Inference complete!
Output shape: torch.Size([1, 45])


### Documentation

- [rasterio](https://rasterio.readthedocs.io/en/latest/)
- [torchgeo](https://torchgeo.readthedocs.io/en/stable/)
- [DOFA Model](https://torchgeo.readthedocs.io/en/stable/api/models.html#torchgeo.models.DOFA)

### AI Disclaimer  
Parts of this code, including debugging, error handling, and text refinement, were developed with the assistance of AI tools. AI was used to help correct code logic, fix grammar, and improve clarity in documentation.