# CSIRO Image2Biomass - V5: RGB+Depth Fusion Ensemble

This notebook generates predictions using a **5-fold ensemble** of RGB+Depth Fusion models:
- **Architecture**: Dual encoder (EfficientNetV2-M RGB + Depth Anything v2)
- **Fusion**: Concatenation of RGB and depth features
- Mean validation loss: 3.06 +/- 1.55 (MSE)
- Image size: 384x384
- TTA: 8 transforms (flips + rotations)
- Biological constraints enforced in post-processing

## Setup Instructions
1. Add the model dataset (image2biomass-depth-fusion-model)
2. Add the competition data
3. **Set Internet to OFF** (required for submission)
4. Run all cells to generate submission

## 1. Imports

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import AutoImageProcessor, AutoModelForDepthEstimation

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"timm version: {timm.__version__}")

## 2. Configuration

In [None]:
# Paths
TEST_CSV = '/kaggle/input/csiro-biomass/test.csv'
TEST_IMG_DIR = '/kaggle/input/csiro-biomass/test'
TRAIN_CSV = '/kaggle/input/csiro-biomass/train.csv'
TRAIN_IMG_DIR = '/kaggle/input/csiro-biomass/train'

# Model checkpoints path
MODEL_BASE = '/kaggle/input/image2biomass-depth-fusion-model/pytorch/default/1'
DEPTH_MODEL_PATH = f'{MODEL_BASE}/depth_anything_v2_small'
N_FOLDS = 5

# Target names
TARGET_NAMES = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']

# Model config
CONFIG = {
    'backbone': 'efficientnetv2_rw_m',
    'depth_model': 'depth_anything_v2_small',
    'fusion_type': 'concat',
    'image_size': 384,
    'batch_size': 8,
    'num_workers': 0,
    'dropout': 0.3
}

# Use TTA
USE_TTA = True

# ImageNet normalization
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

## 3. Depth Estimator

In [None]:
class DepthEstimator(nn.Module):
    """Wrapper for Depth Anything v2 model."""
    
    def __init__(self, model_path, freeze=True):
        super().__init__()
        self.freeze = freeze
        
        # Load from local path
        self.processor = AutoImageProcessor.from_pretrained(model_path)
        self.model = AutoModelForDepthEstimation.from_pretrained(model_path)
        
        if freeze:
            for param in self.model.parameters():
                param.requires_grad = False
    
    def forward(self, images):
        B, C, H, W = images.shape
        
        with torch.no_grad() if self.freeze else torch.enable_grad():
            outputs = self.model(images)
            depth = outputs.predicted_depth
            
            depth = F.interpolate(
                depth.unsqueeze(1),
                size=(H, W),
                mode='bilinear',
                align_corners=False
            )
        
        # Normalize to [0, 1]
        depth_flat = depth.view(B, -1)
        depth_min = depth_flat.min(dim=1, keepdim=True)[0].view(B, 1, 1, 1)
        depth_max = depth_flat.max(dim=1, keepdim=True)[0].view(B, 1, 1, 1)
        depth = (depth - depth_min) / (depth_max - depth_min + 1e-8)
        
        return depth

print("DepthEstimator defined")

## 4. RGB+Depth Fusion Model

In [None]:
class RGBDepthFusionEncoder(nn.Module):
    """Dual-encoder model fusing RGB and depth features."""
    
    def __init__(self, rgb_backbone, depth_model_path, fusion_type='concat',
                 dropout=0.3, pretrained=True):
        super().__init__()
        
        self.fusion_type = fusion_type
        self.target_names = TARGET_NAMES
        
        # RGB Encoder
        self.rgb_encoder = timm.create_model(
            rgb_backbone,
            pretrained=pretrained,
            num_classes=0,
            global_pool='avg'
        )
        rgb_features = self.rgb_encoder.num_features
        
        # Depth Estimator
        self.depth_estimator = DepthEstimator(depth_model_path, freeze=True)
        
        # Depth Feature Encoder
        self.depth_encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )
        depth_features = 256
        
        # Fusion
        if fusion_type == 'concat':
            fused_features = rgb_features + depth_features
            self.fusion = nn.Identity()
        elif fusion_type == 'add':
            self.rgb_proj = nn.Linear(rgb_features, 512)
            self.depth_proj = nn.Linear(depth_features, 512)
            fused_features = 512
        elif fusion_type == 'attention':
            self.rgb_proj = nn.Linear(rgb_features, 512)
            self.depth_proj = nn.Linear(depth_features, 512)
            self.attention = nn.MultiheadAttention(512, num_heads=8, batch_first=True)
            fused_features = 512
        
        # Regression heads
        self.heads = nn.ModuleDict()
        for target_name in self.target_names:
            self.heads[target_name] = nn.Sequential(
                nn.Linear(fused_features, 256),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(256, 64),
                nn.ReLU(inplace=True),
                nn.Linear(64, 1)
            )
    
    def forward(self, images):
        rgb_features = self.rgb_encoder(images)
        
        with torch.no_grad():
            depth_maps = self.depth_estimator(images)
        depth_features = self.depth_encoder(depth_maps)
        
        if self.fusion_type == 'concat':
            fused = torch.cat([rgb_features, depth_features], dim=1)
        elif self.fusion_type == 'add':
            fused = self.rgb_proj(rgb_features) + self.depth_proj(depth_features)
        elif self.fusion_type == 'attention':
            rgb_proj = self.rgb_proj(rgb_features).unsqueeze(1)
            depth_proj = self.depth_proj(depth_features).unsqueeze(1)
            combined = torch.cat([rgb_proj, depth_proj], dim=1)
            attended, _ = self.attention(combined, combined, combined)
            fused = attended.mean(dim=1)
        
        predictions = {}
        for target_name in self.target_names:
            predictions[target_name] = self.heads[target_name](fused).squeeze(-1)
        
        return predictions

print("RGBDepthFusionEncoder defined")

## 5. Dataset Class

In [None]:
class BiomassTestDataset(Dataset):
    """Test dataset for biomass prediction."""
    
    def __init__(self, csv_path, img_dir, transform=None, target_stats=None):
        self.csv_path = csv_path
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.target_stats = target_stats
        
        self.df = pd.read_csv(csv_path)
        self.df['image_id'] = self.df['sample_id'].str.split('__').str[0]
        self.image_ids = self.df['image_id'].unique()
        self.image_paths = self.df.groupby('image_id')['image_path'].first().to_dict()
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_filename = Path(self.image_paths[image_id]).name
        image_path = self.img_dir / image_filename
        
        image = Image.open(image_path).convert('RGB')
        image = np.array(image)
        
        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed['image']
        
        return {'image': image, 'image_id': image_id}

print("BiomassTestDataset defined")

## 6. TTA Transforms

In [None]:
def get_tta_transforms(image_size=384):
    """Get 8 TTA transforms (4 flips x 2 rotations)."""
    transforms = []
    
    flip_configs = [
        (False, False),  # No flip
        (True, False),   # Horizontal flip
        (False, True),   # Vertical flip
        (True, True),    # Both flips
    ]
    
    rotation_angles = [0, 90]
    
    for hflip, vflip in flip_configs:
        for angle in rotation_angles:
            aug_list = [A.Resize(image_size, image_size)]
            
            if hflip:
                aug_list.append(A.HorizontalFlip(p=1.0))
            if vflip:
                aug_list.append(A.VerticalFlip(p=1.0))
            if angle != 0:
                aug_list.append(A.Rotate(limit=(angle, angle), p=1.0, border_mode=0))
            
            aug_list.extend([
                A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
                ToTensorV2()
            ])
            
            transforms.append(A.Compose(aug_list))
    
    return transforms

def get_val_transform(image_size=384):
    """Get validation transform (no augmentation)."""
    return A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ToTensorV2()
    ])

if USE_TTA:
    tta_transforms = get_tta_transforms(CONFIG['image_size'])
    print(f"Using {len(tta_transforms)} TTA transforms")
else:
    tta_transforms = [get_val_transform(CONFIG['image_size'])]
    print("TTA disabled")

## 7. Get Target Statistics

In [None]:
# Load training data for normalization statistics
train_df = pd.read_csv(TRAIN_CSV)
train_df['image_id'] = train_df['sample_id'].str.split('__').str[0]

train_wide = train_df.pivot_table(
    index='image_id',
    columns='target_name',
    values='target',
    aggfunc='first'
)

target_stats = {}
for target_name in TARGET_NAMES:
    values = train_wide[target_name].values
    target_stats[target_name] = {
        'mean': float(np.mean(values)),
        'std': float(np.std(values)) + 1e-8
    }

print("Target normalization statistics:")
for target_name, stats in target_stats.items():
    print(f"  {target_name:<20} mean: {stats['mean']:>8.2f}  std: {stats['std']:>8.2f}")

## 8. Load Fold Models

In [None]:
fold_models = []

print(f"Loading {N_FOLDS} fold models...")
for fold_idx in range(N_FOLDS):
    checkpoint_path = Path(MODEL_BASE) / f'fold_{fold_idx}' / 'best_model.pth'
    
    print(f"\nFold {fold_idx + 1}/{N_FOLDS}:")
    print(f"  Loading from: {checkpoint_path}")
    
    # Create model
    model = RGBDepthFusionEncoder(
        rgb_backbone=CONFIG['backbone'],
        depth_model_path=DEPTH_MODEL_PATH,
        fusion_type=CONFIG['fusion_type'],
        dropout=CONFIG['dropout'],
        pretrained=False
    )
    model = model.to(DEVICE)
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print(f"  Val Loss: {checkpoint['best_val_loss']:.4f}")
    
    fold_models.append(model)

print(f"\nLoaded {len(fold_models)} models")

## 9. Generate Predictions with TTA

In [None]:
def denormalize(pred_dict, target_stats):
    """Denormalize predictions."""
    return {
        name: (value * target_stats[name]['std']) + target_stats[name]['mean']
        for name, value in pred_dict.items()
    }

all_fold_predictions = []

for fold_idx, model in enumerate(fold_models):
    print(f"\nFold {fold_idx + 1}/{N_FOLDS}...")
    fold_predictions = {name: [] for name in TARGET_NAMES}
    
    with torch.no_grad():
        for tta_idx, tta_transform in enumerate(tta_transforms):
            # Create dataset with TTA transform
            test_dataset = BiomassTestDataset(
                csv_path=TEST_CSV,
                img_dir=TEST_IMG_DIR,
                transform=tta_transform,
                target_stats=target_stats
            )
            
            test_loader = DataLoader(
                test_dataset,
                batch_size=CONFIG['batch_size'],
                shuffle=False,
                num_workers=CONFIG['num_workers']
            )
            
            tta_preds = {name: [] for name in TARGET_NAMES}
            
            for batch in test_loader:
                images = batch['image'].to(DEVICE)
                predictions = model(images)
                
                batch_size = images.size(0)
                for i in range(batch_size):
                    pred_dict = {name: predictions[name][i].item() for name in TARGET_NAMES}
                    pred_denorm = denormalize(pred_dict, target_stats)
                    
                    for name in TARGET_NAMES:
                        tta_preds[name].append(pred_denorm[name])
            
            # Accumulate TTA predictions
            if tta_idx == 0:
                for name in TARGET_NAMES:
                    fold_predictions[name] = tta_preds[name]
            else:
                for name in TARGET_NAMES:
                    for i in range(len(tta_preds[name])):
                        fold_predictions[name][i] += tta_preds[name][i]
    
    # Average TTA predictions
    n_tta = len(tta_transforms)
    for name in TARGET_NAMES:
        fold_predictions[name] = [p / n_tta for p in fold_predictions[name]]
    
    all_fold_predictions.append(fold_predictions)
    print(f"  Generated predictions for {len(fold_predictions[TARGET_NAMES[0]])} samples")

print(f"\nGenerated predictions from all {len(fold_models)} folds")

## 10. Ensemble and Apply Constraints

In [None]:
# Average across folds
print("Averaging predictions across folds...")
n_samples = len(all_fold_predictions[0][TARGET_NAMES[0]])
ensemble_predictions = {name: [] for name in TARGET_NAMES}

for sample_idx in range(n_samples):
    for name in TARGET_NAMES:
        fold_preds = [fp[name][sample_idx] for fp in all_fold_predictions]
        ensemble_predictions[name].append(np.mean(fold_preds))

# Apply biological constraints
print("Applying biological constraints...")
for sample_idx in range(n_samples):
    # Clip negatives
    for name in TARGET_NAMES:
        ensemble_predictions[name][sample_idx] = max(0.0, ensemble_predictions[name][sample_idx])
    
    clover = ensemble_predictions['Dry_Clover_g'][sample_idx]
    dead = ensemble_predictions['Dry_Dead_g'][sample_idx]
    green = ensemble_predictions['Dry_Green_g'][sample_idx]
    gdm = ensemble_predictions['GDM_g'][sample_idx]
    total = ensemble_predictions['Dry_Total_g'][sample_idx]
    
    # Enforce GDM = Green + Clover
    gdm_calc = green + clover
    adjusted_gdm = (gdm + gdm_calc) / 2
    if gdm_calc > 0:
        scale = adjusted_gdm / gdm_calc
        ensemble_predictions['Dry_Green_g'][sample_idx] = green * scale
        ensemble_predictions['Dry_Clover_g'][sample_idx] = clover * scale
    ensemble_predictions['GDM_g'][sample_idx] = adjusted_gdm
    
    # Enforce Total = GDM + Dead
    total_calc = adjusted_gdm + dead
    adjusted_total = (total + total_calc) / 2
    if adjusted_total > adjusted_gdm:
        ensemble_predictions['Dry_Dead_g'][sample_idx] = adjusted_total - adjusted_gdm
    else:
        ensemble_predictions['Dry_Dead_g'][sample_idx] = 0.0
        adjusted_total = adjusted_gdm
    ensemble_predictions['Dry_Total_g'][sample_idx] = adjusted_total

print("\nPredictions summary:")
for name in TARGET_NAMES:
    values = ensemble_predictions[name]
    print(f"  {name:<15} mean: {np.mean(values):>8.2f}  min: {np.min(values):>8.2f}  max: {np.max(values):>8.2f}")

## 11. Create Submission

In [None]:
# Load test.csv to get sample_id ordering
test_df = pd.read_csv(TEST_CSV)
test_image_ids = test_df['sample_id'].str.split('__').str[0].unique()

# Create submission rows
submission_rows = []
for img_idx, image_id in enumerate(test_image_ids):
    for name in TARGET_NAMES:
        sample_id = f"{image_id}__{name}"
        prediction = ensemble_predictions[name][img_idx]
        submission_rows.append({'sample_id': sample_id, 'target': prediction})

submission_df = pd.DataFrame(submission_rows)
submission_df.to_csv('submission.csv', index=False)

print("Submission file created!")
print(f"Shape: {submission_df.shape}")
print("\nSubmission preview:")
print(submission_df)

## 12. Constraint Verification

In [None]:
print("Constraint verification:")
for sample_idx in range(n_samples):
    gdm = ensemble_predictions['GDM_g'][sample_idx]
    gdm_calc = ensemble_predictions['Dry_Green_g'][sample_idx] + ensemble_predictions['Dry_Clover_g'][sample_idx]
    total = ensemble_predictions['Dry_Total_g'][sample_idx]
    total_calc = ensemble_predictions['GDM_g'][sample_idx] + ensemble_predictions['Dry_Dead_g'][sample_idx]
    
    print(f"Sample {sample_idx}:")
    print(f"  GDM diff: {abs(gdm - gdm_calc):.6f}")
    print(f"  Total diff: {abs(total - total_calc):.6f}")

print("\n" + "="*70)
print("Submission ready: submission.csv")
print("RGB+Depth Fusion 5-Fold Ensemble with TTA")
print("="*70)