In [21]:
import os
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
import tifffile as tiff

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

HS_DIR = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\data\raw\Kaggle_Prepared\train\HS"

TARGET_BANDS = 125
TARGET_HW    = (32, 32)     # giữ đúng dataset hiện tại, muốn đổi thì đổi
VAL_RATIO    = 0.2
SEED         = 42

BATCH_SIZE   = 32
EPOCHS       = 30
LR           = 1e-3
WD           = 1e-4

NUM_WORKERS  = 0  # debug trước, chạy OK rồi hãy tăng 2


In [22]:
class HSDataset(Dataset):
    def __init__(self, root_dir, target_bands=125, target_hw=(32,32), normalize="minmax"):
        self.root_dir = Path(root_dir)
        self.target_bands = target_bands
        self.target_hw = target_hw
        self.normalize = normalize

        tif_files = sorted(list(self.root_dir.rglob("*.tif")) + list(self.root_dir.rglob("*.tiff")))
        if len(tif_files) == 0:
            raise RuntimeError(f"No .tif/.tiff found in {root_dir}")

        # ImageFolder-style: HS/class_name/*.tif
        class_names = sorted({p.parent.name for p in tif_files})
        self.class_to_idx = {c:i for i,c in enumerate(class_names)}

        self.samples = [(p, self.class_to_idx[p.parent.name]) for p in tif_files]

    def __len__(self):
        return len(self.samples)

    @staticmethod
    def _to_chw(arr: np.ndarray) -> np.ndarray:
        # chắc chắn đưa về (C,H,W) theo dim lớn nhất là band
        if arr.ndim == 2:
            return arr[None, :, :]  # (1,H,W)
        if arr.ndim != 3:
            raise ValueError(f"Unexpected ndim={arr.ndim}, shape={arr.shape}")

        band_axis = int(np.argmax(arr.shape))      # axis có size lớn nhất (125/126)
        arr = np.moveaxis(arr, band_axis, 0)       # đưa band về axis0 -> (C,?,?)
        return arr

    def _fix_bands(self, x: torch.Tensor) -> torch.Tensor:
        c, h, w = x.shape
        tb = self.target_bands
        if c > tb:
            x = x[:tb]  # 126 -> 125
        elif c < tb:
            pad = torch.zeros((tb - c, h, w), dtype=x.dtype)
            x = torch.cat([x, pad], dim=0)
        return x

    def _resize(self, x: torch.Tensor) -> torch.Tensor:
        th, tw = self.target_hw
        if x.shape[1:] == (th, tw):
            return x
        x = x.unsqueeze(0)  # (1,C,H,W)
        x = F.interpolate(x, size=(th, tw), mode="bilinear", align_corners=False)
        return x.squeeze(0)

    def _normalize(self, x: torch.Tensor) -> torch.Tensor:
        if self.normalize is None:
            return x
        if self.normalize == "minmax":
            mn = x.amin(dim=(1,2), keepdim=True)
            mx = x.amax(dim=(1,2), keepdim=True)
            return (x - mn) / (mx - mn + 1e-6)
        if self.normalize == "zscore":
            mean = x.mean(dim=(1,2), keepdim=True)
            std  = x.std(dim=(1,2), keepdim=True)
            return (x - mean) / (std + 1e-6)
        raise ValueError("normalize must be: 'minmax', 'zscore', or None")

    def __getitem__(self, idx):
        path, y = self.samples[idx]
        arr = tiff.imread(str(path))
        arr = self._to_chw(arr).astype(np.float32)  # (C,H,W)

        x = torch.from_numpy(arr)
        x = self._fix_bands(x)      # <<< ép 125
        x = self._resize(x)
        x = self._normalize(x)

        return x, torch.tensor(y, dtype=torch.long)


In [23]:
torch.manual_seed(SEED)
np.random.seed(SEED)

full_ds = HSDataset(
    HS_DIR,
    target_bands=TARGET_BANDS,
    target_hw=TARGET_HW,
    normalize="minmax"
)

indices = np.arange(len(full_ds))
labels  = np.array([full_ds.samples[i][1] for i in indices])

train_idx, val_idx = train_test_split(
    indices,
    test_size=VAL_RATIO,
    random_state=SEED,
    stratify=labels
)

train_ds = Subset(full_ds, train_idx.tolist())
val_ds   = Subset(full_ds, val_idx.tolist())

print("Classes:", full_ds.class_to_idx)
print("Train size:", len(train_ds), "Val size:", len(val_ds))

# sanity: check bands (lấy sample 200)
from collections import Counter
def check_bands(ds, n=200):
    cnt = Counter()
    for i in range(min(len(ds), n)):
        x, _ = ds[i]
        cnt[int(x.shape[0])] += 1
    return cnt

print("Train bands sample:", check_bands(train_ds))
print("Val   bands sample:", check_bands(val_ds))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

xb, yb = next(iter(train_loader))
print("Batch:", xb.shape, yb.shape)


Classes: {'HS': 0}
Train size: 480 Val size: 120
Train bands sample: Counter({125: 200})
Val   bands sample: Counter({125: 120})


  super().__init__(loader)


Batch: torch.Size([32, 125, 32, 32]) torch.Size([32])


In [24]:
import torchvision.models as models

num_classes = len(full_ds.class_to_idx)

model = models.resnet18(weights=None)
model.conv1 = nn.Conv2d(TARGET_BANDS, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)


In [26]:
import wandb
import torch
import torch.nn as nn
import torch.optim as optim

in_channels = 125
RUN_NAME = "baseline_hs125_resnet18"
CKPT_PATH = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\checkpoints\best_hs125_resnet18.pth"

wandb.init(
    project="beyond-visible-spectrum",
    name=RUN_NAME,
    config={"epochs": 10, "batch_size": 32, "lr": 1e-4, "in_channels": in_channels, "wd": 1e-4}
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=wandb.config.lr, weight_decay=wandb.config.wd)

def run_epoch(loader, train=True):
    model.train(train)
    total_loss, correct, total = 0.0, 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        if train:
            optimizer.zero_grad(set_to_none=True)

        logits = model(x)
        loss = criterion(logits, y)

        if train:
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)

    return total_loss/total, correct/total

best_val_acc = -1.0

for epoch in range(1, wandb.config.epochs + 1):
    tr_loss, tr_acc = run_epoch(train_loader, train=True)
    va_loss, va_acc = run_epoch(val_loader,   train=False)

    wandb.log({
        "epoch": epoch,
        "train_loss": tr_loss, "train_acc": tr_acc,
        "val_loss": va_loss,   "val_acc": va_acc
    })

    if va_acc > best_val_acc:
        best_val_acc = va_acc
        torch.save(model.state_dict(), CKPT_PATH)
        wandb.log({"best_val_acc": best_val_acc})

    print(f"Epoch {epoch:02d} | train_acc={tr_acc:.3f} val_acc={va_acc:.3f}")

wandb.finish()
print("Best val_acc:", best_val_acc)


Epoch 01 | train_acc=1.000 val_acc=1.000
Epoch 02 | train_acc=1.000 val_acc=1.000
Epoch 03 | train_acc=1.000 val_acc=1.000
Epoch 04 | train_acc=1.000 val_acc=1.000
Epoch 05 | train_acc=1.000 val_acc=1.000
Epoch 06 | train_acc=1.000 val_acc=1.000
Epoch 07 | train_acc=1.000 val_acc=1.000
Epoch 08 | train_acc=1.000 val_acc=1.000
Epoch 09 | train_acc=1.000 val_acc=1.000
Epoch 10 | train_acc=1.000 val_acc=1.000


0,1
best_val_acc,▁
epoch,▁▂▃▃▄▅▆▆▇█
train_acc,▁▁▁▁▁▁▁▁▁▁
train_loss,▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▁▁▁▁▁▁▁▁
val_loss,▁▁▁▁▁▁▁▁▁▁

0,1
best_val_acc,1
epoch,10
train_acc,1
train_loss,0
val_acc,1
val_loss,0


Best val_acc: 1.0
