## Trying out DINOv3

The purpose of this notebook is to fine-tune a DINOv3 model based off of a dataset of 64x64x4 chips and their corresponding chms from g-liht. 

Uses DEV kernel
mjf 11/25/2025

In [None]:
# Standard library imports
import os
import warnings
from datetime import datetime
import glob

# Set environment variables and warnings
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMBA_NUM_THREADS'] = '1'
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.nn.parallel.parallel_apply")

# Third-party scientific computing
import numpy as np
import pandas as pd
import rasterio
from osgeo import gdal
from tiler import Tiler, Merger

# PyTorch and related
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Computer vision and image processing
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode

# Hugging Face
from transformers import AutoModel
from datasets import load_dataset, load_from_disk

# Visualization
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap, ListedColormap, BoundaryNorm
from plotnine import *

# Progress bars
from tqdm import tqdm

## 1. Set up

In [None]:
# Define Custom loss block
class AsymmetricMSELoss(nn.Module):
    """
    Custom MSE Loss that penalizes underestimation more heavily than overestimation.
    
    For each pixel:
    - If predicted < actual: loss = underestimation_penalty * (predicted - actual)^2
    - If predicted >= actual: loss = (predicted - actual)^2
    
    Parameters:
    -----------
    underestimation_penalty : float
        Multiplier for underestimation errors (default: 2.0)
    reduction : str
        Specifies the reduction to apply: 'none' | 'mean' | 'sum' (default: 'mean')
    """
    
    def __init__(self, underestimation_penalty=2.0, reduction='mean'):
        super(AsymmetricMSELoss, self).__init__()
        self.underestimation_penalty = underestimation_penalty
        self.reduction = reduction
        
    def forward(self, predicted, target):
        """
        Forward pass of the asymmetric MSE loss.
        
        Args:
            predicted (torch.Tensor): Predicted values
            target (torch.Tensor): Ground truth values
            
        Returns:
            torch.Tensor: Loss value
        """
        # Calculate standard MSE
        squared_errors = (predicted - target) ** 2
        
        # Create mask for underestimation (predicted < target)
        underestimation_mask = predicted < target
        
        # Apply penalty to underestimation errors
        penalty_weights = torch.ones_like(squared_errors)
        penalty_weights[underestimation_mask] = self.underestimation_penalty
        
        # Apply weights to squared errors
        weighted_errors = squared_errors * penalty_weights
        
        # Apply reduction
        if self.reduction == 'mean':
            return weighted_errors.mean()
        elif self.reduction == 'sum':
            return weighted_errors.sum()
        elif self.reduction == 'none':
            return weighted_errors
        else:
            raise ValueError(f"Invalid reduction mode: {self.reduction}")
    
    def __repr__(self):
        return f"AsymmetricMSELoss(underestimation_penalty={self.underestimation_penalty}, reduction='{self.reduction}')"

class AsymmetricFocalMSELoss(nn.Module):
    """
    Asymmetric Focal MSE Loss that:
    1. Penalizes underestimation more heavily than overestimation
    2. Uses focal weighting to focus more on hard examples (larger errors)
    
    Parameters:
    -----------
    underestimation_penalty : float
        Multiplier for underestimation errors (default: 2.0)
    focal_gamma : float
        Focusing parameter - higher values increase focus on large errors (default: 2.0)
    height_threshold : float
        Height above which to increase focus (optional, set to None to disable)
    threshold_gamma : float
        Additional gamma for samples above height_threshold (default: 1.0)
    reduction : str
        Specifies the reduction to apply: 'none' | 'mean' | 'sum' (default: 'mean')
    """
    
    def __init__(self, underestimation_penalty=2.0, focal_gamma=2.0, 
                 height_threshold=None, threshold_gamma=1.0, reduction='mean'):
        super(AsymmetricFocalMSELoss, self).__init__()
        self.underestimation_penalty = underestimation_penalty
        self.focal_gamma = focal_gamma
        self.height_threshold = height_threshold
        self.threshold_gamma = threshold_gamma
        self.reduction = reduction
        
    def forward(self, predicted, target):
        # Base errors
        errors = torch.abs(predicted - target)
        squared_errors = errors ** 2
        
        # Create mask for underestimation (predicted < target)
        underestimation_mask = predicted < target
        
        # Apply asymmetric penalty
        penalty_weights = torch.ones_like(squared_errors)
        penalty_weights[underestimation_mask] = self.underestimation_penalty
        
        # Calculate focal weights based on error magnitude
        # This emphasizes samples with larger errors
        focal_weights = errors ** self.focal_gamma
        
        # If using a height threshold, apply additional focus on tall trees
        if self.height_threshold is not None:
            tall_vegetation_mask = target > self.height_threshold
            # Add extra focus on tall vegetation with large errors
            focal_weights[tall_vegetation_mask] = focal_weights[tall_vegetation_mask] * \
                (errors[tall_vegetation_mask] ** self.threshold_gamma)
        
        # Apply both weights to squared errors
        weighted_errors = squared_errors * penalty_weights * focal_weights
        
        # Apply reduction
        if self.reduction == 'mean':
            return weighted_errors.mean()
        elif self.reduction == 'sum':
            return weighted_errors.sum()
        elif self.reduction == 'none':
            return weighted_errors
        else:
            raise ValueError(f"Invalid reduction mode: {self.reduction}")
    
    def __repr__(self):
        return (f"AsymmetricFocalMSELoss(underestimation_penalty={self.underestimation_penalty}, "
                f"focal_gamma={self.focal_gamma}, height_threshold={self.height_threshold}, "
                f"threshold_gamma={self.threshold_gamma}, reduction='{self.reduction}')")

In [None]:
# ====================================================================
# CONFIGURATION CELL - Modify parameters here
# ====================================================================

# Dataset Configuration

DATA_CONFIG = {
    'data_name': 'chm_npy_dataset',
    'image_dir': '/explore/nobackup/people/mmacande/srlite/chm_model/20231014_chm/train_merged_nodtm_npy/images/',
    'label_dir': '/explore/nobackup/people/mmacande/srlite/chm_model/20231014_chm/train_merged_nodtm_npy/labels/',
    'stats_path': '/explore/nobackup/people/mfrost2/projects/boreal_chm_dino/numpy_stats/',  
    'np_stats': 'maxmin_ak_100k_both_nrg_final',  # Or whatever stats you're using
    'nir_min': 0,    
    'nir_max': 7142,
    'red_min': 0,
    'red_max': 5893,
    'green_min': 0,
    'green_max': 5387,
}

# Load means and standard deviations based on config
DATA_CONFIG['means'] = np.load(f"{DATA_CONFIG['stats_path']}channel_means_{DATA_CONFIG['np_stats']}.npy")
DATA_CONFIG['stds'] = np.load(f"{DATA_CONFIG['stats_path']}channel_stds_{DATA_CONFIG['np_stats']}.npy")
DATA_CONFIG['input_bands'] = DATA_CONFIG['np_stats'][-3:]  # Extract band info from np_stats name

# Print configuration info
#print(f"Dataset: {DATA_CONFIG['data_name']}")
print(f"Input bands: {DATA_CONFIG['input_bands']}")
print(f"Channel means: {DATA_CONFIG['means']}")
print(f"Channel stds: {DATA_CONFIG['stds']}")

# Training Configuration  
TRAINING_CONFIG = {
    'n_epochs': 100,
    'patience': 10,
    'lr': 5e-5,
    'loss_criterion': AsymmetricMSELoss(underestimation_penalty=3.0), #AsymmetricFocalMSELoss(underestimation_penalty=1.5,  focal_gamma=.5, height_threshold=10.0, threshold_gamma=0.5)        
    # Options: nn.MSELoss(),  nn.L1Loss(), nn.SmoothL1Loss(), AsymmetricMSELoss()
    'hf_token': "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",
    'gradient_clip_norm': 1.0,
    'weight_decay': 1e-3, #1e-4,
    'scheduler_patience': 3,
    'scheduler_factor': 0.5,
    'dropout_rate': 0.2
}


# Model Configuration - CHANGE 'CURRENT_MODEL' TO SWITCH MODELS
MODEL_CONFIGS = {
    'large': {
        'model_name': 'facebook/dinov3-vitl16-pretrain-sat493m',
        'description': 'DINOv3-Large (1024 dim, ~300M params)',
        'base_batch_size': 16,
        'memory_efficient': True
    },
    '7b': {
        'model_name': 'facebook/dinov3-vit7b16-pretrain-sat493m', 
        'description': 'DINOv3-7B (4096 dim, ~7B params)',
        'base_batch_size': 4,  # Smaller batch size for memory
        'memory_efficient': False
    }
}

# SELECT WHICH MODEL TO USE HERE
CURRENT_MODEL = 'large'  # Options: 'large' or '7b'

# ====================================================================
# Derived configurations (automatically set based on above)
# ====================================================================
config = MODEL_CONFIGS[CURRENT_MODEL]
input_bands = DATA_CONFIG['np_stats'][-3:]
#dataset_path = f"{DATA_CONFIG['base_path']}{DATA_CONFIG['data_name']}_dataset"

# Print configuration summary
print("=== CONFIGURATION SUMMARY ===")
#print(f"Dataset: {DATA_CONFIG['data_name']}")
print(f"Input bands: {input_bands}")
print(f"Selected model: {config['description']}")
print(f"Model name: {config['model_name']}")
print(f"Training epochs: {TRAINING_CONFIG['n_epochs']}")
print(f"Learning rate: {TRAINING_CONFIG['lr']}")
print(f"Loss function: {type(TRAINING_CONFIG['loss_criterion']).__name__}")

In [None]:
## Define CHM colormap
def create_custom_binned_colormap(colors, vmax=35):
    """
    Create a colormap with custom bins using ListedColormap and BoundaryNorm.
    This avoids the LinearSegmentedColormap position issues.
    """
    
    # Define the boundaries for each color bin
    boundaries = [0, 0.001, .5, 1, 2, 3, 5, 10, vmax]
    
    # Create a ListedColormap (simpler than LinearSegmentedColormap)
    forest_ht_cmap = ListedColormap(colors)
    
    # Create a BoundaryNorm for discrete bins
    norm = BoundaryNorm(boundaries, len(colors))
    
    return forest_ht_cmap, norm, boundaries

# Your original colors
colors = ['#636363','#fc8d59','#fee08b','#ffffbf',
          '#d9ef8b','#91cf60','#1a9850','#005a32']

# Create the custom colormap
forest_ht_cmap, forest_ht_norm, boundaries = create_custom_binned_colormap(colors, vmax=35)

## 2. Data Preprocessing

In [None]:
# ====================================================================
# DATA PREPROCESSING AND DATASET PIPELINE
# ====================================================================

# Define transforms for your images
def create_transform(means=None, stds=None):
    """Create image transformation pipeline with proper normalization"""
    # No transforms needed since normalization is handled in the dataset
    return None

class NPYDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None, means=None, stds=None, skip_zero_chm=True):
        """
        Custom dataset to load 4-band .npy images and corresponding .npy labels
        Extracts NIR, Red, Green channels (bands 3, 2, 1) from 4-band images
        
        Args:
            image_dir: Path to folder containing .npy image files
            label_dir: Path to folder containing .npy label files
            transform: Optional transforms to apply
            means: Channel means for normalization [NIR, Red, Green] (after min-max)
            stds: Channel standard deviations for normalization [NIR, Red, Green] (after min-max)
            skip_zero_chm: If True, skip files where all CHM pixels are <= 0
        """
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.means = means
        self.stds = stds
        self.skip_zero_chm = skip_zero_chm
        
        # Channel mapping for 4-band images (Blue, Green, Red, NIR)
        # We want NIR, Red, Green which are bands 3, 2, 1 (0-indexed)
        self.channel_indices = [3, 2, 1]  # NIR, Red, Green
        
        # Get all .npy image files
        self.image_files = glob.glob(os.path.join(image_dir, "*.npy"))
        self.image_files.sort()  # Ensure consistent ordering
        
        # Create mapping of image files to label files
        self.valid_pairs = []
        skipped_count = 0
        
        for img_path in self.image_files:
            base_name = os.path.splitext(os.path.basename(img_path))[0]
            label_path = os.path.join(label_dir, f"{base_name}.npy")
            
            if os.path.exists(label_path):
                # Check if we should skip files with all CHM <= 0
                if self.skip_zero_chm:
                    try:
                        label = np.load(label_path)
                        if len(label.shape) > 2:
                            label = np.squeeze(label)
                        
                        # Skip if all pixels are <= 0
                        if np.all(label <= 0):
                            skipped_count += 1
                            continue
                            
                    except Exception as e:
                        print(f"Error checking label {label_path}: {e}")
                        continue
                
                self.valid_pairs.append((img_path, label_path))
            else:
                print(f"Warning: No label found for {img_path}")
        
        print(f"Dataset contains {len(self.valid_pairs)} samples")
        if self.skip_zero_chm:
            print(f"Skipped {skipped_count} files with all CHM pixels <= 0")
        
    def __len__(self):
        return len(self.valid_pairs)
    
    def __getitem__(self, idx):
        img_path, label_path = self.valid_pairs[idx]
        
        # Load 4-band image (.npy file)
        try:
            image = np.load(img_path)
            
            # Handle different possible shapes for 4-band images
            if len(image.shape) == 3:
                if image.shape[0] == 4:
                    # Shape is (4, H, W) - channels first
                    nrg_image = image[self.channel_indices]  # Shape: (3, H, W)
                    nrg_image = np.transpose(nrg_image, (1, 2, 0))  # Convert to (H, W, 3)
                elif image.shape[-1] == 4:
                    # Shape is (H, W, 4) - channels last
                    nrg_image = image[:, :, self.channel_indices]  # Shape: (H, W, 3)
                else:
                    raise ValueError(f"Unexpected image shape: {image.shape}")
            else:
                raise ValueError(f"Expected 3D image array, got shape: {image.shape}")
            
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            nrg_image = np.zeros((64, 64, 3), dtype=np.float32)
        
        # Load label (.npy file)
        try:
            label = np.load(label_path)
            if len(label.shape) > 2:
                label = np.squeeze(label)
        except Exception as e:
            print(f"Error loading label {label_path}: {e}")
            label = np.zeros((64, 64), dtype=np.float32)
        
        # Convert to tensors
        image = torch.from_numpy(nrg_image.copy()).float()
        label = torch.from_numpy(label.copy()).float()
        
        # STEP 1: Min-max normalization to [0, 1] range
        # NIR channel (index 0)
        image[:, :, 0] = (image[:, :, 0] - DATA_CONFIG['nir_min']) / (DATA_CONFIG['nir_max'] - DATA_CONFIG['nir_min'])
        image[:, :, 0] = torch.clamp(image[:, :, 0], 0, 1)
        
        # Red channel (index 1) 
        image[:, :, 1] = (image[:, :, 1] - DATA_CONFIG['red_min']) / (DATA_CONFIG['red_max'] - DATA_CONFIG['red_min'])
        image[:, :, 1] = torch.clamp(image[:, :, 1], 0, 1)
        
        # Green channel (index 2)
        image[:, :, 2] = (image[:, :, 2] - DATA_CONFIG['green_min']) / (DATA_CONFIG['green_max'] - DATA_CONFIG['green_min'])
        image[:, :, 2] = torch.clamp(image[:, :, 2], 0, 1)
        
        # STEP 2: Z-score normalization using dataset statistics
        if self.means is not None and self.stds is not None:
            # NIR channel (index 0)
            image[:, :, 0] = (image[:, :, 0] - self.means[0]) / self.stds[0]
            # Red channel (index 1) 
            image[:, :, 1] = (image[:, :, 1] - self.means[1]) / self.stds[1]
            # Green channel (index 2)
            image[:, :, 2] = (image[:, :, 2] - self.means[2]) / self.stds[2]
        
        # Convert image to (C, H, W) format for PyTorch
        image = image.permute(2, 0, 1)
        
        # Apply transforms if provided
        if self.transform:
            image = self.transform(image)
        
        return image, label

def create_flexible_dataloaders(train_dataset, test_dataset, model_config='large'):
    """Create dataloaders optimized for the selected model"""
    
    config = MODEL_CONFIGS[model_config]
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
    
    # Use model-specific base batch size
    base_batch_size = config['base_batch_size']
    train_batch_size = base_batch_size * max(1, num_gpus)
    test_batch_size = base_batch_size * max(1, num_gpus)
    
    # Adjust workers based on model size
    num_workers = 4 * num_gpus if config['memory_efficient'] else 2 * num_gpus
    
    print(f"DataLoader setup for {config['description']}:")
    print(f"  Number of GPUs: {num_gpus}")
    print(f"  Base batch size: {base_batch_size}")
    print(f"  Train batch size: {train_batch_size}")
    print(f"  Test batch size: {test_batch_size}")
    print(f"  Workers: {num_workers}")
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=train_batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=test_batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, test_loader

# ====================================================================
# DATA LOADING AND PREPROCESSING
# ====================================================================

# Create transforms
transform = create_transform(means=DATA_CONFIG.get('means'), stds=DATA_CONFIG.get('stds'))

# Create the full dataset from your .npy files with CHM filtering
full_dataset = NPYDataset(
    image_dir=DATA_CONFIG['image_dir'],
    label_dir=DATA_CONFIG['label_dir'],
    transform=transform,
    means=DATA_CONFIG.get('means'),
    stds=DATA_CONFIG.get('stds'),
    skip_zero_chm=True  # This will skip files where all CHM pixels are <= 0
)

print(f"Total dataset size: {len(full_dataset)}")

# Split dataset into train and test (80/20 split)
from torch.utils.data import random_split

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

train_dataset, test_dataset = random_split(
    full_dataset, 
    [train_size, test_size],
    generator=torch.Generator().manual_seed(42)  # For reproducible splits
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Create optimized data loaders based on model configuration
train_loader, test_loader = create_flexible_dataloaders(
    train_dataset, 
    test_dataset, 
    model_config=CURRENT_MODEL
)

## 3. Preview dataset

In [None]:
# Alternative: Visualize raw data (before normalization)
def visualize_raw_image_and_chm(dataset, idx=0):
    """
    Visualize raw image and CHM data (before normalization) from file paths
    """
    # Get the file paths for the sample
    img_path, label_path = dataset.valid_pairs[idx]
    
    print(f"Loading raw files:")
    print(f"Image: {os.path.basename(img_path)}")
    print(f"Label: {os.path.basename(label_path)}")
    
    # Load raw image
    raw_image = np.load(img_path)
    
    # Extract NIR, Red, Green channels
    if raw_image.shape[0] == 4:
        # Shape is (4, H, W) - channels first
        nrg_image = raw_image[[3, 2, 1]]  # NIR, Red, Green
        nrg_image = np.transpose(nrg_image, (1, 2, 0))  # Convert to (H, W, 3)
    elif raw_image.shape[-1] == 4:
        # Shape is (H, W, 4) - channels last
        nrg_image = raw_image[:, :, [3, 2, 1]]  # NIR, Red, Green
    
    # Load raw CHM
    raw_chm = np.load(label_path)
    if len(raw_chm.shape) > 2:
        raw_chm = np.squeeze(raw_chm)
    
    print(f"Raw image shape: {nrg_image.shape}")
    print(f"Raw CHM shape: {raw_chm.shape}")
    print(f"Raw image range: {nrg_image.min():.1f} to {nrg_image.max():.1f}")
    print(f"Raw CHM range: {raw_chm.min():.3f} to {raw_chm.max():.3f}")
    
    # Apply min-max stretching for visualization
    stretched_image = np.zeros_like(nrg_image, dtype=np.float32)
    
    # NIR channel
    stretched_image[:, :, 0] = (nrg_image[:, :, 0] - DATA_CONFIG['nir_min']) / (DATA_CONFIG['nir_max'] - DATA_CONFIG['nir_min'])
    # Red channel  
    stretched_image[:, :, 1] = (nrg_image[:, :, 1] - DATA_CONFIG['red_min']) / (DATA_CONFIG['red_max'] - DATA_CONFIG['red_min'])
    # Green channel
    stretched_image[:, :, 2] = (nrg_image[:, :, 2] - DATA_CONFIG['green_min']) / (DATA_CONFIG['green_max'] - DATA_CONFIG['green_min'])
    
    stretched_image = np.clip(stretched_image, 0, 1)
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Display stretched image
    ax1.imshow(stretched_image)
    ax1.set_title(f'Raw NIR-Red-Green Image (Sample {idx})')
    ax1.axis('off')
    
    # Display CHM
    chm_plot = ax2.imshow(raw_chm, cmap=forest_ht_cmap, norm=forest_ht_norm)
    ax2.set_title(f'Raw Canopy Height Model (Sample {idx})')
    ax2.axis('off')
    plt.colorbar(chm_plot, ax=ax2, label='Height (m)')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize raw data (before any processing)
print("\n=== Visualizing Raw Data ===")
visualize_raw_image_and_chm(full_dataset, idx=0)

## 4. Model Structure

In [None]:
## Make DepthHead for fine-tuning, with attention, Leaky RelU, and extra layers

class ChannelAttention(nn.Module):
    """Channel attention mechanism to focus on important features"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels//reduction, 1),
            nn.ReLU(),
            nn.Conv2d(channels//reduction, channels, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return x * self.attention(x)

class DINOv3DepthHead(nn.Module):
    def __init__(self, model_name="facebook/dinov3-vitl16-pretrain-sat493m", 
             freeze_backbone=True, output_channels=1, token=None, input_bands='rgb',
             training_config=None):
        
        super().__init__()
        
        # Store training config for use in decoder
        self.training_config = training_config or {}
        self.dropout_rate = self.training_config.get('dropout_rate', 0.1)  # Default to 0.1
        
        # Store output_channels for use in decoder
        self.output_channels = output_channels
        
        # Store output_channels for use in decoder
        self.output_channels = output_channels
        
        # Load DINOv3 backbone
        if token:
            self.backbone = AutoModel.from_pretrained(model_name, token=token)
        else:
            self.backbone = AutoModel.from_pretrained(model_name)
    
        # Modify input weights for NRG if specified
        if input_bands == 'nrg':
            self._modify_input_weights_for_nrg()
        
        # Get actual model dimensions dynamically
        self._determine_model_dimensions()
        
        # Freeze backbone if specified
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
            self._freeze_backbone = True
        
        print(f"DINOv3 DepthHead initialized:")
        print(f"  Model: {model_name}")
        print(f"  Input bands: {input_bands}")
        print(f"  Embedding dim: {self.embed_dim}")
        print(f"  Total tokens: {self.total_tokens}")
        print(f"  Spatial patches: {self.num_spatial_patches}")
        print(f"  Using tokens {self.spatial_start}:{self.spatial_end}")
        
        # Build improved decoder
        self.depth_head = self._build_improved_decoder()
    
        # Initialize the final layer - ADD THIS LINE:
        self._initialize_output_layer(target_percentile=8.42)  # 70% - 1.72, 90% = 8.42 
    
    def _determine_model_dimensions(self):
        """Dynamically determine model dimensions by running a test forward pass"""
        # Use 64x64 dummy input to match your actual data
        dummy_input = torch.randn(1, 3, 64, 64)
        
        with torch.no_grad():
            outputs = self.backbone(dummy_input)
            features = outputs.last_hidden_state
        
        # Get actual dimensions
        self.total_tokens = features.shape[1]
        self.embed_dim = features.shape[2]
        
        # Calculate spatial patch info for 64x64 input with 16x16 patches
        self.patches_per_side = 64 // 16  # = 4
        self.num_spatial_patches = self.patches_per_side ** 2  # = 16
        
        # Determine token structure
        non_spatial_tokens = self.total_tokens - self.num_spatial_patches
        
        if non_spatial_tokens == 1:
            # Structure: [CLS] + [16 spatial]
            self.spatial_start = 1
            self.spatial_end = self.total_tokens
        elif non_spatial_tokens == 5:
            # Structure: [CLS] + [4 register] + [16 spatial]  
            self.spatial_start = 5
            self.spatial_end = self.total_tokens
        else:
            # Generic: take last 16 tokens as spatial
            self.spatial_start = self.total_tokens - self.num_spatial_patches
            self.spatial_end = self.total_tokens
        
        print(f"Detected model structure:")
        print(f"  Total tokens: {self.total_tokens}")
        print(f"  Embedding dim: {self.embed_dim}")
        print(f"  Non-spatial tokens: {non_spatial_tokens}")
        print(f"  Input size: 64x64")
        print(f"  Patch grid: {self.patches_per_side}x{self.patches_per_side}")

    def _modify_input_weights_for_nrg(self):
        """Simple weight modification for NIR-Red-Green input"""
        print("Modifying input weights for NIR-Red-Green bands...")
        
        # Find the first Conv2d layer with 3 input channels (this is the patch embedding)
        for name, module in self.backbone.named_modules():
            if isinstance(module, nn.Conv2d) and module.in_channels == 3:
                print(f"Found patch embedding at: {name}")
                
                with torch.no_grad():
                    original_weights = module.weight.data.clone()
                    
                    # NIR=Red, Red=Red, Green=Green (assuming BGR input order)
                    new_weights = torch.zeros_like(original_weights)
                    new_weights[:, 0, :, :] = original_weights[:, 2, :, :]  # NIR <- Red
                    new_weights[:, 1, :, :] = original_weights[:, 2, :, :]  # Red <- Red  
                    new_weights[:, 2, :, :] = original_weights[:, 1, :, :]  # Green <- Green
                    
                    module.weight.data = new_weights
                    print("‚úÖ Weights modified successfully")
                break

    def _build_improved_decoder(self):
        """
        Improved decoder adapted for 16x16 input (1x1 spatial features)
        """
        
        class ResidualBlock(nn.Module):
            """Residual block with LeakyReLU and attention"""
            def __init__(self, channels, use_attention=True, dropout_rate=0.1):
                super().__init__()
                self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
                self.bn1 = nn.BatchNorm2d(channels)
                self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
                self.bn2 = nn.BatchNorm2d(channels)
                self.activation = nn.LeakyReLU(0.1, inplace=True)
                
                # Add dropout after activation
                self.dropout = nn.Dropout2d(dropout_rate)
                
                self.attention = ChannelAttention(channels) if use_attention else nn.Identity()
                
            def forward(self, x):
                residual = x
                out = self.activation(self.bn1(self.conv1(x)))
                out = self.dropout(out)  # Apply dropout after activation
                out = self.bn2(self.conv2(out))
                out = self.attention(out)
                out += residual  # Skip connection
                return self.activation(out)
               
        # Get dropout rate from training config
        dropout_rate = self.dropout_rate
        
        layers = nn.ModuleList()
        current_channels = self.embed_dim
        
        # Need 4 upsampling steps: 4x4 -> 8x8 -> 16x16 -> 32x32 -> 64x64
        target_channels = [512, 256, 128, 64]  # Same number, but explicit about path
        
        for i, out_channels in enumerate(target_channels):
            # Upsampling block (2x upsampling each time)
            upsample_block = nn.Sequential(
                nn.ConvTranspose2d(current_channels, out_channels, 
                                 kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.1, inplace=True),
                nn.Dropout2d(dropout_rate)  # Use config dropout rate
            )
            layers.append(upsample_block)
            
            # Residual refinement blocks with dropout
            refinement_block = nn.Sequential(
                ResidualBlock(out_channels, use_attention=(i >= 2), dropout_rate=dropout_rate),
                ResidualBlock(out_channels, use_attention=(i >= 2), dropout_rate=dropout_rate)
            )
            layers.append(refinement_block)
            
            current_channels = out_channels

        # Final layers - these should NOT change spatial dimensions
        # After 4 upsampling steps, we should already be at 64x64
        final_layers = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout2d(dropout_rate),  # Use config dropout rate
            ChannelAttention(128),
            
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout2d(dropout_rate),  # Use config dropout rate
            
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),
            
            # No dropout before final prediction layer
            # Final prediction - should already be 64x64
            nn.Conv2d(32, self.output_channels, kernel_size=1)
        )
        layers.append(final_layers)
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
    # Validate input size
        if x.shape[-2:] != (64, 64):
            raise ValueError(f"Expected input size 64x64, got {x.shape[-2:]}")
        
        # Get features from DINOv3
        if hasattr(self, '_freeze_backbone'):
            with torch.no_grad():
                outputs = self.backbone(x)
        else:
            outputs = self.backbone(x)
            
        features = outputs.last_hidden_state
        
        # Extract spatial tokens (should be just 1 token for 16x16 input)
        spatial_tokens = features[:, self.spatial_start:self.spatial_end, :]
        batch_size = spatial_tokens.shape[0]
        
        # Verify we have the right number of spatial patches
        expected_spatial_tokens = self.spatial_end - self.spatial_start
        if spatial_tokens.shape[1] != expected_spatial_tokens:
            raise ValueError(f"Expected {expected_spatial_tokens} spatial tokens, got {spatial_tokens.shape[1]}")
        
        # Reshape to spatial grid (4x4 for 64x64 input)  # ‚Üê CORRECTED COMMENT
        spatial_features = spatial_tokens.transpose(1, 2).reshape(
            batch_size, self.embed_dim, self.patches_per_side, self.patches_per_side  # 4x4
        )
        
        # Pass through decoder (4x4 -> 64x64)  # ‚Üê CORRECTED COMMENT
        depth_map = self.depth_head(spatial_features)
        depth_map = depth_map.squeeze(1)

        depth_map = torch.clamp(depth_map, min=0.0)
        
        return depth_map
        
    def _initialize_output_layer(self, target_mean=None, target_percentile=None):
        """
        Initialize final layer to better predict the full range with skewed data
        
        Args:
            target_mean (float, optional): The mean value of your targets
            target_percentile (float, optional): A specific percentile value from your targets
                                               (e.g., 0.7 would be the 70th percentile)
        """
        # Find the final convolutional layer
        final_conv = None
        for name, module in self.named_modules():
            if isinstance(module, nn.Conv2d):
                final_conv = module
        
        if final_conv is None:
            print("Warning: Could not find final convolutional layer for initialization")
            return
            
        with torch.no_grad():
            # Initialize weights with a slightly higher standard deviation
            # to encourage more diverse predictions
            nn.init.normal_(final_conv.weight, mean=0.0, std=0.02)
            
            # Initialize the bias to address the skew
            if final_conv.bias is not None:
                # If you've calculated statistics from your dataset:
                if target_mean is not None:
                    final_conv.bias.fill_(target_mean)
                elif target_percentile is not None:
                    final_conv.bias.fill_(target_percentile)
                else:
                    # Default initialization - adjust based on your data's scale
                    # This value should be in the same scale as your target values
                    final_conv.bias.fill_(3.0)  # Example value

In [None]:
def run_dinov3_training_flexible(
    train_loader, 
    test_loader,
    data_config,    # NEW: Pass in data configuration
    training_config,  # NEW: Pass in training configuration
    model_config='large'  # Keep this as a simple string parameter
):
    """
    Training function with flexible model selection and configuration
    
    Parameters:
    -----------
    train_loader : DataLoader
        PyTorch DataLoader for training data
    test_loader : DataLoader
        PyTorch DataLoader for validation/test data
    data_config : dict
        Configuration for dataset details
    training_config : dict
        Configuration for training parameters
    model_config : str
        Model configuration key ('large' or '7b')
    """
    
    # Get model configuration
    config = MODEL_CONFIGS[model_config]
    
    # Extract required variables from configs
    loss_criterion = training_config['loss_criterion']
    base_lr = training_config['lr']
    weight_decay = training_config.get('weight_decay', 1e-4)
    patience = training_config.get('patience', 10)
    num_epochs = training_config.get('n_epochs', 30)
    hf_token = training_config.get('hf_token', None)
    gradient_clip_norm = training_config.get('gradient_clip_norm', 1.0)
    
    # Get data-specific information
    data_name = data_config['data_name']
    input_bands = data_config.get('np_stats')[-3:]  # Extract from data config
    
    print("="*60)
    print(f"STARTING DINOV3 MULTI-GPU TRAINING")
    print(f"Model: {config['description']}")
    print("="*60)
       
    # Check available GPUs
    if torch.cuda.is_available():
        device = torch.device('cuda')
        num_gpus = torch.cuda.device_count()
        print(f"Found {num_gpus} GPUs available:")
        for i in range(num_gpus):
            print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        device = torch.device('cpu')
        print("No GPUs available, using CPU")
        return None, [], []
    
    # Clear GPU memory
    torch.cuda.empty_cache()
    
    # Create model with selected configuration
    model = DINOv3DepthHead(
        model_name=config['model_name'],
        freeze_backbone=True,
        token=hf_token,
        input_bands=input_bands
    )
    
    # Move to primary GPU first
    model = model.to(device)
    
    # Wrap with DataParallel for multi-GPU training
    if num_gpus > 1:
        print(f"Using DataParallel across {num_gpus} GPUs")
        model = nn.DataParallel(model)
        effective_batch_size = train_loader.batch_size * num_gpus
        print(f"Effective batch size: {effective_batch_size}")
    
    # Training setup - adjust learning rate based on model size
    criterion = loss_criterion
    model_params = model.module.depth_head.parameters() if hasattr(model, 'module') else model.depth_head.parameters()
    
    # Adjust learning rate based on model size
    adjusted_lr = base_lr if model_config == 'large' else base_lr * 0.1
    optimizer = torch.optim.AdamW(model_params, lr=adjusted_lr, weight_decay=weight_decay)
    
    # Smaller learning rate for larger model
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, min_lr=1e-6)
    
    print(f"Training setup:")
    print(f"  Learning rate: {base_lr}")
    print(f"  Patience: {patience}")
    
    # Training loop
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs} (Patience: {patience_counter}/{patience})")
        print("-" * 50)
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_mae = 0.0
        
        for batch_idx, (images, targets) in enumerate(train_loader):
            images, targets = images.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
            train_mae += nn.L1Loss()(outputs, targets).item()
            
            # if batch_idx % 100 == 0:
            #     print(f"  Batch {batch_idx}/{len(train_loader)}: Loss = {loss.item():.4f}")
            if batch_idx % 500 == 0:
                with torch.no_grad():
  
                    print(f"  Batch {batch_idx}/{len(train_loader)}: Loss = {loss.item():.4f}")
            
            # More frequent memory cleanup for 7B model
            if not config['memory_efficient'] and batch_idx % 25 == 0:
                torch.cuda.empty_cache()
        
        avg_train_loss = train_loss / len(train_loader)
        avg_train_mae = train_mae / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_mae = 0.0
        
        with torch.no_grad():
            for images, targets in test_loader:
                images, targets = images.to(device), targets.to(device)
                outputs = model(images)
                
                val_loss += criterion(outputs, targets).item()
                val_mae += nn.L1Loss()(outputs, targets).item()
        
        avg_val_loss = val_loss / len(test_loader)
        avg_val_mae = val_mae / len(test_loader)
        val_losses.append(avg_val_loss)

        #####################################
        epoch_pred_ranges = []
        epoch_target_ranges = []
        
        with torch.no_grad():
            # Sample a few batches for range analysis
            sample_batches = min(5, len(test_loader))  # Sample 5 batches max
            for i, (images, targets) in enumerate(test_loader):
                if i >= sample_batches:
                    break
                    
                images, targets = images.to(device), targets.to(device)
                outputs = model(images)
                
                epoch_pred_ranges.append([outputs.min().item(), outputs.max().item()])
                epoch_target_ranges.append([targets.min().item(), targets.max().item()])
        
        # Calculate epoch statistics
        avg_pred_min = np.mean([r[0] for r in epoch_pred_ranges])
        avg_pred_max = np.mean([r[1] for r in epoch_pred_ranges])
        avg_target_min = np.mean([r[0] for r in epoch_target_ranges])
        avg_target_max = np.mean([r[1] for r in epoch_target_ranges])
        
        print(f"\nEpoch {epoch+1} Range Analysis:")
        print(f"  Predictions: [{avg_pred_min:.1f}, {avg_pred_max:.1f}] (range: {avg_pred_max - avg_pred_min:.1f})")
        print(f"  Ground Truth: [{avg_target_min:.1f}, {avg_target_max:.1f}] (range: {avg_target_max - avg_target_min:.1f})")
        
        scheduler.step(avg_val_loss) 
        
        # Check for best model and save if improved
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            
            # Save BEST model with config info
            model_to_save = model.module if hasattr(model, 'module') else model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model_to_save.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'val_mae': avg_val_mae,
                'train_losses': train_losses,      # COMPLETE history (this was missing!)
                'val_losses': val_losses,          # COMPLETE history (this was missing!)
                'model_config': {
                    'model_type': model_config,
                    'model_name': config['model_name'],
                    'embed_dim': model_to_save.embed_dim,
                    'total_tokens': model_to_save.total_tokens,
                    'description': config['description']
                }
            }, f'best_dinov3_{model_config}_{data_name}.pth')
            
            improvement = "‚≠ê NEW BEST"
        else:
            patience_counter += 1
            improvement = f"No improvement ({patience_counter}/{patience})"
        
        print(f"  Train Loss: {avg_train_loss:.4f} | Train MAE: {avg_train_mae:.4f}")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Val MAE:   {avg_val_mae:.4f}")
        print(f"  Val RMSE:   {np.sqrt(avg_val_loss):.4f}")
        print(f"  Status: {improvement}")
        
        # Early stopping
        if patience_counter >= patience:
            print(f"\nüõë EARLY STOPPING at epoch {epoch+1}")
            break
        
        # Memory cleanup
        torch.cuda.empty_cache()
    
    # Save LAST model (regardless of performance)
    print(f"\nüíæ Saving final model from epoch {epoch+1}")
    model_to_save = model.module if hasattr(model, 'module') else model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model_to_save.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'val_mae': avg_val_mae,
        'train_losses': train_losses,      
        'val_losses': val_losses,          
        'final_epoch': True,  # Flag to indicate this is the final model
        'model_config': {
            'model_type': model_config,
            'model_name': config['model_name'],
            'embed_dim': model_to_save.embed_dim,
            'total_tokens': model_to_save.total_tokens,
            'description': config['description']
        }
    }, f'last_dinov3_{model_config}_{data_name}.pth')
    
    print(f"‚úÖ Training complete!")
    print(f"   Best model saved as: best_dinov3_{CURRENT_MODEL}_{data_name}.pth")
    print(f"   Last model saved as: last_dinov3_{CURRENT_MODEL}_{data_name}.pth")
    print(f"   Best validation loss: {best_val_loss:.4f}")
    print(f"   Final validation loss: {avg_val_loss:.4f}")

    return model, train_losses, val_losses

## 5. Train Model

In [None]:
# # Create dataloaders for current model selection
# train_loader_current, test_loader_current = create_flexible_dataloaders(
#     train_dataset, test_dataset, model_config=CURRENT_MODEL
# )

# # Run training with selected model
# trained_model, train_losses, val_losses = run_dinov3_training_flexible(
#     train_loader=train_loader_current,
#     test_loader=test_loader_current,
#     data_config=DATA_CONFIG,           # NEW: Pass data configuration
#     training_config=TRAINING_CONFIG,   # NEW: Pass training configuration
#     model_config=CURRENT_MODEL         # Keep this the same
# )


import time
from datetime import timedelta

# Start timer
start_time = time.time()
print(f"Starting training at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# Create dataloaders for current model selection
train_loader_current, test_loader_current = create_flexible_dataloaders(
    train_dataset, test_dataset, model_config=CURRENT_MODEL
)

# # Run training with selected model
# trained_model, train_losses, val_losses = run_dinov3_training_flexible(
#     train_loader=train_loader_current,
#     test_loader=test_loader_current,
#     data_config=DATA_CONFIG,           # NEW: Pass data configuration
#     training_config=TRAINING_CONFIG,   # NEW: Pass training configuration
#     model_config=CURRENT_MODEL         # Keep this the same
# )

# End timer and calculate duration
end_time = time.time()
duration = end_time - start_time
duration_formatted = str(timedelta(seconds=int(duration)))

print(f"Training completed at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total training time: {duration_formatted}")

## 6. Evaluate

In [None]:
## Load Model
def load_flexible_model(checkpoint_path, device='cuda', input_bands='nrg'):
    """Load a model checkpoint regardless of which model type it is"""
    
    print(f"Loading model from: {checkpoint_path}")
    
    # Load checkpoint to check model type
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model_config = checkpoint['model_config']  # This line was missing!
    
    print(f"Loading {model_config['description']}")
    print(f"  Model name: {model_config['model_name']}")
    print(f"  Embedding dim: {model_config['embed_dim']}")
    print(f"  Total tokens: {model_config['total_tokens']}")
    
    # Create model with correct configuration
    token = "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"
    model = DINOv3DepthHead(
        model_name=model_config['model_name'],
        freeze_backbone=True,
        token=token,
        input_bands=input_bands
    )
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print(f"‚úÖ Model loaded successfully!")
    print(f"  Best epoch: {checkpoint['epoch']}")
    print(f"  Best validation loss: {checkpoint['val_loss']:.4f}")
    
    return model, checkpoint

In [None]:
# Initialize variables
trained_model = None
train_losses = []
val_losses = []

# Try to load existing model and training history
model_filename = f'last_dinov3_{CURRENT_MODEL}_{DATA_CONFIG["data_name"]}.pth'
if model_filename in os.listdir('.'):
    trained_model, checkpoint = load_flexible_model(model_filename)
    
    # Check if training history exists in checkpoint
    if 'train_losses' in checkpoint:
        train_losses = checkpoint['train_losses']
    if 'val_losses' in checkpoint:
        val_losses = checkpoint['val_losses']

# Debug: Verify the data was loaded
print(f"Loaded {len(train_losses)} training loss values")
print(f"Loaded {len(val_losses)} validation loss values")

In [None]:
# Plot results Loss and RMSE
if trained_model is not None and len(train_losses) > 0:
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    
    best_epoch = np.argmin(val_losses) + 1
    plt.axvline(x=best_epoch, color='green', linestyle='--', alpha=0.7, label=f'Best epoch: {best_epoch}')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss (MSE)')
    plt.title(f'{MODEL_CONFIGS[CURRENT_MODEL]["description"]} - Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    train_rmse = [np.sqrt(loss) for loss in train_losses]
    val_rmse = [np.sqrt(loss) for loss in val_losses]
    plt.plot(epochs, train_rmse, 'b-', label='Training RMSE', linewidth=2)
    plt.plot(epochs, val_rmse, 'r-', label='Validation RMSE', linewidth=2)
    plt.axvline(x=best_epoch, color='green', linestyle='--', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('RMSE (meters)')
    plt.title('RMSE Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nTraining Summary ({MODEL_CONFIGS[CURRENT_MODEL]['description']}):")
    print(f"  Total epochs run: {len(train_losses)}")
    print(f"  Best epoch: {best_epoch}")
    print(f"  Best validation loss: {min(val_losses):.4f}")
    print(f"  Best validation RMSE: {np.sqrt(min(val_losses)):.4f} meters")
    
    if len(train_losses) < TRAINING_CONFIG['n_epochs']:
        print(f"  ‚úÖ Early stopping triggered - saved {TRAINING_CONFIG['n_epochs'] - len(train_losses)} epochs!")

In [None]:
# Plot Heatmap correlation plots
def safe_corrcoef(x, y):
    """Safely compute correlation coefficient with robust error handling"""
    try:
        # Convert to numpy arrays and ensure they're 1D
        x = np.asarray(x).flatten()
        y = np.asarray(y).flatten()
        
        # Check for valid input
        if len(x) != len(y) or len(x) < 2:
            return 0.0
            
        # Remove any NaN or infinite values
        mask = np.isfinite(x) & np.isfinite(y)
        if np.sum(mask) < 2:
            return 0.0
            
        x_clean = x[mask]
        y_clean = y[mask]
        
        # Check for zero variance
        if np.var(x_clean) == 0 or np.var(y_clean) == 0:
            return 0.0
            
        # Manual correlation calculation to avoid BLAS issues
        x_mean = np.mean(x_clean)
        y_mean = np.mean(y_clean)
        
        numerator = np.sum((x_clean - x_mean) * (y_clean - y_mean))
        x_var = np.sum((x_clean - x_mean)**2)
        y_var = np.sum((y_clean - y_mean)**2)
        
        denominator = np.sqrt(x_var * y_var)
        
        if denominator == 0:
            return 0.0
            
        correlation = numerator / denominator
        
        # Clamp to valid range due to numerical precision
        correlation = np.clip(correlation, -1.0, 1.0)
        
        return correlation
        
    except Exception as e:
        print(f"Warning: Correlation calculation failed with error: {e}")
        return 0.0

def collect_predictions(model, dataloader, device, max_samples=None):
    """
    Collect ground truth and predicted values from the model.
    """
    model.eval()
    all_actual = []
    all_predicted = []
    
    # Count total samples to process
    total_samples = 0
    for batch in dataloader:
        total_samples += batch[0].shape[0]
        if max_samples and total_samples >= max_samples:
            break
    
    # Create progress bar with position and leave parameters
    pbar = tqdm(
        total=min(total_samples, max_samples) if max_samples else total_samples, 
        desc="Collecting predictions",
        position=0,
        leave=True
    )
    
    # Process batches
    samples_collected = 0
    with torch.no_grad():
        for images, depths in dataloader:
            batch_size = images.shape[0]
            
            # Skip if we've collected enough samples
            if max_samples and samples_collected >= max_samples:
                break
                
            # Process the batch
            images = images.to(device)
            depths = depths
            
            # Generate predictions
            predictions = model(images).cpu()
            
            # Flatten for easy analysis (pixel-wise comparison)
            actual_flat = depths.numpy().flatten()
            pred_flat = predictions.numpy().flatten()
            
            # Filter out any invalid values before storing
            valid_mask = np.isfinite(actual_flat) & np.isfinite(pred_flat)
            if np.any(valid_mask):
                all_actual.append(actual_flat[valid_mask])
                all_predicted.append(pred_flat[valid_mask])
            
            # Update counts and progress
            samples_collected += batch_size
            pbar.update(batch_size)
            
    pbar.close()
    
    # Combine all data
    if all_actual and all_predicted:
        all_actual = np.concatenate(all_actual)
        all_predicted = np.concatenate(all_predicted)
    else:
        print("Warning: No valid predictions collected!")
        all_actual = np.array([])
        all_predicted = np.array([])
    
    # Create DataFrame
    df = pd.DataFrame({
        'actual': all_actual,
        'predicted': all_predicted
    })
    
    return df

def create_comparison_plot_matplotlib(train_df, val_df, bins=50, title='Actual vs. Predicted Height', cmap='plasma'):
    """
    Create side-by-side bin2d plots using matplotlib with plasma colormap
    and properly positioned colorbar.
    """
    # Check if dataframes have data
    if len(train_df) == 0 or len(val_df) == 0:
        print("Error: Empty dataframes provided to plotting function")
        return None
    
    # Determine common range for both plots
    min_val = min(train_df['actual'].min(), train_df['predicted'].min(), 
                 val_df['actual'].min(), val_df['predicted'].min())
    max_val = max(train_df['actual'].max(), train_df['predicted'].max(), 
                 val_df['actual'].max(), val_df['predicted'].max())
    
    # Calculate metrics using safe correlation
    train_corr = safe_corrcoef(train_df['actual'].values, train_df['predicted'].values)
    val_corr = safe_corrcoef(val_df['actual'].values, val_df['predicted'].values)
    
    train_r2 = train_corr**2
    val_r2 = val_corr**2
    
    train_rmse = np.sqrt(np.mean((train_df['actual'] - train_df['predicted'])**2))
    val_rmse = np.sqrt(np.mean((val_df['actual'] - val_df['predicted'])**2))
    
    # Create the plot with more space for colorbar
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    
    try:
        # Training data (left plot)
        h1 = ax1.hist2d(train_df['actual'], train_df['predicted'], bins=bins, 
                       cmap=cmap, norm=plt.matplotlib.colors.LogNorm())
        
        ax1.plot([min_val, max_val], [min_val, max_val], 'b--', linewidth=1.5)
        ax1.set_xlabel('Actual Height (m)', fontsize=12)
        ax1.set_ylabel('Predicted Height (m)', fontsize=12)
        ax1.set_title(f'Training\nR¬≤ = {train_r2:.3f}, RMSE = {train_rmse:.3f}', fontsize=14)
        ax1.set_xlim(min_val, max_val)
        ax1.set_ylim(min_val, max_val)
        ax1.set_aspect('equal')
        ax1.grid(alpha=0.3)
        
        # Validation data (right plot)
        h2 = ax2.hist2d(val_df['actual'], val_df['predicted'], bins=bins, 
                       cmap=cmap, norm=plt.matplotlib.colors.LogNorm())
        
        ax2.plot([min_val, max_val], [min_val, max_val], 'b--', linewidth=1.5)
        ax2.set_xlabel('Actual Height (m)', fontsize=12)
        ax2.set_ylabel('Predicted Height (m)', fontsize=12)
        ax2.set_title(f'Validation\nR¬≤ = {val_r2:.3f}, RMSE = {val_rmse:.3f}', fontsize=14)
        ax2.set_xlim(min_val, max_val)
        ax2.set_ylim(min_val, max_val)
        ax2.set_aspect('equal')
        ax2.grid(alpha=0.3)
        
        # Adjust layout to make room for colorbar
        plt.tight_layout()
        
        # Create a new axis for the colorbar with specific positioning
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        cbar = fig.colorbar(h1[3], cax=cbar_ax)
        cbar.set_label('Count (log scale)', fontsize=12)
        cbar.ax.tick_params(labelsize=10)
        
        # Add main title with adjusted positioning
        fig.suptitle(title, fontsize=16, fontweight='bold', y=0.98)
        
        # Less aggressive tight_layout that respects the colorbar position
        plt.subplots_adjust(right=0.9, top=0.9)
        
    except Exception as e:
        print(f"Error creating plots: {e}")
        return None
    
    return fig

def evaluate_and_visualize_model(model, train_loader, test_loader, device, 
                                max_samples=10000, title='Actual vs. Predicted Height',
                                model_name='model', data_name='dataset'):
    """
    Evaluate model and create comparison visualization
    """
    print("Collecting training predictions...")
    train_df = collect_predictions(model, train_loader, device, max_samples)
    
    print("Collecting validation predictions...")
    val_df = collect_predictions(model, test_loader, device, max_samples)
    
    if len(train_df) == 0 or len(val_df) == 0:
        print("‚ùå No valid predictions collected. Check your data and model.")
        return None, None, None
    
    print("Creating visualization...")
    fig = create_comparison_plot_matplotlib(train_df, val_df, title=title, cmap='plasma')
    
    if fig is None:
        print("‚ùå Failed to create visualization.")
        return None, train_df, val_df
    
    print(f"Training samples: {len(train_df)}")
    print(f"Validation samples: {len(val_df)}")
    
    # Calculate and print additional metrics using safe correlation
    train_corr = safe_corrcoef(train_df['actual'].values, train_df['predicted'].values)
    val_corr = safe_corrcoef(val_df['actual'].values, val_df['predicted'].values)
    
    train_r2 = train_corr**2
    val_r2 = val_corr**2
    
    train_rmse = np.sqrt(np.mean((train_df['actual'] - train_df['predicted'])**2))
    val_rmse = np.sqrt(np.mean((val_df['actual'] - val_df['predicted'])**2))
    
    train_mae = np.mean(np.abs(train_df['actual'] - train_df['predicted']))
    val_mae = np.mean(np.abs(val_df['actual'] - val_df['predicted']))
    
    print(f"\nDetailed Metrics:")
    print(f"Training   - R¬≤: {train_r2:.4f}, RMSE: {train_rmse:.4f}, MAE: {train_mae:.4f}")
    print(f"Validation - R¬≤: {val_r2:.4f}, RMSE: {val_rmse:.4f}, MAE: {val_mae:.4f}")
    
    # Save the figure
    try:
        model_desc = f"{model_name}_depth_estimation"
        filename = f'dinov3_depth_prediction_comparison_{model_desc}_{data_name}.png'
        fig.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Plot saved as: {filename}")
    except Exception as e:
        print(f"Warning: Could not save plot: {e}")
    
    return fig, train_df, val_df

# Now run the comprehensive evaluation
print("="*60)
print("COMPREHENSIVE MODEL EVALUATION")
print("="*60)

# Check for model and run evaluation
model_filename = f'best_dinov3_{CURRENT_MODEL}_{DATA_CONFIG["data_name"]}.pth' 

if model_filename in os.listdir('.'):
    print(f"Loading model from checkpoint: {model_filename}")
    model_for_eval, checkpoint = load_flexible_model(model_filename)  # ‚Üê Fixed this line
    
    if model_for_eval is not None:
        result = evaluate_and_visualize_model(
            model=model_for_eval,
            train_loader=train_loader_current,
            test_loader=test_loader_current,
            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
            max_samples=50000,
            title=f'DINOv3-{CURRENT_MODEL.title()} Height Estimation: Actual vs. Predicted CHM',
            model_name=CURRENT_MODEL,  # Add this
            data_name=DATA_CONFIG['data_name']  # Add this
        )
        
        if result[0] is not None:  # fig is not None
            fig, train_df, val_df = result
            plt.show()
            print("\n‚úÖ Comprehensive evaluation completed!")
        else:
            print("‚ùå Evaluation failed during visualization.")
    else:
        print("‚ùå Failed to load model.")
else:
    print(f"‚ùå No trained model found: {model_filename}")
    print(f"Available .pth files: {[f for f in os.listdir('.') if f.endswith('.pth')]}")

In [None]:
# Visualize sample images
def visualize_predictions_fixed(model, dataloader, data_config, num_samples=20, device='cuda'):
    """
    Visualize model predictions with proper configuration handling
    
    Parameters:
    -----------
    model : torch model
        Trained model for inference
    dataloader : DataLoader
        Test data loader
    data_config : dict
        Configuration containing means, stds, and other data info
    num_samples : int
        Number of samples to visualize
    device : str or torch.device
        Device for computation
    """
    model.eval()
    
    # Extract values from config
    means = data_config['means']
    stds = data_config['stds']
    
    def stretch_nrg_image_std(nrg_image, means, stds, n_std=2):
        """
        Apply mean/std stretching for NRG visualization
        
        Parameters:
        -----------
        nrg_image : np.ndarray
            Image array of shape [H, W, C] where C >= 3 (NRG channels)
        means : np.ndarray
            Mean values for each channel
        stds : np.ndarray  
            Standard deviation values for each channel
        n_std : float
            Number of standard deviations to stretch to (default: 2)
        
        Returns:
        --------
        np.ndarray : Stretched image clipped to [0, 1]
        """
        stretched_image = np.zeros_like(nrg_image[:,:,:3], dtype=np.float32)
        
        for c in range(3):  # NRG channels
            # Calculate stretch range: mean ¬± n_std * std
            min_val = means[c] - n_std * stds[c]
            max_val = means[c] + n_std * stds[c]
            
            # Apply stretching
            channel_range = max_val - min_val
            if channel_range > 0:
                stretched_image[:, :, c] = (nrg_image[:, :, c] - min_val) / channel_range
            else:
                # If range is 0, just center around 0.5
                stretched_image[:, :, c] = 0.5
        
        # Clip to [0, 1] range
        stretched_image = np.clip(stretched_image, 0, 1)
        return stretched_image
    
    with torch.no_grad():
        for images, depths in dataloader:
            images = images.to(device)
            depths = depths.cpu().numpy()
            
            # Get predictions - handle DataParallel wrapper
            if hasattr(model, 'module'):
                depth_pred = model.module(images).cpu().numpy()
            else:
                depth_pred = model(images).cpu().numpy()
            
            # Denormalize images for visualization
            images = images.cpu().numpy()
            images = np.transpose(images, (0, 2, 3, 1))  # [B, C, H, W] -> [B, H, W, C]
            
            # Denormalize using your custom stats
            images = images * np.array(stds) + np.array(means)
            
            # Create subplot with extra space for colorbar and title
            fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples*3))
            
            # Handle case where num_samples = 1
            if num_samples == 1:
                axes = axes.reshape(1, -1)
            
            for i in range(min(num_samples, images.shape[0])):
                # Original image - apply mean/std stretching
                denorm_image = images[i]
                stretched_image = stretch_nrg_image_std(denorm_image, means, stds)
                
                axes[i, 0].imshow(stretched_image)
                axes[i, 0].set_title(f'Input Image {i+1} (NRG Stretched)')
                axes[i, 0].axis('off')
                
                # Ground truth depth
                im1 = axes[i, 1].imshow(depths[i], cmap=forest_ht_cmap, norm=forest_ht_norm)
                axes[i, 1].set_title(f'Ground Truth CHM\n(max: {depths[i].max():.1f}m)')
                axes[i, 1].axis('off')
                
                # Predicted depth  
                im2 = axes[i, 2].imshow(depth_pred[i], cmap=forest_ht_cmap, norm=forest_ht_norm)
                axes[i, 2].set_title(f'Predicted CHM\n(max: {depth_pred[i].max():.1f}m)')
                axes[i, 2].axis('off')
                
                # Add colorbar to each row's predicted image
                cbar = plt.colorbar(im2, ax=axes[i, 2], shrink=0.8, aspect=20)
                cbar.set_label('Canopy Height (m)', rotation=270, labelpad=20)
                
                # Customize colorbar ticks to show the boundary values
                cbar.set_ticks([0, 0.001, 0.5, 1, 2, 3, 5, 10, 35])
                cbar.set_ticklabels(['0', '0.001', '0.5', '1', '2', '3', '5', '10', '35'])
                
                # Calculate metrics for this sample
                mse = np.mean((depths[i] - depth_pred[i])**2)
                mae = np.mean(np.abs(depths[i] - depth_pred[i]))
                print(f"Sample {i+1} - MSE: {mse:.3f}, MAE: {mae:.3f}, RMSE: {np.sqrt(mse):.3f}")
            
            # Use dynamic model description from config
            model_name = data_config.get('current_model', 'model')
            model_desc = f"{model_name}_depth_estimation"
            
            # Add suptitle with better positioning
            plt.suptitle(f'DINOv3-{model_name.title()} Sample Predictions ({model_desc})', 
                        fontsize=16, y=0.98)  # Position higher with y parameter
            
            # Adjust layout with tighter spacing
            plt.subplots_adjust(top=0.96, right=0.7, hspace=0.2, wspace=0.1)
            
            plt.savefig(f'dinov3_sample_predictions_{model_desc}.png', dpi=300, bbox_inches='tight')
            plt.show()
            break

# Run the visualization with proper error handling
print("\nVisualizing individual sample predictions:")

# Check if we have a model to use
if 'model_for_eval' in locals() and model_for_eval is not None:
    # Use the model from evaluation
    visualize_predictions_fixed(
        model_for_eval, 
        test_loader_current, 
        DATA_CONFIG,  # Pass the config
        num_samples=15, 
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    )
    
elif 'trained_model' in locals() and trained_model is not None:
    # Use the trained model
    model_to_use = trained_model.module if hasattr(trained_model, 'module') else trained_model
    visualize_predictions_fixed(
        model_to_use, 
        test_loader_current, 
        DATA_CONFIG,  # Pass the config
        num_samples=15, 
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    )
    
else:
    # Try to load a model
    try:
        checkpoint_path = f'best_dinov3_{CURRENT_MODEL}_{DATA_CONFIG["data_name"]}.pth'
        loaded_model, _ = load_flexible_model(checkpoint_path)
        visualize_predictions_fixed(
            loaded_model, 
            test_loader_current, 
            DATA_CONFIG,  # Pass the config
            num_samples=15, 
            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        )
    except Exception as e:
        print(f"‚ùå No model available for visualization: {e}")

## 7. Do Inference on Specitifc TIFs

In [None]:
#Set up based on data preprocessing

# Add to your DATA_CONFIG or create new INFERENCE_CONFIG
INFERENCE_CONFIG = {
    'blue_min': 0,
    'blue_max': 5681,
    'nir_min': DATA_CONFIG['nir_min'],    
    'nir_max': DATA_CONFIG['nir_max'],
    'red_min': DATA_CONFIG['red_min'],
    'red_max': DATA_CONFIG['red_max'],
    'green_min': DATA_CONFIG['green_min'],
    'green_max': DATA_CONFIG['green_max'],
    'img_size': 64,
    'overlap': 0.50,
    'batch_size': 32,
    'visualize_first': True
}

if DATA_CONFIG['data_name'][-3:] == 'rgb':
    channels = [2,1,0]
elif DATA_CONFIG['data_name'][-3:] == 'nrg':
    channels = [3,2,1]
else :
    channels = [0,1,2,3]
print(channels)


In [None]:
tif_list = [
'/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV02_20200722_M1BS_10300100AB1FD100-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV03_20200704_M1BS_104001005EADB900-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV03_20190822_M1BS_104001005054F700-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV03_20170719_M1BS_104001002F7A2E00-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV02_20160817_M1BS_103001005B484100-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV02_20190820_M1BS_1030010099BF2F00-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV02_20110727_M1BS_103001000D914900-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV02_20190820_M1BS_10300100968FB300-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV02_20190820_M1BS_10300100954DE000-sr-02m.tif',
             '/explore/nobackup/projects/above/misc/ABoVE_Shrubs/srlite/002m_2023/WV02_20120721_M1BS_103001001AB23900-sr-02m.tif']



In [None]:
#Function to viz .tifs
def display_array_as_rgb_mean(arr, channels=[2, 1, 0], figsize=(15, 10), title="RGB Image", 
                             z_score_threshold=3, mask_negative=True, save_path=None):
    """
    Display a multi-channel array as RGB image with mean¬±std normalization, handling negative values.
    
    Args:
        arr: numpy array of shape (C, H, W)
        channels: list of channel indices to use as [R, G, B] 
        figsize: figure size for matplotlib
        title: title for the plot
        z_score_threshold: number of standard deviations to include (default 3)
        mask_negative: if True, mask values < 0 as NaN (transparent in display)
        save_path: path to save the image (optional)
    """
    
    # Extract the RGB channels
    rgb_image = arr[channels].astype(np.float64)  # Use float64 to handle NaN
    
    # Transpose to (H, W, 3) for matplotlib
    rgb_image = np.transpose(rgb_image, (1, 2, 0))
    
    print(f"RGB image shape: {rgb_image.shape}")
       
    # Handle negative values
    negative_mask = rgb_image < 0
    total_negative = np.sum(negative_mask)
    total_pixels = rgb_image.size
    
    if total_negative > 0:
        print(f"\nFound {total_negative:,} negative values ({100*total_negative/total_pixels:.2f}% of all pixel values)")
        
        if mask_negative:
            rgb_image[negative_mask] = np.nan
            print("Negative values masked as NaN (will appear white/transparent)")
        else:
            rgb_image[negative_mask] = 0
            print("Negative values set to 0")
    else:
        print("No negative values found")
    
    # Print statistics and normalize each channel using mean¬±std
    for i, ch in enumerate(channels):
        channel_data = rgb_image[:, :, i]
        valid_data = channel_data[~np.isnan(channel_data)]  # Exclude NaN values
        
        if len(valid_data) > 0:
            print(f"\nChannel {ch+1} statistics:")
            print(f"  Valid pixels: {len(valid_data):,} / {channel_data.size:,}")
            print(f"  Min: {valid_data.min():.3f}, Max: {valid_data.max():.3f}")
            print(f"  Mean: {valid_data.mean():.3f}, Std: {valid_data.std():.3f}")
            
            # Mean¬±std normalization
            mean_val = valid_data.mean()
            std_val = valid_data.std()
            lower_bound = mean_val - z_score_threshold * std_val
            upper_bound = mean_val + z_score_threshold * std_val
            
            # Only normalize valid pixels
            valid_mask = ~np.isnan(channel_data)
            rgb_image[valid_mask, i] = (channel_data[valid_mask] - lower_bound) / (upper_bound - lower_bound)
            
            print(f"  Normalization bounds: [{lower_bound:.3f}, {upper_bound:.3f}]")
        else:
            print(f"\nChannel {ch+1}: No valid data after masking!")
    
    # Clip valid values to [0, 1] range, but preserve NaN
    for i in range(3):
        valid_mask = ~np.isnan(rgb_image[:, :, i])
        rgb_image[valid_mask, i] = np.clip(rgb_image[valid_mask, i], 0, 1)
    
    # Display the image
    plt.figure(figsize=figsize)
    plt.imshow(rgb_image)
    plt.title(f'{title}\nChannels: {[c+1 for c in channels]} (R,G,B) - Mean¬±{z_score_threshold}œÉ Normalization')
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Image saved to: {save_path}")
    
    plt.show()
    
    # Return the arrays but don't print them
    return rgb_image  # Return None instead of the arrays

# Simple wrapper function
def show_rgb_mean(arr):
    """Display RGB using mean¬±3œÉ normalization with negative masking"""
    display_array_as_rgb_mean(
        arr, 
        channels=[2, 1, 0],
        title="RGB with Mean¬±3œÉ Normalization (Channels 3-2-1)"
    )

# rgb_image = show_rgb_mean(arr)

In [None]:
def complete_dinov3_inference_pipeline_batch(tif_list, model, data_config, inference_config):
    """
    Complete pipeline for DINOv3 inference on a list of large geospatial imagery files.
    
    Parameters:
    -----------
    tif_list : list
        List of paths to TIF files to process
    model : torch.nn.Module
        Trained DINOv3 model
    data_config : dict
        Configuration containing data paths and normalization stats
    inference_config : dict
        Configuration containing inference parameters
    
    Returns:
    --------
    list : Results from inference on each file
    """
    
    # Extract configuration values
    data_name = data_config['data_name']
    means = data_config['means']
    stds = data_config['stds']
    
    nir_min = inference_config['nir_min']
    nir_max = inference_config['nir_max']
    red_min = inference_config['red_min']
    red_max = inference_config['red_max']
    green_min = inference_config['green_min']
    green_max = inference_config['green_max']
    img_size = inference_config['img_size']
    overlap = inference_config['overlap']
    batch_size = inference_config['batch_size']
    visualize_first = inference_config['visualize_first']
    
    print("üöÄ Starting Batch DINOv3 Inference Pipeline")
    print("=" * 60)
    
    # Create output directory
    output_dir = f"./predict/{data_name}"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created folder: {output_dir}")
    else:
        print(f"Folder already exists: {output_dir}")
    
    print(f"Total TIF files to process: {len(tif_list)}")
    print("=" * 60)
    
    # Set up model
    model.eval()
    device = next(model.parameters()).device
    
    results = []
    output_h, output_w = None, None  # Initialize for model output dimensions
    
    for file_idx, tif_path in enumerate(tif_list, 1):
        print(f"\nüîÑ Processing TIF {file_idx}/{len(tif_list)}: {os.path.basename(tif_path)}")
        print("-" * 50)
        
        try:
            # Step 1: Load and preprocess imagery
            print("üìÇ Step 1: Loading and preprocessing imagery...")
            
            with rasterio.open(tif_path) as src:
                print(f"   Input shape: {src.shape}")
                print(f"   Bands: {src.count}")
                
                # Extract NIR, Red, Green channels (assuming bands 4, 3, 2)
                nir_band = src.read(4)   # NIR
                red_band = src.read(3)   # Red
                green_band = src.read(2) # Green
                
                # Stack into NRG array (channels first)
                nrg_array = np.stack([nir_band, red_band, green_band], axis=0)
            
            print(f"   Extracted NRG shape: {nrg_array.shape}")
            
            # Preprocess data
            nrg_array = nrg_array.astype(np.float32)
            nrg_array[nrg_array == -9999] = np.nan
            
            # Scale to 0-1 range first (like in training)
            nrg_array[0] = (nrg_array[0] - nir_min) / (nir_max - nir_min)
            nrg_array[1] = (nrg_array[1] - red_min) / (red_max - red_min) 
            nrg_array[2] = (nrg_array[2] - green_min) / (green_max - green_min)
            
            print("nir_mean: ", np.nanmean(nrg_array[0]), " red_mean: ", np.nanmean(nrg_array[1]), " green_mean: ", np.nanmean(nrg_array[2]))            
            
            # Clip to ensure 0-1 range
            nrg_array = np.clip(nrg_array, 0, 1)
            
            # Normalize using training statistics
            nrg_array[0] = (nrg_array[0] - means[0]) / stds[0]
            nrg_array[1] = (nrg_array[1] - means[1]) / stds[1]
            nrg_array[2] = (nrg_array[2] - means[2]) / stds[2]
            
            print(f"   ‚úì Preprocessing complete")
            print(f"   NaN pixels: {np.isnan(nrg_array).sum():,}")
            
            # Step 2: Set up sliding window inference
            print(f"üß† Step 2: Running DINOv3 inference...")
            
            # Transpose to (H, W, C) for tiler
            xraster = nrg_array.transpose(1, 2, 0)
            
            # Set up image tiler
            tiler_image = Tiler(
                data_shape=xraster.shape,
                tile_shape=(img_size, img_size, 3),
                channel_dimension=-1,
                overlap=overlap,
                mode='reflect'
            )
            
            # Get model output dimensions (only for first file)
            if file_idx == 1:
                test_input = torch.randn(1, 3, img_size, img_size).to(device)
                with torch.no_grad():
                    test_output = model(test_input)
                    output_h, output_w = test_output.shape[1], test_output.shape[2]
                print(f"   Model output size: {output_h}x{output_w}")
            
            # Set up merger for combining tile predictions
            tiler_mask = Tiler(
                data_shape=(xraster.shape[0], xraster.shape[1], 1),
                tile_shape=(output_h, output_w, 1),
                channel_dimension=-1,
                overlap=overlap,
                mode='reflect'
            )
            merger = Merger(tiler=tiler_mask, window='triang')
            
            # Process tiles in batches
            total_tiles = len(tiler_image)
            print(f"   Processing {total_tiles} tiles...")
            
            tile_count = 0
            for batch_id, batch_i in tiler_image(xraster, batch_size=batch_size):
                actual_batch_size = len(batch_i)
                
                # Prepare input batch
                input_batch = batch_i.transpose(0, 3, 1, 2).astype('float32')
                input_batch_tensor = torch.from_numpy(input_batch).to(device)
                input_batch_tensor = torch.nan_to_num(input_batch_tensor, nan=0.0)
                
                # Run inference
                with torch.no_grad():
                    y_batch = model(input_batch_tensor)
                    y_batch_numpy = y_batch.cpu().numpy()
                    
                    # Format for merger (B, H, W, C)
                    if len(y_batch_numpy.shape) == 3:
                        formatted_output = np.expand_dims(y_batch_numpy, axis=-1)
                    else:
                        formatted_output = y_batch_numpy
                
                # Add predictions to merger
                for j in range(actual_batch_size):
                    tile_id = batch_id * batch_size + j
                    merger.add(tile_id, formatted_output[j])
                
                tile_count += actual_batch_size
            
            # Merge tile predictions
            print("   üîÑ Merging tiles...")
            predictions_raw = merger.merge(unpad=True)
            predictions_meters = np.squeeze(predictions_raw)
            
            print(f"   ‚úì Inference complete")
            print(f"   Final prediction shape: {predictions_meters.shape}")
            
            # Step 3: Post-process predictions
            print(f"üîß Step 3: Post-processing predictions...")
            
            # Restore original NoData locations
            original_nodata = np.isnan(nrg_array[0])
            predictions_meters[original_nodata] = -9999
            
            # Convert to decimeters (int16 for storage efficiency)
            predictions_decimeters = convert_predictions_to_decimeters_int16(predictions_meters)
            print(f"   ‚úì Converted to decimeters (int16)")
            
            # Step 4: Save as GeoTIFF
            print(f"üíæ Step 4: Saving GeoTIFF...")
            
            # Generate output filename
            base_name = os.path.basename(tif_path)
            if len(base_name) >= 46:
                string = base_name[-46:-10]
            else:
                string = base_name[:-4]
            output_tif = os.path.join(output_dir, f'{string}sr-02m.chm.tif')
            
            save_predictions_as_geotiff(predictions_decimeters, tif_path, output_tif)
            
            # Step 5: Optional visualization
            if (file_idx == 1 and visualize_first) or (visualize_first and file_idx <= 3):
                print(f"üìä Step 5: Generating visualization for file {file_idx}...")
                
                plt.figure(figsize=(20, 6))
                
                # Original NRG composite
                plt.subplot(1, 3, 1)
                nrg_display = nrg_array.transpose(1, 2, 0)
                display_step = max(1, min(nrg_display.shape[0], nrg_display.shape[1]) // 1000)
                nrg_sub = nrg_display[::display_step, ::display_step, :]
                
                # Normalize for display
                nrg_norm = np.zeros_like(nrg_sub)
                for ch in range(3):
                    channel = nrg_sub[:, :, ch]
                    valid = ~np.isnan(channel)
                    if valid.any():
                        vmin, vmax = np.nanpercentile(channel[valid], [2, 98])
                        if vmax > vmin:
                            nrg_norm[:, :, ch] = np.clip((channel - vmin)/(vmax - vmin), 0, 1)
                
                plt.imshow(nrg_norm)
                plt.title(f'NRG Composite\n{os.path.basename(tif_path)}')
                plt.axis('off')
                
                # Predicted CHM
                plt.subplot(1, 3, 2)
                pred_display = predictions_meters.copy()
                pred_display[pred_display == -9999] = np.nan
                pred_sub = pred_display[::display_step, ::display_step]
                
                valid_pred = pred_display[~np.isnan(pred_display)]
                if len(valid_pred) > 0:
                    vmin, vmax = np.nanpercentile(valid_pred, [1, 99])
                    im1 = plt.imshow(pred_sub, cmap=forest_ht_cmap, norm=forest_ht_norm)
                    plt.colorbar(im1, label='Height (m)')
                else:
                    plt.imshow(pred_sub, cmap=forest_ht_cmap, norm=forest_ht_norm)
                
                plt.title('Predicted CHM (meters)')
                plt.axis('off')
                
                # Statistics histogram
                plt.subplot(1, 3, 3)
                if len(valid_pred) > 0:
                    plt.hist(valid_pred, bins=50, alpha=0.7, edgecolor='black')
                    plt.xlabel('Predicted Height (m)')
                    plt.ylabel('Frequency')
                    plt.title(f'Height Distribution\nMean: {valid_pred.mean():.2f}m')
                    plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.savefig(f'{output_dir}/visualization_file_{file_idx:03d}.png', 
                           dpi=150, bbox_inches='tight')
                plt.show()
            
            # Add to results
            results.append(predictions_meters)
            
            # Print summary for this file
            valid_pixels = (predictions_meters != -9999).sum()
            nodata_pixels = (predictions_meters == -9999).sum()
            
            print(f"‚úÖ Completed TIF {file_idx}/{len(tif_list)}")
            print(f"   üìä Valid predictions: {valid_pixels:,} pixels")
            print(f"   üö´ NoData pixels: {nodata_pixels:,} pixels")
            
            if valid_pixels > 0:
                valid_heights = predictions_meters[predictions_meters != -9999]
                print(f"   üå≤ Height range: {valid_heights.min():.2f}m to {valid_heights.max():.2f}m")
                print(f"   üìà Mean height: {valid_heights.mean():.2f}m ¬± {valid_heights.std():.2f}m")
            
        except Exception as e:
            print(f"‚ùå Error processing {tif_path}: {str(e)}")
            import traceback
            traceback.print_exc()
            results.append(None)
            continue
    
    # Print final summary
    successful_files = sum(1 for r in results if r is not None)
    failed_files = len(results) - successful_files
    
    print("\n" + "="*60)
    print("‚≠ê BATCH INFERENCE COMPLETE ‚≠ê")
    print("="*60)
    print(f"üìÅ Output directory: {output_dir}")
    print(f"‚úÖ Successfully processed: {successful_files}/{len(tif_list)} files")
    if failed_files > 0:
        print(f"‚ùå Failed: {failed_files} files")
    print("="*60)
    
    return results

In [None]:
# Convert tif to decimeters and make 0 min
def convert_predictions_to_decimeters_int16(predictions):
    """
    Convert predictions from meters to decimeters, round to nearest integer, 
    and convert to int16 data type. Preserves -9999 NoData values.
    
    Parameters:
    predictions: numpy array with height predictions in meters (may contain -9999)
    
    Returns:
    converted_predictions: numpy array of int16 type with values in decimeters
    """
    # Create NoData mask
    nodata_mask = predictions == -9999
    
    # Work with valid data only
    valid_predictions = predictions.copy()
    valid_predictions[nodata_mask] = 0  # Temporarily set to 0 for processing
    
    # Multiply by 10 to convert meters to decimeters
    decimeters = valid_predictions * 10
    
    # Round to nearest integer
    rounded = np.round(decimeters)
    
    # Convert to int16
    converted_predictions = rounded.astype(np.int16)

    # Make all negative vlaues = 0
    converted_predictions[converted_predictions < 0] = 0 
    
    # Restore NoData values
    converted_predictions[nodata_mask] = -9999
    
    print(f"Converted to decimeters. NoData pixels preserved: {np.sum(nodata_mask):,}")
    
    return converted_predictions

In [None]:
#Compress and Save Tif
def save_predictions_as_geotiff(predictions_decimeters, reference_tif_path, output_path):
    """
    Save predictions as a GeoTIFF with matching geotransform and projection from reference,
    with LZW compression, NoData value of -9999, and metadata about units.
    
    Parameters:
    predictions_decimeters: numpy array of int16 predictions in decimeters
    reference_tif_path: path to the original .tif file for geospatial reference
    output_path: path where to save the output GeoTIFF
    """
    # Open the reference dataset to get geospatial information
    ref_ds = gdal.Open(reference_tif_path)
    if ref_ds is None:
        raise ValueError(f"Could not open reference file: {reference_tif_path}")
    
    # Get geospatial information from reference
    geotransform = ref_ds.GetGeoTransform()
    projection = ref_ds.GetProjection()
    
    # Get dimensions
    height, width = predictions_decimeters.shape
    
    # Replace NaN values with -9999 (NoData value)
    output_array = predictions_decimeters.copy()
    if np.issubdtype(predictions_decimeters.dtype, np.floating):
        output_array = np.where(np.isnan(predictions_decimeters), -9999, predictions_decimeters)
    output_array = output_array.astype(np.int16)
    
    # Create the output dataset
    driver = gdal.GetDriverByName('GTiff')
    
    # Create dataset with LZW compression
    out_ds = driver.Create(
        output_path, 
        width, 
        height, 
        1,  # number of bands
        gdal.GDT_Int16,  # data type
        options=['COMPRESS=LZW', 'TILED=YES']  # LZW compression + tiling for efficiency
    )
    
    if out_ds is None:
        raise ValueError(f"Could not create output file: {output_path}")
    
    # Set geospatial information
    out_ds.SetGeoTransform(geotransform)
    out_ds.SetProjection(projection)
    
    # Get the band and write data
    band = out_ds.GetRasterBand(1)
    band.WriteArray(output_array)
    
    # Set NoData value
    band.SetNoDataValue(-9999)
    
    # Add metadata about units
    band.SetDescription("Canopy Height Model - Values in decimeters")
    band.SetMetadataItem("UNITS", "decimeters")
    band.SetMetadataItem("DESCRIPTION", "Predicted canopy heights in decimeters (1 meter = 10 decimeters)")
    
    # Add dataset-level metadata
    out_ds.SetMetadataItem("PROCESSING", "DINOv3 model prediction")
    out_ds.SetMetadataItem("NODATA_VALUE", "-9999")
    out_ds.SetMetadataItem("UNITS", "decimeters")
    
    # Flush and close
    band.FlushCache()
    out_ds.FlushCache()
    band = None
    out_ds = None
    ref_ds = None
    
    print(f"Successfully saved predictions to: {output_path}")
    print(f"- Compression: LZW")
    print(f"- NoData value: -9999") 
    print(f"- Units: decimeters")
    print(f"- Data type: int16")
    print(f"- Dimensions: {height} x {width}")

In [None]:
# Usage:
if model_for_eval is not None:
    print(model_filename)
    batch_results = complete_dinov3_inference_pipeline_batch(
        tif_list=tif_list,
        model=model_for_eval,
        data_config=DATA_CONFIG,
        inference_config=INFERENCE_CONFIG
    )
else:
    print("‚ùå No model available for inference")