In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
import torchvision.models as models

DATA_DIR = '/kaggle/input/csiro-biomass'
WORKING_DIR = '/kaggle/working'

USE_PRETRAINED_WEIGHTS = True
PRETRAINED_WEIGHTS_DIR = '/kaggle/input/pretrained-weights/pretrained_weights'

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

print('torch', torch.__version__)
print('cuda', torch.cuda.is_available())
if torch.cuda.is_available():
    print('device', torch.cuda.get_device_name(0))


In [None]:
train_df = pd.read_csv(f'{DATA_DIR}/train.csv')
test_df = pd.read_csv(f'{DATA_DIR}/test.csv')

TARGET_COLS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']

pivot = (
    train_df.pivot_table(index='image_path', columns='target_name', values='target', aggfunc='first')
    .reset_index()
)
for t in TARGET_COLS:
    if t not in pivot.columns:
        pivot[t] = np.nan
pivot = pivot[['image_path'] + TARGET_COLS]

meta_cols = [c for c in train_df.columns if c not in ('sample_id', 'target_name', 'target')]
meta_first = train_df.groupby('image_path', as_index=False)[meta_cols].first()
train_data = meta_first.merge(pivot, on='image_path', how='left')

train_data['full_image_path'] = train_data['image_path'].apply(lambda p: os.path.join(DATA_DIR, p))

species_codes, species_uniques = pd.factorize(train_data['Species'].astype('string'), sort=True)
state_codes, state_uniques = pd.factorize(train_data['State'].astype('string'), sort=True)
train_data['species_idx'] = species_codes.astype('int64')
train_data['state_idx'] = state_codes.astype('int64')

train_data['ndvi'] = train_data['Pre_GSHH_NDVI'].astype('float32')
train_data['height'] = train_data['Height_Ave_cm'].astype('float32')

train_data = train_data.drop(columns=['Pre_GSHH_NDVI', 'Height_Ave_cm'])

print('train_data', train_data.shape)
print('n_species', len(species_uniques), 'n_state', len(state_uniques))
print(train_data[['image_path'] + TARGET_COLS + ['ndvi','height','species_idx','state_idx']].head())

test_images = test_df.groupby('image_path').first().reset_index()
test_images['full_image_path'] = test_images['image_path'].apply(lambda p: os.path.join(DATA_DIR, p))
print('test_images', test_images.shape)


In [None]:
IMG_SIZE = 384

transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_val = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

class MultiTaskDataset(Dataset):
    def __init__(self, df, transform, has_targets: bool):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.has_targets = has_targets

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['full_image_path']).convert('RGB')
        x = self.transform(img) if self.transform else img

        if not self.has_targets:
            y = torch.zeros(5, dtype=torch.float32)
            aux = {
                'ndvi': torch.tensor(0.0, dtype=torch.float32),
                'height': torch.tensor(0.0, dtype=torch.float32),
                'species': torch.tensor(0, dtype=torch.long),
                'state': torch.tensor(0, dtype=torch.long),
            }
            return x, y, aux

        y = torch.tensor([row[c] for c in TARGET_COLS], dtype=torch.float32)
        aux = {
            'ndvi': torch.tensor(float(row['ndvi']), dtype=torch.float32),
            'height': torch.tensor(float(row['height']), dtype=torch.float32),
            'species': torch.tensor(int(row['species_idx']), dtype=torch.long),
            'state': torch.tensor(int(row['state_idx']), dtype=torch.long),
        }
        return x, y, aux

print('dataset ready')

In [None]:
def load_pretrained_weights(model, weights_path):
    if os.path.exists(weights_path):
        state_dict = torch.load(weights_path, map_location='cpu')
        if isinstance(state_dict, dict) and 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        if isinstance(state_dict, dict) and any(k.startswith('module.') for k in state_dict.keys()):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        filtered = {k: v for k, v in state_dict.items() if not (k.startswith('fc.') or k.startswith('classifier.'))}
        model.load_state_dict(filtered, strict=False)
        return True
    return False

class MultiTaskEffNetB0(nn.Module):
    def __init__(self, n_species: int, n_state: int, pretrained_path: str | None):
        super().__init__()
        self.backbone = models.efficientnet_b0(weights=None)
        if pretrained_path:
            load_pretrained_weights(self.backbone, pretrained_path)
        feat_dim = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()

        self.biomass_head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 5),
        )

        self.reg_head = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
        )

        self.species_head = nn.Linear(feat_dim, n_species)
        self.state_head = nn.Linear(feat_dim, n_state)

    def forward(self, x):
        feat = self.backbone(x)
        y = self.biomass_head(feat)
        reg = self.reg_head(feat)
        sp = self.species_head(feat)
        st = self.state_head(feat)
        return y, reg, sp, st

class MultiTaskResNet34(nn.Module):
    def __init__(self, n_species: int, n_state: int, pretrained_path: str | None):
        super().__init__()
        self.backbone = models.resnet34(weights=None)
        if pretrained_path:
            load_pretrained_weights(self.backbone, pretrained_path)
        feat_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.biomass_head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 5),
        )

        self.reg_head = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
        )

        self.species_head = nn.Linear(feat_dim, n_species)
        self.state_head = nn.Linear(feat_dim, n_state)

    def forward(self, x):
        feat = self.backbone(x)
        y = self.biomass_head(feat)
        reg = self.reg_head(feat)
        sp = self.species_head(feat)
        st = self.state_head(feat)
        return y, reg, sp, st

class MultiTaskResNet50(nn.Module):
    def __init__(self, n_species: int, n_state: int, pretrained_path: str | None):
        super().__init__()
        self.backbone = models.resnet50(weights=None)
        if pretrained_path:
            load_pretrained_weights(self.backbone, pretrained_path)
        feat_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.biomass_head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 5),
        )

        self.reg_head = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
        )

        self.species_head = nn.Linear(feat_dim, n_species)
        self.state_head = nn.Linear(feat_dim, n_state)

    def forward(self, x):
        feat = self.backbone(x)
        y = self.biomass_head(feat)
        reg = self.reg_head(feat)
        sp = self.species_head(feat)
        st = self.state_head(feat)
        return y, reg, sp, st

print('models ready')

In [None]:
from sklearn.model_selection import KFold

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 12
num_workers = 2 if torch.cuda.is_available() else 0

kfold = KFold(n_splits=5, shuffle=True, random_state=SEED)

weight_paths = {}
if USE_PRETRAINED_WEIGHTS:
    weight_paths['effnetb0'] = os.path.join(PRETRAINED_WEIGHTS_DIR, 'efficientnet_b0_rwightman-3dd342df.pth')
    weight_paths['resnet34'] = os.path.join(PRETRAINED_WEIGHTS_DIR, 'resnet34-b627a593.pth')
    weight_paths['resnet50'] = os.path.join(PRETRAINED_WEIGHTS_DIR, 'resnet50-0676ba61.pth')
    for k, v in list(weight_paths.items()):
        if not os.path.exists(v):
            del weight_paths[k]

loss_biomass = nn.HuberLoss(delta=1.0)
loss_reg = nn.MSELoss()
loss_ce = nn.CrossEntropyLoss()
loss_mse = nn.MSELoss()

ALPHA_REG = 0.2
BETA_CLS = 0.1
GAMMA_CONSTRAINT = 0.01
MIXUP_ALPHA = 0.2
USE_MIXUP = False

model_states = {}
val_losses = {}

backbone_configs = [
    ('effnetb0', MultiTaskEffNetB0, weight_paths.get('effnetb0')),
    ('resnet34', MultiTaskResNet34, weight_paths.get('resnet34')),
    ('resnet50', MultiTaskResNet50, weight_paths.get('resnet50')),
]

for backbone_name, ModelClass, weight_path in backbone_configs:
    if weight_path is None and USE_PRETRAINED_WEIGHTS:
        print(f'Skipping {backbone_name} - no pretrained weights')
        continue

    for fold_idx, (tr_idx, va_idx) in enumerate(kfold.split(train_data)):
        tr = train_data.iloc[tr_idx].reset_index(drop=True)
        va = train_data.iloc[va_idx].reset_index(drop=True)

        train_ds = MultiTaskDataset(tr, transform_train, has_targets=True)
        val_ds = MultiTaskDataset(va, transform_val, has_targets=True)

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

        model = ModelClass(len(species_uniques), len(state_uniques), weight_path).to(device)
        opt = optim.AdamW(model.parameters(), lr=8e-4, weight_decay=1e-4)
        
        warmup_epochs = 3
        total_epochs = 50
        warmup_sched = optim.lr_scheduler.LinearLR(opt, start_factor=0.1, total_iters=warmup_epochs)
        cosine_sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_epochs - warmup_epochs, eta_min=1e-6)

        best = 1e18
        best_state = None
        second_best = 1e18
        second_best_state = None
        patience = 10
        no_improve = 0

        for epoch in range(total_epochs):
            model.train()
            for x, y, aux in train_loader:
                x = x.to(device)
                y = y.to(device)
                ndvi = aux['ndvi'].to(device)
                height = aux['height'].to(device)
                sp = aux['species'].to(device)
                st = aux['state'].to(device)

                if USE_MIXUP and np.random.random() < 0.5:
                    x, y_a, y_b, lam = mixup_data(x, y, MIXUP_ALPHA)
                    opt.zero_grad()
                    y_hat, reg_hat, sp_hat, st_hat = model(x)
                    l_main = lam * loss_biomass(y_hat, y_a) + (1 - lam) * loss_biomass(y_hat, y_b)
                    l_reg = loss_reg(reg_hat[:, 0], ndvi) + loss_reg(reg_hat[:, 1], height)
                    l_cls = loss_ce(sp_hat, sp) + loss_ce(st_hat, st)
                    l_constraint = loss_mse(y_hat[:, 4], y_hat[:, 0] + y_hat[:, 2]) + loss_mse(y_hat[:, 3], y_hat[:, 4] + y_hat[:, 1])
                else:
                    opt.zero_grad()
                    y_hat, reg_hat, sp_hat, st_hat = model(x)
                    l_main = loss_biomass(y_hat, y)
                    l_reg = loss_reg(reg_hat[:, 0], ndvi) + loss_reg(reg_hat[:, 1], height)
                    l_cls = loss_ce(sp_hat, sp) + loss_ce(st_hat, st)
                    l_constraint = loss_mse(y_hat[:, 4], y_hat[:, 0] + y_hat[:, 2]) + loss_mse(y_hat[:, 3], y_hat[:, 4] + y_hat[:, 1])

                loss = l_main + ALPHA_REG * l_reg + BETA_CLS * l_cls + GAMMA_CONSTRAINT * l_constraint
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                opt.step()

            model.eval()
            val_main = 0.0
            with torch.no_grad():
                for x, y, aux in val_loader:
                    x = x.to(device)
                    y = y.to(device)
                    y_hat, _, _, _ = model(x)
                    val_main += loss_biomass(y_hat, y).item()
            val_main /= max(1, len(val_loader))
            
            if epoch < warmup_epochs:
                warmup_sched.step()
            else:
                cosine_sched.step()

            if val_main < best:
                second_best = best
                second_best_state = best_state
                best = val_main
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
                no_improve = 0
            elif val_main < second_best:
                second_best = val_main
                second_best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
                no_improve += 1
            else:
                no_improve += 1

            print(f'{backbone_name} fold {fold_idx+1} epoch {epoch+1}/{total_epochs} val_main {val_main:.4f} lr {opt.param_groups[0]["lr"]:.6f}')
            
            if no_improve >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break

        model_states[f'{backbone_name}_fold{fold_idx+1}'] = best_state
        val_losses[f'{backbone_name}_fold{fold_idx+1}'] = best
        
        if second_best_state is not None:
            model_states[f'{backbone_name}_fold{fold_idx+1}_2nd'] = second_best_state
            val_losses[f'{backbone_name}_fold{fold_idx+1}_2nd'] = second_best

print('done', len(model_states))

In [None]:
def post_process_predictions(preds):
    preds = preds.copy()
    preds = np.maximum(preds, 0.0)
    
    clover = preds[:, 0]
    dead = preds[:, 1]
    green = preds[:, 2]
    total = preds[:, 3]
    gdm = preds[:, 4]
    
    gdm_recalc = green + clover
    gdm = np.maximum(gdm, gdm_recalc * 0.95)
    
    total_recalc = gdm + dead
    total = np.maximum(total, total_recalc * 0.95)
    
    preds[:, 0] = clover
    preds[:, 1] = dead
    preds[:, 2] = green
    preds[:, 3] = total
    preds[:, 4] = gdm
    
    return preds

class BiomassEnsemble:
    def __init__(self, model_states: dict, val_losses: dict):
        inv = {k: 1.0 / (v + 1e-8) for k, v in val_losses.items()}
        s = sum(inv.values())
        self.weights = {k: inv[k] / s for k in inv}
        self.models = {}
        
        for name, state in model_states.items():
            backbone_name = name.split('_fold')[0]
            weight_path = weight_paths.get(backbone_name)
            
            if backbone_name == 'effnetb0':
                m = MultiTaskEffNetB0(len(species_uniques), len(state_uniques), weight_path).to(device)
            elif backbone_name == 'resnet34':
                m = MultiTaskResNet34(len(species_uniques), len(state_uniques), weight_path).to(device)
            elif backbone_name == 'resnet50':
                m = MultiTaskResNet50(len(species_uniques), len(state_uniques), weight_path).to(device)
            else:
                continue
            
            m.load_state_dict(state)
            m.eval()
            self.models[name] = m

    def predict(self, loader, use_tta=True):
        all_preds = []
        with torch.no_grad():
            for x, _, _ in loader:
                x = x.to(device)
                batch_preds = []
                
                if use_tta:
                    tta_transforms = [
                        lambda img: img,
                        lambda img: torch.flip(img, [3]),
                        lambda img: torch.flip(img, [2]),
                    ]
                    
                    for tta_fn in tta_transforms:
                        x_aug = tta_fn(x)
                        out = None
                        for name, m in self.models.items():
                            y_hat, _, _, _ = m(x_aug)
                            w = self.weights.get(name, 1.0 / len(self.models))
                            out = y_hat * w if out is None else out + y_hat * w
                        batch_preds.append(out.cpu().numpy())
                    
                    ensemble_pred = np.mean(batch_preds, axis=0)
                else:
                    out = None
                    for name, m in self.models.items():
                        y_hat, _, _, _ = m(x)
                        w = self.weights.get(name, 1.0 / len(self.models))
                        out = y_hat * w if out is None else out + y_hat * w
                    ensemble_pred = out.cpu().numpy()
                
                all_preds.append(ensemble_pred)
        
        preds = np.concatenate(all_preds, axis=0)
        return post_process_predictions(preds)


test_ds = MultiTaskDataset(test_images, transform_val, has_targets=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

ens = BiomassEnsemble(model_states, val_losses)
all_test_preds_cnn = ens.predict(test_loader, use_tta=True)

preds_df_cnn = pd.DataFrame(
    {
        'image_path': test_images['image_path'].values,
        **{c: all_test_preds_cnn[:, i] for i, c in enumerate(TARGET_COLS)},
    }
)

print('CNN predictions done, shape:', all_test_preds_cnn.shape)
print(f'Ensemble has {len(ens.models)} models')

USE_TABULAR_BLEND = False
CNN_WEIGHT = 1.0
TAB_WEIGHT = 0.0

if USE_TABULAR_BLEND:
    print('Attempting tabular blend...')
    try:
        tabular_file = f'{WORKING_DIR}/tabular_predictions.csv'
        if os.path.exists(tabular_file):
            tabular_preds_df = pd.read_csv(tabular_file)
            blended = preds_df_cnn.merge(tabular_preds_df, on='image_path', how='left', suffixes=('_cnn', '_tab'))
            
            CNN_WEIGHT = 0.90
            TAB_WEIGHT = 0.10
            
            for t in TARGET_COLS:
                cnn_col = f'{t}_cnn'
                tab_col = f'{t}_tab'
                if tab_col in blended.columns:
                    blended[t] = CNN_WEIGHT * blended[cnn_col].fillna(0) + TAB_WEIGHT * blended[tab_col].fillna(0)
                else:
                    blended[t] = blended[cnn_col]
            
            preds_df = blended[['image_path'] + TARGET_COLS]
            print(f'Blended CNN ({CNN_WEIGHT*100:.0f}%) + Tabular ({TAB_WEIGHT*100:.0f}%)')
        else:
            print('Tabular predictions not found, using CNN only')
            preds_df = preds_df_cnn
    except Exception as e:
        print(f'Tabular blending failed: {e}, using CNN only')
        preds_df = preds_df_cnn
else:
    preds_df = preds_df_cnn

merged = test_df.merge(preds_df, on='image_path', how='left')

rows = []
for _, r in merged.iterrows():
    rows.append({'sample_id': r['sample_id'], 'target': max(0.0, float(r[r['target_name']]))})

sub = pd.DataFrame(rows)
sub.to_csv(f'{WORKING_DIR}/submission.csv', index=False)
print(sub.head())
print('saved', f'{WORKING_DIR}/submission.csv')