In [None]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from model.HierNet import WSIHierNet
from dataset import prepare_dataset
from utils import *
from types import SimpleNamespace
import yaml
from eval import evaluator
import deepspeed


def run_test(cfg):
    # === 設定設備與模型 ===
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dims = [int(x) for x in cfg['dims'].split('-')]

    print(cfg)
    scales = list(map(int, cfg['magnification'].split('-')))
    scale = int(scales[1] / scales[0])
    print(f"Scale for magnifications {scales} is {scale}")

    cfg_x20_emb = SimpleNamespace(
        backbone=cfg['emb_x20_backbone'], in_dim=dims[0], out_dim=dims[1],
        scale=scale, dropout=cfg['dropout'], dw_conv=cfg['emb_x20_dw_conv'], ksize=cfg['emb_x20_ksize']
    )
    cfg_x5_emb = SimpleNamespace(
        backbone=cfg['emb_x5_backbone'], in_dim=dims[0], out_dim=dims[1],
        scale=1, dropout=cfg['dropout'], dw_conv=False, ksize=cfg['emb_x5_ksize']
    )
    cfg_tra_backbone = SimpleNamespace(
        backbone=cfg['tra_backbone'], ksize=cfg['tra_ksize'], dw_conv=cfg['tra_dw_conv'],
        d_model=dims[1], d_out=dims[2], nhead=cfg['tra_nhead'],
        dropout=cfg['dropout'], num_layers=cfg['tra_num_layers'], epsilon=cfg['tra_epsilon']
    )

    model = WSIHierNet(dims, cfg_x20_emb, cfg_x5_emb, cfg_tra_backbone,
                       dropout=cfg['dropout'], pool=cfg['pool'], join=cfg['join'], fusion=cfg['fusion'])

    # === DeepSpeed 初始化 ===
    ds_config_path = '/work/u6658716/TCGA-LUAD/DSCA/config/ds_config.json'
    model_engine, _, _, _ = deepspeed.initialize(
        model=model,
        config_params=ds_config_path,
        dist_init_required=False
    )

    # === 載入 best model（DeepSpeed 格式） ===
    load_path = cfg['save_path']
    print(f"[INFO] Loading DeepSpeed checkpoint from {load_path}")
    success, _ = model_engine.load_checkpoint(load_path, tag="model-best")
    if not success:
        raise RuntimeError("[ERROR] Failed to load DeepSpeed checkpoint")

    model_engine.eval()

    # === 準備資料 ===
    path_split = cfg['path_data_split'].format(cfg['seed_data_split'])
    _, _, pids_test = read_datasplit_npz(path_split)
    test_set = prepare_dataset(pids_test, cfg, cfg['magnification'])
    test_loader = DataLoader(test_set, batch_size=cfg['batch_size'], num_workers=cfg['num_workers'], pin_memory=True)

    # === 推論 ===
    result = {'y': None, 'y_hat': None}
    with torch.no_grad():
        for fx, fx5, cx5, y in test_loader:
            fx, fx5, cx5, y = fx.to(model_engine.local_rank), fx5.to(model_engine.local_rank), cx5.to(model_engine.local_rank), y.to(model_engine.local_rank)
            y_hat = model_engine(fx, fx5, cx5)
            result = collect_tensor(result, y.detach().cpu(), y_hat.detach().cpu())

    y_true, y_pred = result['y'], result['y_hat']
    c_index = evaluator(y_true, y_pred, metrics='cindex')
    print(f"[RESULT] Test C-index: {c_index:.4f}")


def get_config(config_path="config/config.yml"):
    with open(config_path, "r") as setting:
        config = yaml.load(setting, Loader=yaml.FullLoader)
    return config


if __name__ == '__main__':
    config = get_config('/work/u6658716/TCGA-LUAD/DSCA/config/config_ms.yaml')
    run_test(config)


{'task': 'HierSurv', 'experiment': 'sim', 'seed_data_split': 0, 'path_data_split': './data_split/tcga_luad_merged/tcga_luad_merged-seed42-fold{}.npz', 'csv_path': '/work/u6658716/TCGA-LUAD/DSCA/data_split/tcga_luad_merged/tcga_luad_merged_path_full.csv', 'h5_dir': '/work/u6658716/TCGA-LUAD/PATCHES/LUAD/tiles-10x-s224', 'slide_dir': '/work/u6658716/TCGA-LUAD/DATASETS/TCGA/LUAD', 'ckpt_path': '/work/u6658716/TCGA-LUAD/CLAM/checkpoints/conch/pytorch_model.bin', 'lora_checkpoint': None, 'target_patch_size': 224, 'dataset_name': 'tcga_luad_merged', 'magnification': '5-10', 'path_patchx20': '/work/u6658716/TCGA-LUAD/PATCHES/LUAD/tiles-5x-s224/tiles-10x-s224/feats-RN50-B-s224/pt_files', 'path_patchx5': '/work/u6658716/TCGA-LUAD/PATCHES/LUAD/tiles-5x-s224/feats-RN50-B-s224/pt_files', 'path_coordx5': '/work/u6658716/TCGA-LUAD/PATCHES/LUAD/tiles-5x-s224/patches', 'path_label': './data_split/tcga_luad_merged/tcga_luad_merged_path_full.csv', 'label_discrete': False, 'bins_discrete': 4, 'feat_forma

ModuleNotFoundError: No module named 'mpi4py'