In [1]:
import os
from pathlib import Path
import numpy as np
from collections import Counter
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
import tifffile as tiff
import torchvision.models as models
import wandb
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

# ===== PATHS =====
HS_DIR = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\data\raw\Kaggle_Prepared\train\HS"
TEST_HS_DIR = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\data\raw\Kaggle_Prepared\val\HS"
CHECKPOINT_DIR = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\checkpoints"
CKPT_PATH = os.path.join(CHECKPOINT_DIR, "best_hs_trimmed_resnet18.pth")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ===== TRIMMED DATA SETTINGS =====
ORIGINAL_BANDS = 125
START_BAND = 10      # Skip first 10
END_BAND = 125 - 14  # Skip last 14 (111)
TARGET_BANDS = END_BAND - START_BAND # 101
TARGET_HW = (64, 64) 

# ===== SPLIT =====
VAL_RATIO = 0.2
SEED = 42

# ===== TRAIN =====
EPOCHS =10
BATCH_SIZE = 32
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 0

print(f"Target Bands: {TARGET_BANDS} (Indices {START_BAND} to {END_BAND})")


Target Bands: 101 (Indices 10 to 111)


In [2]:
def compute_trimmed_stats(img_dir):
    print("Computing global stats for TRIMMED bands...")
    files = sorted([f for f in os.listdir(img_dir) if f.endswith(('.tif', '.tiff'))])
    
    pixel_num = 0
    channel_sum = np.zeros(TARGET_BANDS, dtype=np.float64)
    channel_sum_sq = np.zeros(TARGET_BANDS, dtype=np.float64)

    for f in tqdm(files):
        path = os.path.join(img_dir, f)
        # Load
        img = tiff.imread(path).astype(np.float32)
        
        # Strictly check for HWC format based on known Band count (125)
        # If shape is (H, W, 125), transpose to (125, H, W)
        if img.ndim == 3 and img.shape[-1] == ORIGINAL_BANDS:
             img = np.transpose(img, (2, 0, 1))
        elif img.ndim == 2:
             img = img[None, :, :]
        
        # Pad if short (rare) or Trim
        if img.shape[0] < ORIGINAL_BANDS:
             pad = np.zeros((ORIGINAL_BANDS - img.shape[0], img.shape[1], img.shape[2]), dtype=np.float32)
             img = np.concatenate([img, pad], axis=0)
        
        # === TRIMMING ===
        # Slice [10:111]
        img = img[START_BAND:END_BAND, :, :]

        pixels = img.reshape(TARGET_BANDS, -1)
        
        channel_sum += pixels.sum(axis=1)
        channel_sum_sq += (pixels ** 2).sum(axis=1)
        pixel_num += pixels.shape[1]
    
    mean = channel_sum / pixel_num
    std = np.sqrt(channel_sum_sq / pixel_num - mean ** 2)
    
    return mean, std

def label_from_filename(fname: str) -> str:
    return os.path.basename(fname).split("_")[0]


In [3]:
class HSTrimmedDataset(Dataset):
    def __init__(self, img_dir, file_list=None, augment=False, mean=None, std=None):
        self.img_dir = img_dir
        self.augment = augment
        self.mean = torch.tensor(mean).view(TARGET_BANDS, 1, 1).float() if mean is not None else torch.zeros(TARGET_BANDS, 1, 1)
        self.std = torch.tensor(std).view(TARGET_BANDS, 1, 1).float() if std is not None else torch.ones(TARGET_BANDS, 1, 1)

        if file_list is not None:
            self.files = file_list
        else:
            self.files = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(('.tif', '.tiff'))])
        
        labels = sorted({label_from_filename(f) for f in self.files})
        self.class_to_idx = {c: i for i, c in enumerate(labels)}
        self.idx_to_class = {i: c for c, i in self.class_to_idx.items()}
        self.y = [self.class_to_idx[label_from_filename(f)] for f in self.files]

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        path = os.path.join(self.img_dir, fname)
        arr = tiff.imread(path).astype(np.float32)
        
        # Format fix: Check strict band location
        if arr.ndim == 3 and arr.shape[-1] == ORIGINAL_BANDS:
             arr = np.transpose(arr, (2, 0, 1))
        elif arr.ndim == 2: 
             arr = arr[None, :, :]
        
        # Pad to 125 if needed (robustness)
        if arr.shape[0] < ORIGINAL_BANDS:
             pad = np.zeros((ORIGINAL_BANDS - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=np.float32)
             arr = np.concatenate([arr, pad], axis=0)
            
        # === TRIM ===
        arr = arr[START_BAND:END_BAND, :, :]

        x = torch.from_numpy(arr)
        
        # Resize to 64x64
        if x.shape[1:] != TARGET_HW:
            x = x.unsqueeze(0)
            x = F.interpolate(x, size=TARGET_HW, mode="bilinear", align_corners=False)
            x = x.squeeze(0)

        # Normalize
        x = (x - self.mean) / (self.std + 1e-8)
        
        if self.augment:
            if torch.rand(1) > 0.5: x = torch.flip(x, dims=[2])
            if torch.rand(1) > 0.5: x = torch.flip(x, dims=[1])
            k = torch.randint(0, 4, (1,)).item()
            x = torch.rot90(x, k, dims=[1, 2])

        return x, self.y[idx]

class HSTrimmedTestDataset(Dataset):
    def __init__(self, img_dir, mean=None, std=None):
        self.img_dir = img_dir
        self.files = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(('.tif', '.tiff'))])
        self.mean = torch.tensor(mean).view(TARGET_BANDS, 1, 1).float() if mean is not None else torch.zeros(TARGET_BANDS, 1, 1)
        self.std = torch.tensor(std).view(TARGET_BANDS, 1, 1).float() if std is not None else torch.ones(TARGET_BANDS, 1, 1)
    
    def __len__(self): return len(self.files)
    
    def __getitem__(self, idx):
        fname = self.files[idx]
        path = os.path.join(self.img_dir, fname)
        arr = tiff.imread(path).astype(np.float32)
        
        if arr.ndim == 3 and arr.shape[-1] == ORIGINAL_BANDS:
             arr = np.transpose(arr, (2, 0, 1))
        elif arr.ndim == 2: 
             arr = arr[None, :, :]
        
        if arr.shape[0] < ORIGINAL_BANDS:
             pad = np.zeros((ORIGINAL_BANDS - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=np.float32)
             arr = np.concatenate([arr, pad], axis=0)
             
        arr = arr[START_BAND:END_BAND, :, :]
        x = torch.from_numpy(arr)
        
        if x.shape[1:] != TARGET_HW:
            x = x.unsqueeze(0)
            x = F.interpolate(x, size=TARGET_HW, mode="bilinear", align_corners=False)
            x = x.squeeze(0)
            
        x = (x - self.mean) / (self.std + 1e-8)
        return x, fname


In [4]:
mean_stats, std_stats = compute_trimmed_stats(HS_DIR)
print("Trimmed Stats Shape:", mean_stats.shape)

all_files = sorted([f for f in os.listdir(HS_DIR) if f.endswith(('.tif', '.tiff'))])
labels = [label_from_filename(f) for f in all_files]
indices = np.arange(len(all_files))

train_idx, val_idx = train_test_split(
    indices, test_size=VAL_RATIO, random_state=SEED, stratify=labels
)

train_files = [all_files[i] for i in train_idx]
val_files = [all_files[i] for i in val_idx]

train_ds = HSTrimmedDataset(HS_DIR, train_files, augment=True, mean=mean_stats, std=std_stats)
val_ds   = HSTrimmedDataset(HS_DIR, val_files, augment=False, mean=mean_stats, std=std_stats)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print(f"Train samples: {len(train_ds)}")
x, y = next(iter(train_loader))
print("Batch Shape:", x.shape) # Should be (B, 101, 64, 64)


Computing global stats for TRIMMED bands...


100%|██████████| 600/600 [00:01<00:00, 453.97it/s]


Trimmed Stats Shape: (101,)
Train samples: 480
Batch Shape: torch.Size([32, 101, 64, 64])


In [5]:
num_classes = len(train_ds.class_to_idx)
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# Modify conv1 for 101 channels
old_conv = model.conv1
model.conv1 = nn.Conv2d(TARGET_BANDS, old_conv.out_channels, 
                        kernel_size=old_conv.kernel_size, stride=old_conv.stride, 
                        padding=old_conv.padding, bias=False)

# Init weights smartly: take mean of RGB weights and replicate 101 times
with torch.no_grad():
    model.conv1.weight[:] = old_conv.weight.mean(dim=1, keepdim=True).repeat(1, TARGET_BANDS, 1, 1)

model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)


In [6]:
wandb.init(project="beyond-visible-spectrum", name="baseline_hs_trimmed")

def train_one_epoch(loader):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        total_loss += loss.item() * x.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

best_acc = 0.0
for epoch in range(1, EPOCHS+1):
    train_loss, train_acc = train_one_epoch(train_loader)
    val_loss, val_acc = evaluate(val_loader)
    scheduler.step(val_acc)
    
    print(f"Epoch {epoch} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
    wandb.log({"train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc})
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), CKPT_PATH)
        print(f"Saved best model: {val_acc:.4f}")

wandb.finish()


[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from C:\Users\ADMIN\_netrc.
[34m[1mwandb[0m: Currently logged in as: [33mphucga150625[0m ([33mphucga15062005[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1 | Train Acc: 0.4688 | Val Acc: 0.4583
Saved best model: 0.4583
Epoch 2 | Train Acc: 0.5667 | Val Acc: 0.5000
Saved best model: 0.5000
Epoch 3 | Train Acc: 0.5958 | Val Acc: 0.5750
Saved best model: 0.5750
Epoch 4 | Train Acc: 0.6521 | Val Acc: 0.5167
Epoch 5 | Train Acc: 0.6562 | Val Acc: 0.5750
Epoch 6 | Train Acc: 0.6604 | Val Acc: 0.5333
Epoch 7 | Train Acc: 0.6896 | Val Acc: 0.5667
Epoch 8 | Train Acc: 0.7396 | Val Acc: 0.6000
Saved best model: 0.6000
Epoch 9 | Train Acc: 0.7479 | Val Acc: 0.6083
Saved best model: 0.6083
Epoch 10 | Train Acc: 0.7417 | Val Acc: 0.5750


0,1
train_acc,▁▃▄▆▆▆▇███
train_loss,█▆▅▄▃▃▂▂▁▁
val_acc,▁▃▆▄▆▅▆██▆
val_loss,█▆▃█▃▂▄▁▁▂

0,1
train_acc,0.74167
train_loss,0.60042
val_acc,0.575
val_loss,0.87805


In [7]:
if os.path.exists(TEST_HS_DIR):
    print("Running Inference on Test Set...")
    model.load_state_dict(torch.load(CKPT_PATH))
    model.eval()
    
    test_ds = HSTrimmedTestDataset(TEST_HS_DIR, mean=mean_stats, std=std_stats)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
    
    preds = []
    ids = []
    class_names = [train_ds.idx_to_class[i] for i in range(num_classes)]
    
    with torch.no_grad():
        for x, fname in tqdm(test_loader):
            x = x.to(device)
            out = model(x)
            p_idx = out.argmax(1).cpu().numpy()
            preds.extend([class_names[i] for i in p_idx])
            ids.extend(fname)
            
    import pandas as pd
    df = pd.DataFrame({"Id": ids, "Category": preds})
    csv_path = os.path.join(CHECKPOINT_DIR, "submission_hs_trimmed.csv")
    df.to_csv(csv_path, index=False)
    print(f"Saved {csv_path}")


Running Inference on Test Set...


100%|██████████| 10/10 [00:01<00:00,  5.37it/s]

Saved D:\HocTap\NCKH_ThayDoNhuTai\Challenges\checkpoints\submission_hs_trimmed.csv



