In [1]:
pip install rasterio

Collecting rasterio
  Downloading rasterio-1.5.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (8.6 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Downloading rasterio-1.5.0-cp312-cp312-manylinux_2_28_x86_64.whl (37.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.6/37.6 MB[0m [31m49.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading affine-2.4.0-py3-none-any.whl (15 kB)
Installing collected packages: affine, rasterio
Successfully installed affine-2.4.0 rasterio-1.5.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install transformers

Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.nn.parallel import DataParallel
import numpy as np
import cv2
import rasterio
from pathlib import Path
import copy
import random
from sklearn.model_selection import StratifiedKFold
from tqdm.auto import tqdm
import pandas as pd

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_gpus    = torch.cuda.device_count()
DATA_ROOT = "/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/ICPR02/kaggle"
SSL_ROOT  = "/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/archive/share"

print(f"Device: {device} | GPUs: {n_gpus}")
print(f"Data root: {DATA_ROOT}")
print(f"SSL root:  {SSL_ROOT}")

Device: cuda | GPUs: 2
Data root: /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/ICPR02/kaggle
SSL root:  /kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/archive/share


In [4]:
device

device(type='cuda')

# Temporal Dataset

In [5]:
class TemporalSSLDataset(Dataset):
    """
    For each field, samples two date acquisitions as a positive pair.
    If a field only has one date, pairs with itself using augmentation.
    Includes both train (crop/field/date) and val (field/bands) structures.
    """
    def __init__(self, ssl_root, target_size=(224, 224)):
        self.target_size = target_size
        self.bands = [
            'B1','B2','B3','B4','B5','B6',
            'B7','B8','B8A','B9','B11','B12'
        ]

        # field_id → list of paths that contain B1.tif
        self.fields = {}

        ssl_root = Path(ssl_root)

        # train: share/train/{crop}/{field_id}/{date_folder}/B1.tif
        train_dir = ssl_root / "train"
        if train_dir.exists():
            for crop_dir in sorted(train_dir.iterdir()):
                if not crop_dir.is_dir():
                    continue
                for field_dir in sorted(crop_dir.iterdir()):
                    if not field_dir.is_dir():
                        continue
                    key = f"{crop_dir.name}_{field_dir.name}"
                    dates = []
                    for date_dir in sorted(field_dir.iterdir()):
                        if date_dir.is_dir() and (date_dir / "B1.tif").exists():
                            dates.append(date_dir)
                    if dates:
                        self.fields[key] = dates

        # val: share/val/{field_id}/B1.tif
        val_dir = ssl_root / "val"
        if val_dir.exists():
            for field_dir in sorted(val_dir.iterdir()):
                if field_dir.is_dir() and (field_dir / "B1.tif").exists():
                    key = f"val_{field_dir.name}"
                    self.fields[key] = [field_dir]

        self.field_keys = list(self.fields.keys())
        print(f"SSL dataset: {len(self.field_keys)} fields, "
              f"{sum(len(v) for v in self.fields.values())} total acquisitions")
        multi = sum(1 for v in self.fields.values() if len(v) >= 2)
        print(f"  Fields with 2+ dates (temporal pairs): {multi}")
        print(f"  Fields with 1 date  (self-pairs):      {len(self.field_keys)-multi}")

    def __len__(self):
        return len(self.field_keys)

    def _load(self, path):
        bands = []
        for b in self.bands:
            with rasterio.open(path / f"{b}.tif") as src:
                x = src.read(1).astype(np.float32)
                if x.shape != self.target_size:
                    x = cv2.resize(x, self.target_size)
                bands.append(x)
        return torch.from_numpy(np.stack(bands))   # (12, H, W)

    def _augment(self, x):
        """Light augmentation used for self-pairs only"""
        _, H, W = x.shape
        crop = torch.randint(int(0.7*H), H+1, (1,)).item()
        i    = torch.randint(0, H-crop+1, (1,)).item()
        j    = torch.randint(0, W-crop+1, (1,)).item()
        x    = x[:, i:i+crop, j:j+crop]
        x    = F.interpolate(x.unsqueeze(0), size=self.target_size,
                             mode='bilinear', align_corners=False).squeeze(0)
        if torch.rand(1) > 0.5:
            x = torch.flip(x, [2])
        x = x + 0.01 * torch.randn_like(x)
        return x

    def __getitem__(self, idx):
        key   = self.field_keys[idx]
        dates = self.fields[key]

        if len(dates) >= 2:
            # real temporal pair — sample two different dates
            d1, d2 = random.sample(dates, 2)
            view1  = self._load(d1)
            view2  = self._load(d2)
        else:
            # self-pair with augmentation
            view1 = self._load(dates[0])
            view2 = self._augment(self._load(dates[0]))

        return view1, view2

# Disease Dataset Class

In [6]:
class S2Disease(Dataset):
    def __init__(self, root_dir, is_eval=False, target_size=(224, 224)):
        self.root_dir    = Path(root_dir)
        self.is_eval     = is_eval
        self.target_size = target_size
        self.bands = [
            'B1','B2','B3','B4','B5','B6',
            'B7','B8','B8A','B9','B11','B12'
        ]

        if is_eval:
            self.samples      = list((self.root_dir / "evaluation").glob("*/"))
            self.classes      = ['Aphid', 'Blast', 'RPH', 'Rust']
            self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        else:
            all_dirs          = [d for d in self.root_dir.iterdir() if d.is_dir()]
            self.classes      = sorted([d.name for d in all_dirs if d.name != "evaluation"])
            self.class_to_idx = {n: i for i, n in enumerate(self.classes)}
            self.samples      = []
            for cls in self.classes:
                self.samples.extend(list((self.root_dir / cls).glob("*/")))

        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
        self.num_classes  = len(self.classes)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_path = self.samples[idx]
        band_data   = []
        for band in self.bands:
            with rasterio.open(sample_path / f"{band}.tif") as src:
                data = src.read(1).astype(np.float32)
                if data.shape != self.target_size:
                    data = cv2.resize(data, self.target_size,
                                      interpolation=cv2.INTER_LINEAR)
                band_data.append(data)

        image = torch.from_numpy(np.stack(band_data))

        if self.is_eval:
            label = torch.zeros(self.num_classes)
        else:
            cls_name        = sample_path.parent.name
            label           = torch.zeros(self.num_classes)
            label[self.class_to_idx[cls_name]] = 1.0

        return {
            'image':     image,
            'label':     label,
            'sample_id': sample_path.name
        }

# Spectral Band

In [7]:
class SpectralMixer(nn.Module):
    """
    1x1 conv: learns cross-band relationships before patch embedding.
    Projects 12 bands → 64 learned spectral features.
    """
    def __init__(self, in_bands=12, out_channels=64):
        super().__init__()
        self.conv = nn.Conv2d(in_bands, out_channels, kernel_size=1, bias=False)
        self.bn   = nn.BatchNorm2d(out_channels)
        self.act  = nn.GELU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class PatchEmbed(nn.Module):
    """Standard ViT patch embedding on mixed spectral features"""
    def __init__(self, in_channels=64, patch_size=16, img_size=224, embed_dim=384):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches  = (img_size // patch_size) ** 2
        self.proj       = nn.Conv2d(in_channels, embed_dim,
                                    kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, C, H, W) → (B, n_patches, embed_dim)
        x = self.proj(x)                    # (B, embed_dim, H/p, W/p)
        x = x.flatten(2).transpose(1, 2)   # (B, n_patches, embed_dim)
        return x


def compute_spectral_variance(x, patch_size=16):
    """
    Computes band variance per patch from raw 12-band input.
    Used to identify low-information patches (uniform soil, water).

    Args:
        x: (B, 12, H, W) raw bands
    Returns:
        variance: (B, n_patches) per-patch spectral variance
    """
    B, C, H, W = x.shape
    n_ph = H // patch_size
    n_pw = W // patch_size

    # reshape into patches: (B, C, n_ph, patch_size, n_pw, patch_size)
    x_p = x.reshape(B, C, n_ph, patch_size, n_pw, patch_size)
    # (B, n_ph, n_pw, C, patch_size, patch_size)
    x_p = x_p.permute(0, 2, 4, 1, 3, 5)
    # (B, n_patches, C * patch_size * patch_size)
    x_p = x_p.reshape(B, n_ph * n_pw, -1)

    # variance across spectral+spatial dimensions per patch
    var = x_p.var(dim=-1)   # (B, n_patches)
    return var


class SpectraMaskViT(nn.Module):
    """
    Full architecture:
      1. SpectralMixer:   12 bands → 64 learned spectral features
      2. PatchEmbed:      64 → 384-dim tokens (16x16 patches → 196 tokens)
      3. EntropyMask:     remove bottom 30% lowest-variance patches
      4. ViT-Small:       6 layers, 384 dim, 6 heads
      5. Head:            DINO projection (SSL) or MLP (fine-tuning)
    """
    def __init__(self, img_size=224, patch_size=16,
                 embed_dim=384, depth=6, n_heads=6,
                 mlp_ratio=4.0, mask_ratio=0.30):
        super().__init__()

        self.patch_size  = patch_size
        self.mask_ratio  = mask_ratio
        self.embed_dim   = embed_dim
        self.n_patches   = (img_size // patch_size) ** 2   # 196

        # spectral mixing layer
        self.spectral_mixer = SpectralMixer(in_bands=12, out_channels=64)

        # patch embedding
        self.patch_embed = PatchEmbed(
            in_channels=64, patch_size=patch_size,
            img_size=img_size, embed_dim=embed_dim
        )

        # learnable CLS token and position embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # ViT-Small transformer blocks
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model    = embed_dim,
                nhead      = n_heads,
                dim_feedforward = int(embed_dim * mlp_ratio),
                dropout    = 0.0,
                activation = 'gelu',
                batch_first = True,
                norm_first  = True    # pre-norm (more stable)
            )
            for _ in range(depth)
        ])

        self.norm     = nn.LayerNorm(embed_dim)
        self.num_features = embed_dim   # used by DINOHead

    def forward(self, x):
        B = x.shape[0]

        # 1. compute spectral variance BEFORE mixing (from raw bands)
        var = compute_spectral_variance(x, self.patch_size)   # (B, 196)

        # 2. spectral mixing
        x = self.spectral_mixer(x)   # (B, 64, H, W)

        # 3. patch embedding
        x = self.patch_embed(x)      # (B, 196, 384)

        # 4. entropy masking — remove bottom mask_ratio% patches by variance
        n_keep   = int(self.n_patches * (1 - self.mask_ratio))   # keep top 70%
        # get indices of top-n_keep patches by variance
        _, keep_idx = torch.topk(var, n_keep, dim=1)              # (B, n_keep)
        keep_idx    = keep_idx.sort(dim=1).values                 # keep order
        # gather kept patches
        keep_idx_exp = keep_idx.unsqueeze(-1).expand(-1, -1, self.embed_dim)
        x            = torch.gather(x, 1, keep_idx_exp)          # (B, n_keep, 384)

        # 5. prepend CLS token and add position embeddings
        cls    = self.cls_token.expand(B, -1, -1)                 # (B, 1, 384)
        x      = torch.cat([cls, x], dim=1)                       # (B, n_keep+1, 384)

        # position embed: CLS gets pos 0, kept patches get their original positions
        cls_pos  = self.pos_embed[:, :1]                          # (1, 1, 384)
        patch_pos = self.pos_embed[:, 1:].expand(B, -1, -1)      # (B, 196, 384)
        kept_pos  = torch.gather(
            patch_pos, 1,
            keep_idx.unsqueeze(-1).expand(-1, -1, self.embed_dim)
        )                                                          # (B, n_keep, 384)
        pos = torch.cat([cls_pos.expand(B, -1, -1), kept_pos], dim=1)
        x   = x + pos

        # 6. transformer blocks
        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        # return CLS token as representation
        return x[:, 0]   # (B, 384)


# DINO Encoder

In [9]:
class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=65536):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 2048),
            nn.GELU(),
            nn.Linear(2048, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)


class DINOLoss(nn.Module):
    def __init__(self, out_dim, temp_s=0.1, temp_t=0.04):
        super().__init__()
        self.temp_s = temp_s
        self.temp_t = temp_t
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student, teacher):
        t    = F.softmax((teacher - self.center) / self.temp_t, dim=-1)
        s    = student / self.temp_s
        loss = torch.sum(-t * F.log_softmax(s, dim=-1), dim=-1).mean()
        self.center = 0.9 * self.center + 0.1 * teacher.mean(dim=0, keepdim=True)
        return loss


@torch.no_grad()
def update_teacher(student, teacher, m=0.996):
    for ps, pt in zip(student.parameters(), teacher.parameters()):
        pt.data = m * pt.data + (1 - m) * ps.data

# SSL Model Initialization

In [10]:
ssl_dataset = TemporalSSLDataset(ssl_root=SSL_ROOT, target_size=(224, 224))
ssl_loader  = DataLoader(
    ssl_dataset,
    batch_size = 32 * max(n_gpus, 1),   # scale batch with GPU count
    shuffle    = True,
    num_workers = 0
)

print(f"SSL batches per epoch: {len(ssl_loader)}")

# build student and teacher
encoder_s = SpectraMaskViT().to(device)
encoder_t = copy.deepcopy(encoder_s).to(device)
for p in encoder_t.parameters():
    p.requires_grad = False

head_s = DINOHead(encoder_s.num_features).to(device)
head_t = DINOHead(encoder_t.num_features).to(device)

# DataParallel for dual T4
if n_gpus > 1:
    encoder_s = DataParallel(encoder_s)
    encoder_t = DataParallel(encoder_t)
    head_s    = DataParallel(head_s)
    head_t    = DataParallel(head_t)
    print(f"✓ DataParallel enabled across {n_gpus} GPUs")

criterion_ssl = DINOLoss(out_dim=65536).to(device)
optimizer_ssl = torch.optim.AdamW(
    list(encoder_s.parameters()) + list(head_s.parameters()),
    lr=1e-4, weight_decay=0.05
)
scaler = torch.amp.GradScaler('cuda')

print("✓ SSL models ready")

SSL dataset: 1164 fields, 2960 total acquisitions
  Fields with 2+ dates (temporal pairs): 722
  Fields with 1 date  (self-pairs):      442
SSL batches per epoch: 19
✓ DataParallel enabled across 2 GPUs
✓ SSL models ready


# Encoder Training Loop

In [12]:
SSL_EPOCHS = 50

print(f"Starting SSL pretraining for {SSL_EPOCHS} epochs...")

for ep in range(SSL_EPOCHS):
    encoder_s.train()
    total_loss = 0

    for view1, view2 in tqdm(ssl_loader, desc=f"SSL {ep+1}/{SSL_EPOCHS}", leave=True):
        view1, view2 = view1.to(device), view2.to(device)

        optimizer_ssl.zero_grad()

        with torch.amp.autocast('cuda'):
            # student sees both views
            f1    = encoder_s(view1)
            f2    = encoder_s(view2)
            s_out = head_s((f1 + f2) / 2)

            # teacher sees both views (no grad)
            with torch.no_grad():
                t1    = encoder_t(view1)
                t2    = encoder_t(view2)
                t_out = head_t((t1 + t2) / 2).detach()

            loss = criterion_ssl(s_out, t_out)

        scaler.scale(loss).backward()
        scaler.step(optimizer_ssl)
        scaler.update()

        # EMA teacher update
        update_teacher(encoder_s, encoder_t)
        update_teacher(head_s,    head_t)

        total_loss += loss.item()

    print(f"SSL Epoch {ep+1}/{SSL_EPOCHS} | Loss: {total_loss/len(ssl_loader):.4f}")

# unwrap DataParallel before saving
encoder_to_save = encoder_s.module if isinstance(encoder_s, DataParallel) else encoder_s
torch.save(encoder_to_save.state_dict(), 'ssl_encoder.pth')
print("✓ SSL pretraining complete — saved to ssl_encoder.pth")

Starting SSL pretraining for 50 epochs...


SSL 1/50:   0%|          | 0/19 [00:00<?, ?it/s]

  dataset = DatasetReader(path, driver=driver, sharing=sharing, thread_safe=thread_safe, **kwargs)


SSL Epoch 1/50 | Loss: 9.9786


SSL 2/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 2/50 | Loss: 9.2339


SSL 3/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 3/50 | Loss: 9.1277


SSL 4/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 4/50 | Loss: 9.2203


SSL 5/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 5/50 | Loss: 9.2727


SSL 6/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 6/50 | Loss: 9.2725


SSL 7/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 7/50 | Loss: 9.1786


SSL 8/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 8/50 | Loss: 9.0392


SSL 9/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 9/50 | Loss: 8.8624


SSL 10/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 10/50 | Loss: 8.5780


SSL 11/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 11/50 | Loss: 8.2909


SSL 12/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 12/50 | Loss: 7.9639


SSL 13/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 13/50 | Loss: 7.4618


SSL 14/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 14/50 | Loss: 7.0945


SSL 15/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 15/50 | Loss: 6.6955


SSL 16/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 16/50 | Loss: 6.3046


SSL 17/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 17/50 | Loss: 5.9631


SSL 18/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 18/50 | Loss: 5.4132


SSL 19/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 19/50 | Loss: 5.0578


SSL 20/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 20/50 | Loss: 4.6652


SSL 21/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 21/50 | Loss: 4.2335


SSL 22/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 22/50 | Loss: 3.8254


SSL 23/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 23/50 | Loss: 3.5820


SSL 24/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 24/50 | Loss: 3.2230


SSL 25/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 25/50 | Loss: 2.9232


SSL 26/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 26/50 | Loss: 2.6489


SSL 27/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 27/50 | Loss: 2.2735


SSL 28/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 28/50 | Loss: 2.0453


SSL 29/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 29/50 | Loss: 1.9663


SSL 30/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 30/50 | Loss: 1.6947


SSL 31/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 31/50 | Loss: 1.4331


SSL 32/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 32/50 | Loss: 1.2997


SSL 33/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 33/50 | Loss: 1.3064


SSL 34/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 34/50 | Loss: 1.2739


SSL 35/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 35/50 | Loss: 1.3166


SSL 36/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 36/50 | Loss: 1.0306


SSL 37/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 37/50 | Loss: 1.0104


SSL 38/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 38/50 | Loss: 1.0303


SSL 39/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 39/50 | Loss: 0.9295


SSL 40/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 40/50 | Loss: 0.7220


SSL 41/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 41/50 | Loss: 0.8439


SSL 42/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 42/50 | Loss: 0.7251


SSL 43/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 43/50 | Loss: 0.6592


SSL 44/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 44/50 | Loss: 0.6179


SSL 45/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 45/50 | Loss: 0.5843


SSL 46/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 46/50 | Loss: 0.6063


SSL 47/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 47/50 | Loss: 0.7526


SSL 48/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 48/50 | Loss: 0.6623


SSL 49/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 49/50 | Loss: 0.5696


SSL 50/50:   0%|          | 0/19 [00:00<?, ?it/s]

SSL Epoch 50/50 | Loss: 0.5045
✓ SSL pretraining complete — saved to ssl_encoder.pth


# Classifier Head

In [13]:
class SpectraMaskViTClassifier(nn.Module):
    """
    Pretrained SpectraMaskViT encoder + MLP classification head.
    Unfreezes last N transformer blocks for fine-tuning.
    """
    def __init__(self, encoder, num_classes, unfreeze_last_n=2):
        super().__init__()
        self.encoder = encoder

        # freeze everything
        for p in self.encoder.parameters():
            p.requires_grad = False

        # unfreeze spectral mixer — it needs to adapt to disease signals
        for p in self.encoder.spectral_mixer.parameters():
            p.requires_grad = True

        # unfreeze last N transformer blocks
        for block in self.encoder.blocks[-unfreeze_last_n:]:
            for p in block.parameters():
                p.requires_grad = True

        # unfreeze final norm
        for p in self.encoder.norm.parameters():
            p.requires_grad = True

        hidden = self.encoder.num_features   # 384

        self.head = nn.Sequential(
            nn.Linear(hidden, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        feats = self.encoder(x)
        return self.head(feats)


# load pretrained SSL encoder
ssl_encoder = SpectraMaskViT()
ssl_encoder.load_state_dict(torch.load('ssl_encoder.pth'))
ssl_encoder = ssl_encoder.to(device)
print("✓ SSL encoder loaded")

# build classifier
train_dataset = S2Disease(root_dir=DATA_ROOT, is_eval=False, target_size=(224, 224))

model = SpectraMaskViTClassifier(
    encoder         = ssl_encoder,
    num_classes     = train_dataset.num_classes,
    unfreeze_last_n = 2
).to(device)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"✓ Classifier ready | Trainable: {trainable:,} / {total:,}")
print(f"  Unfrozen: spectral_mixer + last 2 blocks + norm")

✓ SSL encoder loaded
✓ Classifier ready | Trainable: 3,881,604 / 17,347,332
  Unfrozen: spectral_mixer + last 2 blocks + norm


# Balanced Dataset (Train + Test) by Augmentation

In [14]:
def val_normalize(x):
    mean = x.mean(dim=(1, 2), keepdim=True)
    std  = x.std(dim=(1, 2), keepdim=True) + 1e-6
    return (x - mean) / std

def augment_strong(x):
    _, H, W = x.shape
    crop = torch.randint(int(0.6*H), H+1, (1,)).item()
    i    = torch.randint(0, H-crop+1, (1,)).item()
    j    = torch.randint(0, W-crop+1, (1,)).item()
    x    = x[:, i:i+crop, j:j+crop]
    x    = F.interpolate(x.unsqueeze(0), size=(H, W),
                         mode='bilinear', align_corners=False).squeeze(0)
    if torch.rand(1) > 0.5: x = torch.flip(x, [2])
    if torch.rand(1) > 0.5: x = torch.flip(x, [1])
    k = torch.randint(0, 4, (1,)).item()
    x = torch.rot90(x, k, [1, 2])
    if torch.rand(1) > 0.6:
        drop = torch.randperm(12)[:torch.randint(1, 3, (1,)).item()]
        x    = x.clone(); x[drop] = 0.0
    x = x + 0.02 * torch.randn_like(x)
    return val_normalize(x)

def augment_light(x):
    _, H, W = x.shape
    crop = torch.randint(int(0.8*H), H+1, (1,)).item()
    i    = torch.randint(0, H-crop+1, (1,)).item()
    j    = torch.randint(0, W-crop+1, (1,)).item()
    x    = x[:, i:i+crop, j:j+crop]
    x    = F.interpolate(x.unsqueeze(0), size=(H, W),
                         mode='bilinear', align_corners=False).squeeze(0)
    if torch.rand(1) > 0.5: x = torch.flip(x, [2])
    x = x + 0.01 * torch.randn_like(x)
    return val_normalize(x)


class TrainDataset(Dataset):
    def __init__(self, samples, labels, aug_flags, num_classes, target_size=(224,224)):
        self.samples     = samples
        self.labels      = labels
        self.aug_flags   = aug_flags
        self.num_classes = num_classes
        self.target_size = target_size
        self.bands = ['B1','B2','B3','B4','B5','B6',
                      'B7','B8','B8A','B9','B11','B12']

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        band_data = []
        for b in self.bands:
            with rasterio.open(self.samples[idx] / f"{b}.tif") as src:
                d = src.read(1).astype(np.float32)
                if d.shape != self.target_size:
                    d = cv2.resize(d, self.target_size, interpolation=cv2.INTER_LINEAR)
                band_data.append(d)
        image = torch.from_numpy(np.stack(band_data))
        image = augment_strong(image) if self.aug_flags[idx] else augment_light(image)
        label = torch.zeros(self.num_classes)
        label[self.labels[idx]] = 1.0
        return {'image': image, 'label': label, 'sample_id': self.samples[idx].name}


class ValDataset(Dataset):
    def __init__(self, samples, labels, num_classes, target_size=(224,224)):
        self.samples     = samples
        self.labels      = labels
        self.num_classes = num_classes
        self.target_size = target_size
        self.bands = ['B1','B2','B3','B4','B5','B6',
                      'B7','B8','B8A','B9','B11','B12']

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        band_data = []
        for b in self.bands:
            with rasterio.open(self.samples[idx] / f"{b}.tif") as src:
                d = src.read(1).astype(np.float32)
                if d.shape != self.target_size:
                    d = cv2.resize(d, self.target_size, interpolation=cv2.INTER_LINEAR)
                band_data.append(d)
        image = torch.from_numpy(np.stack(band_data))
        image = val_normalize(image)
        label = torch.zeros(self.num_classes)
        label[self.labels[idx]] = 1.0
        return {'image': image, 'label': label, 'sample_id': self.samples[idx].name}


# stratified split on original samples
labels_orig    = [train_dataset.class_to_idx[s.parent.name] for s in train_dataset.samples]
skf            = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
tr_idx, va_idx = list(skf.split(train_dataset.samples, labels_orig))[0]

orig_train_samples = [train_dataset.samples[i] for i in tr_idx]
orig_train_labels  = [labels_orig[i] for i in tr_idx]
val_samples        = [train_dataset.samples[i] for i in va_idx]
val_labels         = [labels_orig[i] for i in va_idx]

# oversample training split — all classes to max count
train_counts  = {}
for cls in train_dataset.classes:
    idx = train_dataset.class_to_idx[cls]
    train_counts[cls] = orig_train_labels.count(idx)

max_count = max(train_counts.values())
bal_samples, bal_labels, bal_flags = [], [], []

for cls in train_dataset.classes:
    idx      = train_dataset.class_to_idx[cls]
    cls_samp = [s for s, l in zip(orig_train_samples, orig_train_labels) if l == idx]
    n        = len(cls_samp)
    for s in cls_samp:
        bal_samples.append(s); bal_labels.append(idx); bal_flags.append(False)
    for i in range(max_count - n):
        bal_samples.append(cls_samp[i % n]); bal_labels.append(idx); bal_flags.append(True)

print("Balanced training class counts:")
for cls in train_dataset.classes:
    idx = train_dataset.class_to_idx[cls]
    print(f"  {cls}: {bal_labels.count(idx)}")
print(f"Total train: {len(bal_samples)} | Val: {len(val_samples)}")

train_ds = TrainDataset(bal_samples, bal_labels, bal_flags,
                        train_dataset.num_classes)
val_ds   = ValDataset(val_samples, val_labels, train_dataset.num_classes)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, num_workers=0)

print("✓ Val set: original samples only, no augmentation")

Balanced training class counts:
  Aphid: 396
  Blast: 396
  RPH: 396
  Rust: 396
Total train: 1584 | Val: 180
✓ Val set: original samples only, no augmentation


In [15]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    def forward(self, inputs, targets):
        ce  = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt  = torch.exp(-ce)
        return ((1 - pt) ** self.gamma * ce).mean()


# class weights (p=0.5)
counts  = {c: len(list((Path(DATA_ROOT)/c).glob("*/"))) for c in train_dataset.classes}
total_n = sum(counts.values())
weights = torch.tensor(
    [np.power(total_n / counts[c], 0.5) for c in train_dataset.classes],
    dtype=torch.float32
).to(device)

print("Class weights:")
for c, w in zip(train_dataset.classes, weights):
    print(f"  {c}: {w:.3f}")

criterion = FocalLoss(alpha=weights, gamma=2.0).to(device)

optimizer = torch.optim.AdamW([
    {'params': model.encoder.spectral_mixer.parameters(), 'lr': 1e-4},
    {'params': model.encoder.blocks[-2:].parameters(),    'lr': 5e-5},
    {'params': model.encoder.norm.parameters(),           'lr': 5e-5},
    {'params': model.head.parameters(),                   'lr': 1e-3},
], weight_decay=0.05)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

print("✓ Loss, optimizer, scheduler ready")

Class weights:
  Aphid: 1.762
  Blast: 3.464
  RPH: 1.348
  Rust: 4.743
✓ Loss, optimizer, scheduler ready


# Decoder Head Training Loop

In [16]:
epochs           = 50
best_val_acc     = 0
patience_counter = 0
patience         = 7

print(f"Starting fine-tuning for {epochs} epochs...")
print("=" * 70)

for ep in range(epochs):
    # train
    model.train()
    tr_loss = tr_correct = tr_total = 0

    for batch in tqdm(train_loader, desc=f"Epoch {ep+1}/{epochs} [Train]", leave=False):
        x = batch['image'].to(device)
        y = torch.argmax(batch['label'], dim=1).to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss   = criterion(logits, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        tr_loss    += loss.item()
        tr_correct += (logits.argmax(1) == y).sum().item()
        tr_total   += y.size(0)

    # validate
    model.eval()
    va_loss = va_correct = va_total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {ep+1}/{epochs} [Val]", leave=False):
            x = batch['image'].to(device)
            y = torch.argmax(batch['label'], dim=1).to(device)

            logits   = model(x)
            loss     = criterion(logits, y)

            va_loss    += loss.item()
            va_correct += (logits.argmax(1) == y).sum().item()
            va_total   += y.size(0)

    scheduler.step()

    tr_acc = tr_correct / tr_total
    va_acc = va_correct / va_total

    print(f"Epoch {ep+1}/{epochs} | "
          f"Train Loss: {tr_loss/len(train_loader):.4f} | Train Acc: {tr_acc:.4f} | "
          f"Val Loss: {va_loss/len(val_loader):.4f} | Val Acc: {va_acc:.4f}")

    if va_acc > best_val_acc:
        best_val_acc = va_acc
        torch.save(model.state_dict(), 'best_model.pth')
        patience_counter = 0
        print(f"  ✓ New best! Saved.")
    else:
        patience_counter += 1
        print(f"  No improvement ({patience_counter}/{patience})")
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {ep+1}")
            break

print("=" * 70)
print(f"✓ Fine-tuning complete! Best Val Acc: {best_val_acc:.4f}")
model.load_state_dict(torch.load('best_model.pth'))
print("✓ Best model loaded!")

Starting fine-tuning for 50 epochs...


Epoch 1/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 1/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 1/50 | Train Loss: 3.0829 | Train Acc: 0.3737 | Val Loss: 3.6381 | Val Acc: 0.1778
  ✓ New best! Saved.


Epoch 2/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 2/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 2/50 | Train Loss: 2.4308 | Train Acc: 0.4343 | Val Loss: 2.1515 | Val Acc: 0.2500
  ✓ New best! Saved.


Epoch 3/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 3/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 3/50 | Train Loss: 2.1389 | Train Acc: 0.5088 | Val Loss: 2.0586 | Val Acc: 0.3000
  ✓ New best! Saved.


Epoch 4/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 4/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 4/50 | Train Loss: 1.8759 | Train Acc: 0.5303 | Val Loss: 2.2599 | Val Acc: 0.3278
  ✓ New best! Saved.


Epoch 5/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 5/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 5/50 | Train Loss: 1.7648 | Train Acc: 0.5732 | Val Loss: 1.9357 | Val Acc: 0.3778
  ✓ New best! Saved.


Epoch 6/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 6/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 6/50 | Train Loss: 1.6579 | Train Acc: 0.5922 | Val Loss: 1.7983 | Val Acc: 0.3667
  No improvement (1/7)


Epoch 7/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 7/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 7/50 | Train Loss: 1.5715 | Train Acc: 0.6136 | Val Loss: 1.8477 | Val Acc: 0.3500
  No improvement (2/7)


Epoch 8/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 8/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 8/50 | Train Loss: 1.5031 | Train Acc: 0.6332 | Val Loss: 1.9073 | Val Acc: 0.2944
  No improvement (3/7)


Epoch 9/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 9/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 9/50 | Train Loss: 1.4376 | Train Acc: 0.6269 | Val Loss: 1.7616 | Val Acc: 0.3111
  No improvement (4/7)


Epoch 10/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 10/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 10/50 | Train Loss: 1.4157 | Train Acc: 0.6383 | Val Loss: 1.8064 | Val Acc: 0.4444
  ✓ New best! Saved.


Epoch 11/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 11/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 11/50 | Train Loss: 1.3423 | Train Acc: 0.6427 | Val Loss: 1.7700 | Val Acc: 0.4389
  No improvement (1/7)


Epoch 12/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 12/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 12/50 | Train Loss: 1.2902 | Train Acc: 0.6490 | Val Loss: 1.8380 | Val Acc: 0.4167
  No improvement (2/7)


Epoch 13/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 13/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 13/50 | Train Loss: 1.2013 | Train Acc: 0.6654 | Val Loss: 1.8211 | Val Acc: 0.4556
  ✓ New best! Saved.


Epoch 14/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 14/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 14/50 | Train Loss: 1.2560 | Train Acc: 0.6850 | Val Loss: 1.6775 | Val Acc: 0.4444
  No improvement (1/7)


Epoch 15/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 15/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 15/50 | Train Loss: 1.1642 | Train Acc: 0.6862 | Val Loss: 1.9336 | Val Acc: 0.4389
  No improvement (2/7)


Epoch 16/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 16/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 16/50 | Train Loss: 1.0734 | Train Acc: 0.6982 | Val Loss: 1.6452 | Val Acc: 0.4944
  ✓ New best! Saved.


Epoch 17/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 17/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 17/50 | Train Loss: 1.1477 | Train Acc: 0.6894 | Val Loss: 1.8991 | Val Acc: 0.4167
  No improvement (1/7)


Epoch 18/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 18/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 18/50 | Train Loss: 1.0187 | Train Acc: 0.7134 | Val Loss: 1.7876 | Val Acc: 0.3722
  No improvement (2/7)


Epoch 19/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 19/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 19/50 | Train Loss: 1.0313 | Train Acc: 0.7039 | Val Loss: 1.7416 | Val Acc: 0.4722
  No improvement (3/7)


Epoch 20/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 20/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 20/50 | Train Loss: 1.0734 | Train Acc: 0.7146 | Val Loss: 1.8467 | Val Acc: 0.4556
  No improvement (4/7)


Epoch 21/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 21/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 21/50 | Train Loss: 0.9721 | Train Acc: 0.7134 | Val Loss: 1.7428 | Val Acc: 0.4556
  No improvement (5/7)


Epoch 22/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 22/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 22/50 | Train Loss: 1.0193 | Train Acc: 0.7367 | Val Loss: 1.6976 | Val Acc: 0.4722
  No improvement (6/7)


Epoch 23/50 [Train]:   0%|          | 0/99 [00:00<?, ?it/s]

Epoch 23/50 [Val]:   0%|          | 0/6 [00:00<?, ?it/s]

Epoch 23/50 | Train Loss: 0.9641 | Train Acc: 0.7348 | Val Loss: 1.7694 | Val Acc: 0.4889
  No improvement (7/7)

Early stopping at epoch 23
✓ Fine-tuning complete! Best Val Acc: 0.4944
✓ Best model loaded!


# Evaluation

In [20]:
eval_dataset = S2Disease(root_dir=DATA_ROOT, is_eval=True, target_size=(224, 224))
eval_loader  = DataLoader(eval_dataset, batch_size=32, shuffle=False, num_workers=0)

print(f"Evaluation samples: {len(eval_dataset)}")

model.eval()
predictions = []
sample_ids  = []

with torch.no_grad():
    for batch in tqdm(eval_loader, desc="Predicting"):
        x      = batch['image'].to(device)
        # normalize eval images
        mean   = x.mean(dim=(2, 3), keepdim=True)
        std    = x.std(dim=(2, 3), keepdim=True) + 1e-6
        x      = (x - mean) / std
        logits = model(x)
        preds  = logits.argmax(dim=1)
        predictions.extend(preds.cpu().numpy())
        sample_ids.extend(batch['sample_id'])

print(f"✓ Generated {len(predictions)} predictions")

Evaluation samples: 40


Predicting:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Generated 40 predictions


In [23]:
submission = pd.DataFrame({
    'sample_id':  sample_ids,
    'prediction': [train_dataset.idx_to_class[p] for p in predictions]
})

print("\nPrediction distribution:")
print(submission['prediction'].value_counts())

submission.to_csv('/kaggle/working/submission.csv', index=False)
print("\n✓ Saved to submission.csv")


Prediction distribution:
prediction
RPH      14
Aphid    12
Rust      8
Blast     6
Name: count, dtype: int64

✓ Saved to submission.csv


In [22]:
submission

Unnamed: 0,sample_id,prediction
0,994b5409c8e946538d87109a99897659,RPH
1,1a419acc1ecc467897d5477a47353fa8,RPH
2,8662df21b2c94788adce4a885ae2b4dc,Blast
3,a564868c3d8c4d4fabde67a536f178ad,Rust
4,796e611aaf8a4f0db57cb79be058f3ae,Rust
5,e77d3a0965fe46d9b3275a7d7f34dbe2,RPH
6,a39dcd0a21824289bb38b40ddf98da89,Aphid
7,e427f07618794fd58dfc9e6c786e3743,Rust
8,13739e32e7a84f669e6ef1284715e93b,RPH
9,b6eeb2bfd281476883fc273b61133e60,RPH
