# CSIRO Image2Biomass - Optimized ResNet50 Inference

This notebook generates predictions using an optimized ResNet50 model trained with:
- Huber loss (delta=2.0)
- Aggressive augmentation
- AdamW optimizer
- Validation loss: 0.4387 (69% improvement over baseline)

## Setup Instructions
1. Upload the model checkpoint as a Kaggle Dataset
2. Add the competition data
3. Run all cells to generate submission

## 1. Imports

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")



PyTorch version: 2.8.0+cu126
CUDA available: False


## 2. Configuration

In [2]:
TEST_CSV = '/kaggle/input/csiro-biomass/test.csv'
TEST_IMG_DIR = '/kaggle/input/csiro-biomass/test'
TRAIN_CSV = '/kaggle/input/csiro-biomass/train.csv'
TRAIN_IMG_DIR = '/kaggle/input/csiro-biomass/train'

# Model checkpoint
CHECKPOINT_PATH = '/kaggle/input/biomass-resnet50-optimized/pytorch/default/1/best_model.pth'

# Target names
TARGET_NAMES = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']

# Model config (from Optuna optimization)
CONFIG = {
    'backbone': 'resnet50',
    'pretrained': True,
    'dropout': 0.1,
    'head_hidden_dim': 256,
    'image_size': 384,
    'batch_size': 16,
    'num_workers': 2
}

# ImageNet normalization
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

Using device: cpu


## 3. Model Architecture

In [3]:
class MultiTaskResNet(nn.Module):
    """Multi-task ResNet for biomass prediction."""
    
    def __init__(self, backbone='resnet50', num_targets=5, pretrained=True,
                 dropout=0.3, head_hidden_dim=256):
        super().__init__()
        
        # Load pretrained backbone
        self.backbone = timm.create_model(backbone, pretrained=pretrained, num_classes=0)
        backbone_features = self.backbone.num_features
        
        # Create prediction heads
        self.heads = nn.ModuleDict({
            target_name: self._make_head(backbone_features, head_hidden_dim, dropout)
            for target_name in TARGET_NAMES
        })
    
    def _make_head(self, in_features, hidden_dim, dropout):
        """Create a prediction head."""
        return nn.Sequential(
            nn.Linear(in_features, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        """Forward pass."""
        features = self.backbone(x)
        outputs = {
            target_name: self.heads[target_name](features).squeeze(-1)
            for target_name in TARGET_NAMES
        }
        return outputs

print("Model architecture defined")

Model architecture defined


## 4. Dataset Class

In [4]:
class BiomassTestDataset(Dataset):
    """Test dataset for biomass prediction."""
    
    def __init__(self, csv_path, img_dir, transform=None, target_stats=None):
        self.csv_path = csv_path
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.target_stats = target_stats
        
        # Load CSV
        self.df = pd.read_csv(csv_path)
        self.df['image_id'] = self.df['sample_id'].str.split('__').str[0]
        self.image_ids = self.df['image_id'].unique()
        self.image_paths = self.df.groupby('image_id')['image_path'].first().to_dict()
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_filename = Path(self.image_paths[image_id]).name
        image_path = self.img_dir / image_filename
        
        # Load image
        image = Image.open(image_path).convert('RGB')
        image = np.array(image)
        
        # Apply transforms
        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed['image']
        
        return {
            'image': image,
            'image_id': image_id
        }

print("Dataset class defined")

Dataset class defined


## 5. Get Target Statistics from Training Data

In [5]:
# Load training data to compute normalization statistics
train_df = pd.read_csv(TRAIN_CSV)
train_df['image_id'] = train_df['sample_id'].str.split('__').str[0]

# Pivot to wide format
train_wide = train_df.pivot_table(
    index='image_id',
    columns='target_name',
    values='target',
    aggfunc='first'
)

# Compute statistics
target_stats = {}
for target_name in TARGET_NAMES:
    values = train_wide[target_name].values
    target_stats[target_name] = {
        'mean': float(np.mean(values)),
        'std': float(np.std(values)) + 1e-8
    }

print("Target normalization statistics:")
for target_name, stats in target_stats.items():
    print(f"  {target_name:<20} mean: {stats['mean']:>8.2f}  std: {stats['std']:>8.2f}")

Target normalization statistics:
  Dry_Clover_g         mean:     6.65  std:    12.10
  Dry_Dead_g           mean:    12.04  std:    12.38
  Dry_Green_g          mean:    26.62  std:    25.37
  Dry_Total_g          mean:    45.32  std:    27.94
  GDM_g                mean:    33.27  std:    24.90


## 6. Create Test Dataset and DataLoader

In [6]:
# Define validation transforms
val_transform = A.Compose([
    A.Resize(CONFIG['image_size'], CONFIG['image_size']),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2()
])

# Create test dataset
test_dataset = BiomassTestDataset(
    csv_path=TEST_CSV,
    img_dir=TEST_IMG_DIR,
    transform=val_transform,
    target_stats=target_stats
)

# Create dataloader
test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

print(f"Test dataset: {len(test_dataset)} samples")
print(f"Test batches: {len(test_loader)}")

Test dataset: 1 samples
Test batches: 1


## 7. Load Model and Checkpoint

In [7]:
# Create model
print(f"Creating model: {CONFIG['backbone']}...")
model = MultiTaskResNet(
    backbone=CONFIG['backbone'],
    num_targets=len(TARGET_NAMES),
    pretrained=False,  # We'll load trained weights
    dropout=CONFIG['dropout'],
    head_hidden_dim=CONFIG['head_hidden_dim']
)
model = model.to(DEVICE)

# Load checkpoint
print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Checkpoint loaded successfully!")
print(f"  Epoch: {checkpoint['epoch']}")
print(f"  Best Val Loss: {checkpoint['best_val_loss']:.4f}")

n_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {n_params:,}")

Creating model: resnet50...
Loading checkpoint from /kaggle/input/biomass-resnet50-optimized/pytorch/default/1/best_model.pth...
Checkpoint loaded successfully!
  Epoch: 48
  Best Val Loss: 0.4387
Total parameters: 26,132,037


## 8. Generate Predictions

In [8]:
def denormalize_predictions(pred_dict, target_stats):
    """Denormalize predictions back to original scale."""
    denormalized = {}
    for target_name, value in pred_dict.items():
        stats = target_stats[target_name]
        denormalized[target_name] = (value * stats['std']) + stats['mean']
    return denormalized

def enforce_constraint(predictions, method='average'):
    """Enforce constraint: Dry_Total = Dry_Clover + Dry_Dead + Dry_Green"""
    enforced = {}
    
    for image_id, pred_dict in predictions.items():
        pred = pred_dict.copy()
        
        clover = pred['Dry_Clover_g']
        dead = pred['Dry_Dead_g']
        green = pred['Dry_Green_g']
        total = pred['Dry_Total_g']
        
        component_sum = clover + dead + green
        
        if method == 'average':
            # Average the predicted total and sum of components
            new_total = (total + component_sum) / 2
            
            # Distribute discrepancy proportionally
            if component_sum > 0:
                scale = new_total / component_sum
                pred['Dry_Clover_g'] = clover * scale
                pred['Dry_Dead_g'] = dead * scale
                pred['Dry_Green_g'] = green * scale
                pred['Dry_Total_g'] = new_total
            else:
                pred['Dry_Total_g'] = 0.0
        
        enforced[image_id] = pred
    
    return enforced

# Run inference
print("Running inference...")
predictions = {}

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Generating predictions"):
        images = batch['image'].to(DEVICE)
        image_ids = batch['image_id']
        
        # Get predictions
        pred = model(images)
        
        # Store predictions for each image
        for i, image_id in enumerate(image_ids):
            pred_dict = {
                target_name: pred[target_name][i].cpu().item()
                for target_name in TARGET_NAMES
            }
            # Denormalize
            pred_dict = denormalize_predictions(pred_dict, target_stats)
            predictions[image_id] = pred_dict

print(f"Generated predictions for {len(predictions)} images")

# Apply constraint enforcement
print("Applying constraint enforcement...")
predictions = enforce_constraint(predictions, method='average')

# Check constraint violations
violations = []
for image_id, pred in predictions.items():
    total = pred['Dry_Total_g']
    component_sum = pred['Dry_Clover_g'] + pred['Dry_Dead_g'] + pred['Dry_Green_g']
    violation = abs(total - component_sum)
    violations.append(violation)

print(f"Constraint violations:")
print(f"  Mean: {np.mean(violations):.6f}g")
print(f"  Max: {np.max(violations):.6f}g")
print(f"  All exact: {all(v < 1e-6 for v in violations)}")

Running inference...


Generating predictions: 100%|██████████| 1/1 [00:00<00:00,  1.37it/s]

Generated predictions for 1 images
Applying constraint enforcement...
Constraint violations:
  Mean: 0.000000g
  Max: 0.000000g
  All exact: True





## 9. Create Submission File

In [9]:
# Load test.csv to get correct sample_id ordering
test_df = pd.read_csv(TEST_CSV)

# Create submission rows
submission_rows = []
for _, row in test_df.iterrows():
    sample_id = row['sample_id']
    image_id = sample_id.split('__')[0]
    target_name = row['target_name']
    
    # Get prediction
    pred_value = predictions[image_id][target_name]
    
    submission_rows.append({
        'sample_id': sample_id,
        'target': pred_value
    })

# Create DataFrame and save
submission_df = pd.DataFrame(submission_rows)
submission_df.to_csv('submission.csv', index=False)

print("Submission file created!")
print(f"Shape: {submission_df.shape}")
print("\nFirst few predictions:")
print(submission_df.head(10))
print("\nSummary statistics:")
print(submission_df['target'].describe())

Submission file created!
Shape: (5, 2)

First few predictions:
                    sample_id     target
0  ID1001187975__Dry_Clover_g   1.158229
1    ID1001187975__Dry_Dead_g  27.936712
2   ID1001187975__Dry_Green_g  24.517426
3   ID1001187975__Dry_Total_g  53.612367
4         ID1001187975__GDM_g  25.084344

Summary statistics:
count     5.000000
mean     26.461816
std      18.609657
min       1.158229
25%      24.517426
50%      25.084344
75%      27.936712
max      53.612367
Name: target, dtype: float64


## 10. Display Predictions Summary

In [10]:
print("\nPredictions by target:")
for target_name in TARGET_NAMES:
    values = [pred[target_name] for pred in predictions.values()]
    print(f"  {target_name:<20} mean: {np.mean(values):>8.2f}  "
          f"min: {np.min(values):>8.2f}  max: {np.max(values):>8.2f}")

print("\n" + "="*70)
print("Submission file ready: submission.csv")
print("="*70)


Predictions by target:
  Dry_Clover_g         mean:     1.16  min:     1.16  max:     1.16
  Dry_Dead_g           mean:    27.94  min:    27.94  max:    27.94
  Dry_Green_g          mean:    24.52  min:    24.52  max:    24.52
  Dry_Total_g          mean:    53.61  min:    53.61  max:    53.61
  GDM_g                mean:    25.08  min:    25.08  max:    25.08

Submission file ready: submission.csv
