# CSIRO Image2Biomass - V5: RGB+Depth Fusion (Optimized)

**Optimized for speed:**
- Depth model loaded ONCE and shared across all folds
- Depth maps pre-computed for each TTA transform
- Only RGB encoder runs per fold (much faster)

## Setup
1. Add model dataset + competition data
2. **Set Internet to OFF**
3. Run all cells

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
import gc

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

In [None]:
# Paths
TEST_CSV = '/kaggle/input/csiro-biomass/test.csv'
TEST_IMG_DIR = '/kaggle/input/csiro-biomass/test'
TRAIN_CSV = '/kaggle/input/csiro-biomass/train.csv'

MODEL_BASE = '/kaggle/input/image2biomass-depth-fusion-model/pytorch/default/1'
DEPTH_MODEL_PATH = f'{MODEL_BASE}/depth_anything_v2_small'
N_FOLDS = 5

TARGET_NAMES = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']

CONFIG = {
    'backbone': 'efficientnetv2_rw_m',
    'image_size': 384,
    'batch_size': 4,
    'dropout': 0.3
}

# Reduce TTA for speed (4 transforms instead of 8)
USE_TTA = True
TTA_ROTATIONS = False  # Skip rotations to halve inference time

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")

## Shared Depth Estimator (loaded once)

In [None]:
class SharedDepthEstimator(nn.Module):
    """Depth estimator loaded once and shared across all fold models."""
    
    def __init__(self, model_path):
        super().__init__()
        self.model = AutoModelForDepthEstimation.from_pretrained(model_path)
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
    
    @torch.no_grad()
    def forward(self, images):
        B, C, H, W = images.shape
        outputs = self.model(images)
        depth = outputs.predicted_depth
        depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bilinear', align_corners=False)
        
        # Normalize to [0, 1]
        depth_flat = depth.view(B, -1)
        depth_min = depth_flat.min(dim=1, keepdim=True)[0].view(B, 1, 1, 1)
        depth_max = depth_flat.max(dim=1, keepdim=True)[0].view(B, 1, 1, 1)
        depth = (depth - depth_min) / (depth_max - depth_min + 1e-8)
        return depth

# Load depth model ONCE
print("Loading shared depth model...")
shared_depth_model = SharedDepthEstimator(DEPTH_MODEL_PATH).to(DEVICE)
print("Depth model loaded!")

## Lightweight Fusion Model (uses pre-computed depth)

In [None]:
class LightweightFusionModel(nn.Module):
    """Fusion model that takes pre-computed depth maps (no depth estimation)."""
    
    def __init__(self, rgb_backbone, dropout=0.3, pretrained=False):
        super().__init__()
        self.target_names = TARGET_NAMES
        
        # RGB Encoder
        self.rgb_encoder = timm.create_model(
            rgb_backbone, pretrained=pretrained, num_classes=0, global_pool='avg'
        )
        rgb_features = self.rgb_encoder.num_features
        
        # Depth Feature Encoder (same architecture as training)
        self.depth_encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )
        depth_features = 256
        fused_features = rgb_features + depth_features
        
        # Regression heads
        self.heads = nn.ModuleDict()
        for name in self.target_names:
            self.heads[name] = nn.Sequential(
                nn.Linear(fused_features, 256),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(256, 64),
                nn.ReLU(inplace=True),
                nn.Linear(64, 1)
            )
    
    def forward(self, images, depth_maps):
        """Forward with pre-computed depth maps."""
        rgb_features = self.rgb_encoder(images)
        depth_features = self.depth_encoder(depth_maps)
        fused = torch.cat([rgb_features, depth_features], dim=1)
        
        return {name: self.heads[name](fused).squeeze(-1) for name in self.target_names}

print("LightweightFusionModel defined")

## Dataset and Transforms

In [None]:
class BiomassTestDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.df = pd.read_csv(csv_path)
        self.df['image_id'] = self.df['sample_id'].str.split('__').str[0]
        self.image_ids = self.df['image_id'].unique()
        self.image_paths = self.df.groupby('image_id')['image_path'].first().to_dict()
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = self.img_dir / Path(self.image_paths[image_id]).name
        image = np.array(Image.open(image_path).convert('RGB'))
        
        if self.transform:
            image = self.transform(image=image)['image']
        return {'image': image, 'image_id': image_id}

def get_tta_transforms(image_size=384, include_rotations=False):
    """Get TTA transforms - 4 without rotations, 8 with."""
    transforms = []
    flip_configs = [(False, False), (True, False), (False, True), (True, True)]
    rotation_angles = [0, 90] if include_rotations else [0]
    
    for hflip, vflip in flip_configs:
        for angle in rotation_angles:
            aug_list = [A.Resize(image_size, image_size)]
            if hflip:
                aug_list.append(A.HorizontalFlip(p=1.0))
            if vflip:
                aug_list.append(A.VerticalFlip(p=1.0))
            if angle != 0:
                aug_list.append(A.Rotate(limit=(angle, angle), p=1.0, border_mode=0))
            aug_list.extend([A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ToTensorV2()])
            transforms.append(A.Compose(aug_list))
    return transforms

tta_transforms = get_tta_transforms(CONFIG['image_size'], include_rotations=TTA_ROTATIONS)
print(f"Using {len(tta_transforms)} TTA transforms")

## Get Target Statistics

In [None]:
train_df = pd.read_csv(TRAIN_CSV)
train_df['image_id'] = train_df['sample_id'].str.split('__').str[0]
train_wide = train_df.pivot_table(index='image_id', columns='target_name', values='target', aggfunc='first')

target_stats = {}
for name in TARGET_NAMES:
    values = train_wide[name].values
    target_stats[name] = {'mean': float(np.mean(values)), 'std': float(np.std(values)) + 1e-8}

print("Target stats loaded")

## Pre-compute Depth Maps for All TTA Transforms

In [None]:
print("Pre-computing depth maps for all TTA transforms...")
all_depth_maps = {}  # {tta_idx: {image_id: depth_tensor}}
all_images = {}  # {tta_idx: {image_id: image_tensor}}

for tta_idx, transform in enumerate(tta_transforms):
    print(f"  TTA {tta_idx + 1}/{len(tta_transforms)}...")
    
    dataset = BiomassTestDataset(TEST_CSV, TEST_IMG_DIR, transform=transform)
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)
    
    tta_depths = {}
    tta_images = {}
    
    with torch.no_grad():
        for batch in loader:
            images = batch['image'].to(DEVICE)
            image_ids = batch['image_id']
            
            # Compute depth maps
            depth_maps = shared_depth_model(images)
            
            # Store on CPU to save GPU memory
            for i, img_id in enumerate(image_ids):
                tta_depths[img_id] = depth_maps[i:i+1].cpu()
                tta_images[img_id] = images[i:i+1].cpu()
    
    all_depth_maps[tta_idx] = tta_depths
    all_images[tta_idx] = tta_images

# Free depth model from GPU
del shared_depth_model
torch.cuda.empty_cache()
gc.collect()

print(f"Pre-computed depth maps for {len(all_depth_maps[0])} images x {len(tta_transforms)} TTA transforms")

## Load Models and Generate Predictions

In [None]:
def denormalize(pred_dict, stats):
    return {name: (val * stats[name]['std']) + stats[name]['mean'] for name, val in pred_dict.items()}

all_fold_predictions = []
image_ids_order = list(all_images[0].keys())

for fold_idx in range(N_FOLDS):
    print(f"\nFold {fold_idx + 1}/{N_FOLDS}...")
    
    # Create lightweight model (no depth estimator)
    model = LightweightFusionModel(
        rgb_backbone=CONFIG['backbone'],
        dropout=CONFIG['dropout'],
        pretrained=False
    ).to(DEVICE)
    
    # Load checkpoint
    checkpoint_path = Path(MODEL_BASE) / f'fold_{fold_idx}' / 'best_model.pth'
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
    
    # Load only the weights we need (skip depth_estimator weights)
    state_dict = checkpoint['model_state_dict']
    model_state = model.state_dict()
    
    for key in model_state.keys():
        if key in state_dict:
            model_state[key] = state_dict[key]
    
    model.load_state_dict(model_state)
    model.eval()
    print(f"  Loaded (Val Loss: {checkpoint['best_val_loss']:.4f})")
    
    # Generate predictions using pre-computed depth maps
    fold_preds = {name: {img_id: 0.0 for img_id in image_ids_order} for name in TARGET_NAMES}
    
    with torch.no_grad():
        for tta_idx in range(len(tta_transforms)):
            for img_id in image_ids_order:
                images = all_images[tta_idx][img_id].to(DEVICE)
                depth_maps = all_depth_maps[tta_idx][img_id].to(DEVICE)
                
                pred = model(images, depth_maps)
                pred_denorm = denormalize({n: pred[n][0].item() for n in TARGET_NAMES}, target_stats)
                
                for name in TARGET_NAMES:
                    fold_preds[name][img_id] += pred_denorm[name]
    
    # Average TTA
    n_tta = len(tta_transforms)
    for name in TARGET_NAMES:
        for img_id in image_ids_order:
            fold_preds[name][img_id] /= n_tta
    
    all_fold_predictions.append(fold_preds)
    
    # Free GPU memory
    del model
    torch.cuda.empty_cache()

print(f"\nGenerated predictions from {N_FOLDS} folds")

## Ensemble and Apply Constraints

In [None]:
# Average across folds
ensemble_predictions = {name: {} for name in TARGET_NAMES}

for img_id in image_ids_order:
    for name in TARGET_NAMES:
        fold_vals = [fp[name][img_id] for fp in all_fold_predictions]
        ensemble_predictions[name][img_id] = np.mean(fold_vals)

# Apply biological constraints
print("Applying biological constraints...")
for img_id in image_ids_order:
    # Clip negatives
    for name in TARGET_NAMES:
        ensemble_predictions[name][img_id] = max(0.0, ensemble_predictions[name][img_id])
    
    clover = ensemble_predictions['Dry_Clover_g'][img_id]
    dead = ensemble_predictions['Dry_Dead_g'][img_id]
    green = ensemble_predictions['Dry_Green_g'][img_id]
    gdm = ensemble_predictions['GDM_g'][img_id]
    total = ensemble_predictions['Dry_Total_g'][img_id]
    
    # GDM = Green + Clover
    gdm_calc = green + clover
    adj_gdm = (gdm + gdm_calc) / 2
    if gdm_calc > 0:
        scale = adj_gdm / gdm_calc
        ensemble_predictions['Dry_Green_g'][img_id] = green * scale
        ensemble_predictions['Dry_Clover_g'][img_id] = clover * scale
    ensemble_predictions['GDM_g'][img_id] = adj_gdm
    
    # Total = GDM + Dead
    total_calc = adj_gdm + dead
    adj_total = (total + total_calc) / 2
    if adj_total > adj_gdm:
        ensemble_predictions['Dry_Dead_g'][img_id] = adj_total - adj_gdm
    else:
        ensemble_predictions['Dry_Dead_g'][img_id] = 0.0
        adj_total = adj_gdm
    ensemble_predictions['Dry_Total_g'][img_id] = adj_total

print("\nPredictions summary:")
for name in TARGET_NAMES:
    vals = list(ensemble_predictions[name].values())
    print(f"  {name:<15} mean: {np.mean(vals):>8.2f}  min: {np.min(vals):>8.2f}  max: {np.max(vals):>8.2f}")

## Create Submission

In [None]:
test_df = pd.read_csv(TEST_CSV)

submission_rows = []
for img_id in image_ids_order:
    for name in TARGET_NAMES:
        submission_rows.append({
            'sample_id': f"{img_id}__{name}",
            'target': ensemble_predictions[name][img_id]
        })

submission_df = pd.DataFrame(submission_rows)
submission_df.to_csv('submission.csv', index=False)

print("Submission created!")
print(submission_df)