In [1]:
import os
from pathlib import Path
import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
import tifffile as tiff
import torchvision.models as models
import wandb
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

# ===== PATHS =====
HS_DIR = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\data\raw\Kaggle_Prepared\train\HS"
TEST_HS_DIR = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\data\raw\Kaggle_Prepared\val\HS"
CHECKPOINT_DIR = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\checkpoints"
CKPT_PATH = os.path.join(CHECKPOINT_DIR, "best_hs125_resnet18.pth")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ===== DATA SETTINGS =====
TARGET_BANDS = 125
TARGET_HW = (64, 64)       # Resizing to 64x64 for consistency

# ===== SPLIT =====
VAL_RATIO = 0.2
SEED = 42

# ===== TRAIN =====
EPOCHS = 10
BATCH_SIZE = 32
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 0


In [2]:
prefixes = sorted({
    fn.split("_")[0]
    for fn in os.listdir(HS_DIR)
    if fn.endswith(".tif")
})

print("MS classes:", prefixes)
print("NUM_CLASSES =", len(prefixes))

MS classes: ['Health', 'Other', 'Rust']
NUM_CLASSES = 3


In [3]:
from tqdm import tqdm

def compute_global_stats(img_dir, bands=125):
    print("Computing global stats per band...")
    files = sorted([f for f in os.listdir(img_dir) if f.endswith(('.tif', '.tiff'))])
    
    pixel_num = 0
    channel_sum = np.zeros(bands, dtype=np.float64)
    channel_sum_sq = np.zeros(bands, dtype=np.float64)

    for f in tqdm(files):
        path = os.path.join(img_dir, f)
        # Read image (H, W, C) or (C, H, W)
        img = tiff.imread(path).astype(np.float32)
        
        # Ensure (C, H, W) format for consistency
        if img.ndim == 2:
            img = img[None, :, :]
        elif img.ndim == 3:
            # Check if channels are last (H, W, C) -> transpose to (C, H, W)
            # Heuristic: if last dim is `bands` (125 or near it) and first is not
            if img.shape[-1] == bands or img.shape[-1] > img.shape[0]:
                img = np.transpose(img, (2, 0, 1))
        
        # Clip to target bands if necessary
        if img.shape[0] > bands:
            img = img[:bands]
        elif img.shape[0] < bands:
            padding = np.zeros((bands - img.shape[0], img.shape[1], img.shape[2]), dtype=np.float32)
            img = np.concatenate([img, padding], axis=0)

        # Flatten spatial dims: (C, H*W)
        pixels = img.reshape(bands, -1)
        
        channel_sum += pixels.sum(axis=1)
        channel_sum_sq += (pixels ** 2).sum(axis=1)
        pixel_num += pixels.shape[1]
    
    mean = channel_sum / pixel_num
    std = np.sqrt(channel_sum_sq / pixel_num - mean ** 2)
    
    # Convert to torch tensor (C, 1, 1)
    return mean, std

def label_from_filename(fname: str) -> str:
    return os.path.basename(fname).split("_")[0]


In [4]:
class HSDataset(Dataset):
    """
    Hyperspectral Dataset with Global Z-score Normalization
    Resizes to TARGET_HW (64, 64)
    """
    def __init__(self, img_dir, file_list=None, target_bands=125, target_hw=(64,64), 
                 augment=False, mean=None, std=None):
        self.img_dir = img_dir
        self.target_bands = target_bands
        self.target_hw = target_hw
        self.augment = augment
        
        # Normalization Stats
        self.mean = torch.tensor(mean).view(target_bands, 1, 1).float() if mean is not None else torch.zeros(target_bands, 1, 1)
        self.std = torch.tensor(std).view(target_bands, 1, 1).float() if std is not None else torch.ones(target_bands, 1, 1)

        if file_list is not None:
            self.files = file_list
        else:
            self.files = sorted([f for f in os.listdir(img_dir) if f.lower().endswith((".tif", ".tiff"))])
        
        # Label mapping
        labels = sorted({label_from_filename(f) for f in self.files})
        self.class_to_idx = {c: i for i, c in enumerate(labels)}
        self.idx_to_class = {i: c for c, i in self.class_to_idx.items()}
        self.y = [self.class_to_idx[label_from_filename(f)] for f in self.files]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        label = self.y[idx]
        path = os.path.join(self.img_dir, fname)
        
        # Load
        arr = tiff.imread(path).astype(np.float32)
        
        # Ensure (C, H, W)
        if arr.ndim == 2: arr = arr[None, :, :]
        elif arr.ndim == 3 and (arr.shape[-1] == self.target_bands or arr.shape[-1] < arr.shape[0]):
             arr = np.transpose(arr, (2, 0, 1))
             
        # Fix bands
        c = arr.shape[0]
        if c > self.target_bands:
            arr = arr[:self.target_bands]
        elif c < self.target_bands:
            pad = np.zeros((self.target_bands - c, arr.shape[1], arr.shape[2]), dtype=np.float32)
            arr = np.concatenate([arr, pad], axis=0)

        x = torch.from_numpy(arr)
        
        # Resize
        if x.shape[1:] != self.target_hw:
            x = x.unsqueeze(0)
            x = F.interpolate(x, size=self.target_hw, mode="bilinear", align_corners=False)
            x = x.squeeze(0)

        # Normalize (Global Z-score)
        # Optional: Clip outlier pixels before norm? 
        # For baseline, standard (x-u)/s is fine. Could start by clipping to 0-1 range if raw data is huge.
        # Assuming raw data is uint16-like but loaded as float.
        x = (x - self.mean) / (self.std + 1e-8)
        
        # Augmentation
        if self.augment:
            if torch.rand(1) > 0.5: x = torch.flip(x, dims=[2])
            if torch.rand(1) > 0.5: x = torch.flip(x, dims=[1])
            k = torch.randint(0, 4, (1,)).item()
            x = torch.rot90(x, k, dims=[1, 2])

        return x, label


In [5]:
print("Calculating stats...")
mean_stats, std_stats = compute_global_stats(HS_DIR, bands=TARGET_BANDS)
print("Mean[0:5]:", mean_stats[:5])
print("Std [0:5]:", std_stats[:5])

# Split
all_files = sorted([f for f in os.listdir(HS_DIR) if f.endswith(('.tif', '.tiff'))])
labels = [label_from_filename(f) for f in all_files]
indices = np.arange(len(all_files))

train_idx, val_idx = train_test_split(
    indices, test_size=VAL_RATIO, random_state=SEED, stratify=labels
)

train_files = [all_files[i] for i in train_idx]
val_files = [all_files[i] for i in val_idx]

train_ds = HSDataset(HS_DIR, train_files, TARGET_BANDS, TARGET_HW, augment=True, 
                     mean=mean_stats, std=std_stats)
val_ds   = HSDataset(HS_DIR, val_files, TARGET_BANDS, TARGET_HW, augment=False, 
                     mean=mean_stats, std=std_stats)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")
xb, yb = next(iter(train_loader))
print("Batch shape:", xb.shape)


Calculating stats...
Computing global stats per band...


100%|██████████| 600/600 [00:00<00:00, 965.56it/s] 


Mean[0:5]: [264.61113607 320.74743815 348.97407389 366.44861979 378.22931315]
Std [0:5]: [330.00531389 347.41296516 352.80715128 355.75865957 360.99587943]
Train: 480 | Val: 120
Batch shape: torch.Size([32, 125, 64, 64])


In [6]:
num_classes = len(train_ds.class_to_idx)
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# Adjust first conv for 125 channels
old_conv = model.conv1
model.conv1 = nn.Conv2d(TARGET_BANDS, old_conv.out_channels, 
                        kernel_size=old_conv.kernel_size, stride=old_conv.stride, 
                        padding=old_conv.padding, bias=False)

# Init weights: average RGB weights and replicate
with torch.no_grad():
    model.conv1.weight[:] = old_conv.weight.mean(dim=1, keepdim=True).repeat(1, TARGET_BANDS, 1, 1)

model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)


In [7]:
wandb.init(project="beyond-visible-spectrum", name="baseline_hs125_fixed")

def train_one_epoch(loader):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        total_loss += loss.item() * x.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

best_acc = 0.0
for epoch in range(1, EPOCHS+1):
    train_loss, train_acc = train_one_epoch(train_loader)
    val_loss, val_acc = evaluate(val_loader)
    scheduler.step(val_acc)
    
    print(f"Epoch {epoch} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
    wandb.log({"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc})
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), CKPT_PATH)
        print(f"Saved best model: {val_acc:.4f}")

wandb.finish()


[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from C:\Users\ADMIN\_netrc.
[34m[1mwandb[0m: Currently logged in as: [33mphucga150625[0m ([33mphucga15062005[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1 | Train Acc: 0.4750 | Val Acc: 0.5000
Saved best model: 0.5000
Epoch 2 | Train Acc: 0.5896 | Val Acc: 0.4833
Epoch 3 | Train Acc: 0.5958 | Val Acc: 0.6167
Saved best model: 0.6167
Epoch 4 | Train Acc: 0.6000 | Val Acc: 0.5583
Epoch 5 | Train Acc: 0.6417 | Val Acc: 0.5417
Epoch 6 | Train Acc: 0.6500 | Val Acc: 0.5833
Epoch 7 | Train Acc: 0.6792 | Val Acc: 0.5417
Epoch 8 | Train Acc: 0.6771 | Val Acc: 0.5000
Epoch 9 | Train Acc: 0.6833 | Val Acc: 0.5583
Epoch 10 | Train Acc: 0.6896 | Val Acc: 0.6333
Saved best model: 0.6333


0,1
train_acc,▁▅▅▅▆▇████
train_loss,█▅▄▃▂▂▁▂▁▁
val_acc,▂▁▇▅▄▆▄▂▅█
val_loss,█▆▃▁▁▂▄▃▂▂

0,1
train_acc,0.68958
train_loss,0.65435
val_acc,0.63333
val_loss,0.82208


In [8]:
class HSTestDataset(Dataset):
    def __init__(self, img_dir, target_bands=125, target_hw=(64,64), mean=None, std=None):
        self.img_dir = img_dir
        self.target_bands = target_bands
        self.target_hw = target_hw
        self.mean = torch.tensor(mean).view(target_bands, 1, 1).float() if mean is not None else torch.zeros(target_bands, 1, 1)
        self.std = torch.tensor(std).view(target_bands, 1, 1).float() if std is not None else torch.ones(target_bands, 1, 1)
        self.files = sorted([f for f in os.listdir(img_dir) if f.endswith(('.tif', '.tiff'))])

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        path = os.path.join(self.img_dir, fname)
        arr = tiff.imread(path).astype(np.float32)
        if arr.ndim == 2: arr = arr[None, :, :]
        elif arr.ndim == 3 and (arr.shape[-1] == self.target_bands or arr.shape[-1] < arr.shape[0]):
             arr = np.transpose(arr, (2, 0, 1))
        
        c = arr.shape[0]
        if c > self.target_bands: arr = arr[:self.target_bands]
        elif c < self.target_bands:
            pad = np.zeros((self.target_bands - c, arr.shape[1], arr.shape[2]), dtype=np.float32)
            arr = np.concatenate([arr, pad], axis=0)
            
        x = torch.from_numpy(arr)
        if x.shape[1:] != self.target_hw:
            x = x.unsqueeze(0)
            x = F.interpolate(x, size=self.target_hw, mode="bilinear", align_corners=False)
            x = x.squeeze(0)
            
        x = (x - self.mean) / (self.std + 1e-8)
        return x, fname

if os.path.exists(TEST_HS_DIR):
    model.load_state_dict(torch.load(CKPT_PATH))
    model.eval()
    
    test_ds = HSTestDataset(TEST_HS_DIR, TARGET_BANDS, TARGET_HW, mean=mean_stats, std=std_stats)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
    
    preds = []
    ids = []
    class_names = [train_ds.idx_to_class[i] for i in range(num_classes)]
    
    with torch.no_grad():
        for x, fname in tqdm(test_loader):
            x = x.to(device)
            out = model(x)
            p_idx = out.argmax(1).cpu().numpy()
            preds.extend([class_names[i] for i in p_idx])
            ids.extend(fname)
            
    import pandas as pd
    df = pd.DataFrame({"Id": ids, "Category": preds})
    df.to_csv(os.path.join(CHECKPOINT_DIR, "submission_hs.csv"), index=False)
    print("Saved submission_hs.csv")


100%|██████████| 10/10 [00:05<00:00,  1.95it/s]

Saved submission_hs.csv



