In [None]:
!pip install -qU timm tifffile

## Imports


In [None]:
import os, re, random
from dataclasses import dataclass
import numpy as np
import pandas as pd
import cv2
import tifffile as tiff
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
import timm

## Seed


In [None]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

## Config


In [None]:
@dataclass
class CFG:
    ROOT: str = "/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026/Kaggle_Prepared"
    OUT_DIR: str = "/kaggle/working/"

    IMG_SIZE: int = 128
    BATCH_SIZE: int = 32
    HS_CHANNELS: int = 30

    EPOCHS: int = 35
    LR: float = 3e-4
    WD: float = 1e-3
    FOLDS: int = 5
    SEED: int = 42

    RGB_BACKBONE: str = "tf_efficientnetv2_s.in21k"
    MIXUP_ALPHA: float = 1.0

    USE_PSEUDO_LABELING: bool = True
    PSEUDO_THRESH: float = 0.90

    LABELS = ["Health", "Rust", "Other"]
    LBL2ID = {k: i for i, k in enumerate(LABELS)}
    ID2LBL = {i: k for k, i in LBL2ID.items()}


seed_everything(CFG.SEED)
os.makedirs(CFG.OUT_DIR, exist_ok=True)

## Utilities


In [None]:
def robust_minmax(img):
    mn, mx = img.min(), img.max()
    if mx - mn < 1e-8:
        return np.zeros_like(img, dtype=np.float32)
    return (img - mn) / (mx - mn)


def read_rgb(path, size):
    img = cv2.imread(path)
    if img is None:
        return np.zeros((size, size, 3), dtype=np.float32)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (size, size))
    return img.astype(np.float32) / 255.0


def read_spectral(ms_path, hs_path, size, target_hs_ch):
    if os.path.exists(ms_path):
        ms = tiff.imread(ms_path).astype(np.float32)
        ms = cv2.resize(ms, (size, size), interpolation=cv2.INTER_CUBIC)

        eps = 1e-8
        nir, red, re, green = ms[..., 4], ms[..., 2], ms[..., 3], ms[..., 1]
        ndvi = (nir - red) / (nir + red + eps)
        gndvi = (nir - green) / (nir + green + eps)
        ndre = (nir - re) / (nir + re + eps)
        mcari = ((re - red) - 0.2 * (re - green)) * (re / (red + eps))

        ms = robust_minmax(ms)
        indices = np.stack([ndvi, gndvi, ndre, mcari], axis=-1)
        indices = (indices - indices.min()) / (indices.max() - indices.min() + eps)
        ms_block = np.concatenate([ms, indices], axis=-1)
    else:
        ms_block = np.zeros((size, size, 9), dtype=np.float32)

    if os.path.exists(hs_path):
        hs = tiff.imread(hs_path).astype(np.float32)
        hs = hs[..., 10:-14]

        if hs.shape[2] > target_hs_ch:
            idx = np.linspace(0, hs.shape[2] - 1, target_hs_ch).astype(int)
            hs = hs[..., idx]
        elif hs.shape[2] < target_hs_ch:
            pad = np.zeros((hs.shape[0], hs.shape[1], target_hs_ch - hs.shape[2]), dtype=np.float32)
            hs = np.concatenate([hs, pad], axis=-1)

        hs = cv2.resize(hs, (size, size), interpolation=cv2.INTER_CUBIC)
        hs = robust_minmax(hs)
    else:
        hs = np.zeros((size, size, target_hs_ch), dtype=np.float32)

    return np.concatenate([ms_block, hs], axis=-1)


def get_file_dataframe(root, split):
    data = []
    rgb_dir = os.path.join(root, split, "RGB")
    ms_dir = os.path.join(root, split, "MS")
    hs_dir = os.path.join(root, split, "HS")

    for f in os.listdir(rgb_dir):
        if not f.endswith(".png"):
            continue
        bid = f.replace(".png", "")
        label = "Unknown"
        if split == "train":
            match = re.match(r"^(Health|Rust|Other)_", bid)
            if match:
                label = match.group(1)

        data.append(
            {
                "base_id": bid,
                "label": label,
                "rgb_path": os.path.join(rgb_dir, f),
                "ms_path": os.path.join(ms_dir, bid + ".tif"),
                "hs_path": os.path.join(hs_dir, bid + ".tif"),
            }
        )
    return pd.DataFrame(data)

## Dataset


In [None]:
class AgriDataset(Dataset):
    def __init__(self, df, train=True):
        self.df = df.reset_index(drop=True)
        self.train = train

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        rgb = read_rgb(row["rgb_path"], CFG.IMG_SIZE)
        spec = read_spectral(row["ms_path"], row["hs_path"], CFG.IMG_SIZE, CFG.HS_CHANNELS)

        if self.train:
            if random.random() > 0.5:
                rgb = np.fliplr(rgb)
                spec = np.fliplr(spec)
            if random.random() > 0.5:
                rgb = np.flipud(rgb)
                spec = np.flipud(spec)
            if random.random() > 0.5:
                k = random.randint(1, 3)
                rgb = np.rot90(rgb, k)
                spec = np.rot90(spec, k)

        rgb = np.ascontiguousarray(rgb.transpose(2, 0, 1))
        spec = np.ascontiguousarray(spec.transpose(2, 0, 1))

        rgb_t = torch.from_numpy(rgb).float()
        spec_t = torch.from_numpy(spec).float()

        label = CFG.LBL2ID.get(row["label"], -1)
        return rgb_t, spec_t, torch.tensor(label, dtype=torch.long)

## Model


In [None]:
class DualStreamModel(nn.Module):
    def __init__(self, spec_ch=39, n_classes=3):
        super().__init__()

        self.rgb_net = timm.create_model(CFG.RGB_BACKBONE, pretrained=True, num_classes=0)

        self.spec_net = nn.Sequential(
            nn.Conv2d(spec_ch, 64, 1),
            nn.BatchNorm2d(64),
            nn.SiLU(),
            nn.Conv2d(64, 64, 3, padding=1, groups=64),
            nn.Conv2d(64, 128, 1),
            nn.BatchNorm2d(128),
            nn.SiLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, padding=1, groups=128),
            nn.Conv2d(128, 256, 1),
            nn.BatchNorm2d(256),
            nn.SiLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )

        combined_dim = self.rgb_net.num_features + 256
        self.head = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.4),
            nn.SiLU(),
            nn.Linear(512, n_classes),
        )

    def forward(self, rgb, spec):
        f1 = self.rgb_net(rgb)
        f2 = self.spec_net(spec)
        return self.head(torch.cat([f1, f2], dim=1))

## Training


In [None]:
def mixup(rgb, spec, y, alpha):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    idx = torch.randperm(rgb.size(0)).to(rgb.device)
    m_rgb = lam * rgb + (1 - lam) * rgb[idx]
    m_spec = lam * spec + (1 - lam) * spec[idx]
    return m_rgb, m_spec, y, y[idx], lam


def train_epoch(model, loader, opt, scaler, device):
    model.train()
    loss_sum = 0
    for rgb, spec, y in loader:
        rgb, spec, y = rgb.to(device), spec.to(device), y.to(device)
        rgb, spec, y_a, y_b, lam = mixup(rgb, spec, y, CFG.MIXUP_ALPHA)

        opt.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda"):
            out = model(rgb, spec)
            loss = lam * F.cross_entropy(out, y_a) + (1 - lam) * F.cross_entropy(out, y_b)

        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        loss_sum += loss.item()
    return loss_sum / len(loader)


@torch.no_grad()
def infer_loop(model, loader, device):
    model.eval()
    probs = []
    for rgb, spec, _ in loader:
        rgb, spec = rgb.to(device), spec.to(device)
        p1 = model(rgb, spec).softmax(1)
        p2 = model(torch.flip(rgb, [3]), torch.flip(spec, [3])).softmax(1)
        probs.append((p1 + p2) / 2)
    return torch.cat(probs).cpu().numpy()

## Execution


In [None]:
train_df = get_file_dataframe(CFG.ROOT, "train")
test_df = get_file_dataframe(CFG.ROOT, "val")
SPEC_CH = 9 + CFG.HS_CHANNELS

skf = StratifiedKFold(n_splits=CFG.FOLDS, shuffle=True, random_state=CFG.SEED)
device = torch.device("cuda")

# --- CYCLE 1 ---
print(">>> Cycle 1: Training")
for fold, (t_idx, v_idx) in enumerate(skf.split(train_df, train_df["label"])):
    ds_tr = AgriDataset(train_df.iloc[t_idx], train=True)
    dl_tr = DataLoader(ds_tr, batch_size=CFG.BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

    model = DualStreamModel(SPEC_CH).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=CFG.LR, weight_decay=CFG.WD)
    scaler = torch.amp.GradScaler("cuda")

    for ep in range(CFG.EPOCHS):
        train_epoch(model, dl_tr, opt, scaler, device)

    torch.save(model.state_dict(), f"{CFG.OUT_DIR}/model_c1_f{fold}.pt")
    print(f"Fold {fold} done.")

# --- PSEUDO LABELING ---
if CFG.USE_PSEUDO_LABELING:
    print("\n>>> Generating Pseudo Labels")
    ds_test = AgriDataset(test_df, train=False)
    dl_test = DataLoader(ds_test, batch_size=CFG.BATCH_SIZE * 2, num_workers=2)

    avg_preds = np.zeros((len(test_df), 3))
    for fold in range(CFG.FOLDS):
        m = DualStreamModel(SPEC_CH).to(device)
        m.load_state_dict(torch.load(f"{CFG.OUT_DIR}/model_c1_f{fold}.pt"))
        m.eval()
        avg_preds += infer_loop(m, dl_test, device) / CFG.FOLDS

    conf = avg_preds.max(1)
    pseudo_mask = conf > CFG.PSEUDO_THRESH
    pseudo_df = test_df[pseudo_mask].copy()
    pseudo_df["label"] = [CFG.ID2LBL[p] for p in avg_preds[pseudo_mask].argmax(1)]
    print(f"Added {len(pseudo_df)} pseudo-samples.")

    # --- CYCLE 2 ---
    print("\n>>> Cycle 2: Retraining")
    for fold, (t_idx, v_idx) in enumerate(skf.split(train_df, train_df["label"])):
        tr_curr = pd.concat([train_df.iloc[t_idx], pseudo_df]).reset_index(drop=True)
        ds_tr = AgriDataset(tr_curr, train=True)
        dl_tr = DataLoader(
            ds_tr,
            batch_size=CFG.BATCH_SIZE,
            shuffle=True,
            num_workers=2,
            drop_last=True,
        )

        model = DualStreamModel(SPEC_CH).to(device)
        model.load_state_dict(torch.load(f"{CFG.OUT_DIR}/model_c1_f{fold}.pt"))
        opt = torch.optim.AdamW(model.parameters(), lr=CFG.LR * 0.5, weight_decay=CFG.WD)
        scaler = torch.amp.GradScaler("cuda")

        for ep in range(10):
            train_epoch(model, dl_tr, opt, scaler, device)

        torch.save(model.state_dict(), f"{CFG.OUT_DIR}/model_c2_f{fold}.pt")
        print(f"Fold {fold} refined.")

## Inference


In [None]:
print("\n>>> Creating Submission")
ds_test = AgriDataset(test_df, train=False)
dl_test = DataLoader(ds_test, batch_size=CFG.BATCH_SIZE * 2, num_workers=2)

avg_preds = np.zeros((len(test_df), 3))
prefix = "model_c2_" if CFG.USE_PSEUDO_LABELING else "model_c1_"

for fold in range(CFG.FOLDS):
    m = DualStreamModel(SPEC_CH).to(device)
    m.load_state_dict(torch.load(f"{CFG.OUT_DIR}/{prefix}f{fold}.pt"))
    m.eval()
    avg_preds += infer_loop(m, dl_test, device) / CFG.FOLDS

final_cats = [CFG.ID2LBL[p] for p in avg_preds.argmax(1)]
final_ids = [bid + ".tif" for bid in test_df["base_id"]]

sub = pd.DataFrame({"Id": final_ids, "Category": final_cats})
sub.to_csv("submission.csv", index=False)
sub.head()