In [4]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# =========================
# 0) CONFIG - SỬA CHỖ NÀY
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

RGB_DIR = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\data\raw\train\RGB"
MS_DIR  = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\data\raw\train\MS"
HS_DIR  = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\data\raw\train\HS"

CKPT_RGB = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\checkpoints\best_rgb_resnet18.pth"
CKPT_MS  = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\checkpoints\best_ms_resnet18.pth"
CKPT_HS  = r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\checkpoints\best_hs_topK20_resnet18.pth"  # hoặc hs-full/hs-gated

BATCH_SIZE = 32
NUM_WORKERS = 2

# Val file list (quan trọng: phải giống nhau cho 3 modality)
# Ví dụ bạn đã lưu:
val_files = np.load(r"D:\HocTap\NCKH_ThayDoNhuTai\Challenges\Notebooks\split\splits\val_idx.npy", allow_pickle=True).tolist()

# Nếu bạn chỉ có val_idx.npy thì bạn phải map ra files từ base dataset trước.
# Ở đây assume bạn đã có val_files list.
assert isinstance(val_files, list) and len(val_files) > 0, "Bạn cần val_files (list filenames) để align 3 loader."

In [5]:
# =========================
# 1) LOAD MODEL HELPER
# =========================
def load_checkpoint_into(model, ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    state_dict = ckpt["model_state"] if isinstance(ckpt, dict) and "model_state" in ckpt else ckpt
    model.load_state_dict(state_dict, strict=True)
    model.to(device).eval()
    return model


In [6]:
# =========================
# 2) BUILD MODELS (ResNet18 ví dụ)
#    Nếu bạn dùng ConvNeXt/timm thì đổi phần này.
# =========================
import torchvision.models as models

def build_resnet18(in_ch, num_classes=3, pretrained=True):
    m = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
    old = m.conv1
    m.conv1 = nn.Conv2d(in_ch, old.out_channels, kernel_size=old.kernel_size,
                        stride=old.stride, padding=old.padding, bias=False)
    if pretrained:
        with torch.no_grad():
            w_mean = old.weight.mean(dim=1, keepdim=True)
            m.conv1.weight.copy_(w_mean.repeat(1, in_ch, 1, 1))
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# SỬA đúng số kênh:
model_rgb = build_resnet18(in_ch=3,   num_classes=3, pretrained=True)
model_ms  = build_resnet18(in_ch=5,   num_classes=3, pretrained=True)
model_hs  = build_resnet18(in_ch=20,  num_classes=3, pretrained=True)  # HS-top20
# nếu HS-full thì in_ch=125

model_rgb = load_checkpoint_into(model_rgb, CKPT_RGB, device)
model_ms  = load_checkpoint_into(model_ms,  CKPT_MS,  device)
model_hs  = load_checkpoint_into(model_hs,  CKPT_HS,  device)

In [7]:
# =========================
# 3) DATASETS / LOADERS (aligned)
#    Bạn cần 3 dataset class: RGBDataset, MSDataset, HSDataset
#    - RGBDataset: đọc PNG RGB
#    - MSDataset : đọc MS tif 5 kênh
#    - HSDataset : class bạn đã có (trả 125,H,W) -> cần wrapper select top20 để ra 20,H,W
# =========================

# --- HS-top20 wrapper (cắt band sau normalize) ---
class SelectBandsDataset(torch.utils.data.Dataset):
    def __init__(self, base_ds, band_idx):
        self.base = base_ds
        self.band_idx = torch.tensor(band_idx, dtype=torch.long)
        self.files = getattr(base_ds, "files", None)
        self.y = getattr(base_ds, "y", None)
        self.idx_to_class = getattr(base_ds, "idx_to_class", None)

    def __len__(self):
        return len(self.base)

    def __getitem__(self, i):
        x, y = self.base[i]  # (125,H,W) đã normalize
        x = x.index_select(0, self.band_idx)  # (K,H,W)
        return x, y

TOP20 = [7, 32, 43, 48, 50, 58, 72, 84, 92, 97, 98, 99, 101, 105, 110, 111, 112, 114, 117, 122]

# --- Tạo dataset theo val_files ---
# Bạn thay transform/args đúng theo code của bạn.
val_ds_rgb = RGBDataset(RGB_DIR, file_list=val_files, transform=tfm_val_rgb)
val_ds_ms  = MSDataset(MS_DIR,   file_list=val_files, augment=False, mean=mean_ms, std=std_ms)

val_ds_hs_full = HSDataset(HS_DIR, file_list=val_files, target_bands=125, target_hw=(64,64),
                           augment=False, mean=mean_hs, std=std_hs)
val_ds_hs = SelectBandsDataset(val_ds_hs_full, TOP20)  # => (20,H,W)

val_loader_rgb = DataLoader(val_ds_rgb, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
val_loader_ms  = DataLoader(val_ds_ms,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
val_loader_hs  = DataLoader(val_ds_hs,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# Quick check alignment
xb, yb = next(iter(val_loader_hs))
print("HS batch shape:", xb.shape)  # [B,20,H,W]

# =========================
# 4) FUSION EVAL (avg logits)
# =========================
@torch.no_grad()
def fusion_eval(model_rgb, model_ms, model_hs, loader_rgb, loader_ms, loader_hs, device, weights=(1.0,1.0,1.0)):
    model_rgb.eval(); model_ms.eval(); model_hs.eval()
    wr, wm, wh = weights

    all_preds, all_labels = [], []

    for (xr, y), (xm, y2), (xh, y3) in zip(loader_rgb, loader_ms, loader_hs):
        # đảm bảo align
        if not (torch.equal(y, y2) and torch.equal(y, y3)):
            raise RuntimeError("Loader mismatch: labels khác nhau -> file_list/order không aligned!")

        xr = xr.to(device, non_blocking=True)
        xm = xm.to(device, non_blocking=True)
        xh = xh.to(device, non_blocking=True)
        y  = y.to(device, non_blocking=True)

        lr = model_rgb(xr)  # logits
        lm = model_ms(xm)
        lh = model_hs(xh)

        logits = wr*lr + wm*lm + wh*lh
        preds = logits.argmax(1)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(y.cpu().numpy())

    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_labels)

    acc = accuracy_score(y_true, y_pred)
    f1m = f1_score(y_true, y_pred, average="macro")
    cm  = confusion_matrix(y_true, y_pred)
    return acc, f1m, cm, y_true, y_pred

acc, f1m, cm, y_true, y_pred = fusion_eval(
    model_rgb, model_ms, model_hs,
    val_loader_rgb, val_loader_ms, val_loader_hs,
    device,
    weights=(1.0, 1.0, 1.0)  # thử (0.8,1.2,1.0) nếu MS mạnh hơn
)

print("=== FUSION A (Avg Logits) ===")
print(f"Acc: {acc:.4f} | F1-macro: {f1m:.4f}")
print("Confusion Matrix:\n", cm)

# Tên lớp nếu có
if hasattr(val_ds_rgb, "idx_to_class"):
    target_names = [val_ds_rgb.idx_to_class[i] for i in range(3)]
    print("\nClassification Report:\n")
    print(classification_report(y_true, y_pred, target_names=target_names, digits=4))
else:
    print("\nClassification Report:\n")
    print(classification_report(y_true, y_pred, digits=4))


NameError: name 'RGBDataset' is not defined