# imports

In [1]:
# 初始化整個實驗環境
import sys
sys.path.append("C:/Users/GAI/Desktop/Scott/NCA_Research")

from E4_PI_NCA.init_notebook_imports import *

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)
set_global_seed(1234)

✅ Environment initialized. Use show_env_info() to check details.
Using device: cuda
[INFO] Global seed set to 1234


# data process

## func

In [2]:
# ================================================================
# 輔助函式
# ================================================================
def nan_to_masked_CHW(arr):
    """將 NaN 轉成 0，並新增一個 geo_mask channel（1 表有效）"""
    mask = (~np.isnan(arr[0])).astype(np.float32)[np.newaxis, ...]
    arr = np.nan_to_num(arr, nan=0.0)
    return np.concatenate([mask, arr], axis=0)

def add_coord_channels(arr):
    """在最前面新增 normalized coord_y, coord_x"""
    _, H, W = arr.shape
    y = np.linspace(-1, 1, H)[..., None] * np.ones((1, W))
    x = np.linspace(-1, 1, W)[None, ...] * np.ones((H, 1))
    coords = np.stack([y, x], axis=0)
    return np.concatenate([coords, arr], axis=0)

def center_crop(arr, target_h, target_w):
    """以中心為基準裁切 (C, H, W)"""
    C, H, W = arr.shape
    start_h = (H - target_h) // 2
    start_w = (W - target_w) // 2
    return arr[:, start_h:start_h+target_h, start_w:start_w+target_w]

# process urbantales cases into npz file 

In [3]:
from pathlib import Path
import numpy as np
import xarray as xr
import torch
import torch.nn.functional as F

# ================================================================
# 基本設定
# ================================================================
folder = Path("../dataset")
output_path = folder / "all_cases_BCHW.npz"



# ================================================================
# 主流程：讀取與處理所有 case
# ================================================================
all_cases = []
case_names = []
channel_names_ref = None
from tqdm import tqdm
count=0
for ped_file in tqdm(list(folder.rglob("*ped.nc")), desc="Processing ped files"):
    case_name = ped_file.parent.name


    

    # ------------------------------------------------------------
    # 解析風向 (從資料夾名稱中的 _d 取得角度)
    # ------------------------------------------------------------
    wind_dir = float(case_name.split("_d")[-1])

    # ------------------------------------------------------------
    # 讀取 NetCDF
    # ------------------------------------------------------------
    with xr.open_dataset(ped_file) as ds:
        arrays = [ds[var].values for var in ds.data_vars]
        ped_np = np.stack(arrays, axis=-1)[::-1, :, :]  # (H, W, C)
        vars_names = list(ds.data_vars.keys())

    # ------------------------------------------------------------
    # 加入 wind direction (sin, cos)
    # ------------------------------------------------------------
    wind_np = np.zeros_like(ped_np[..., 0:2])
    wind_np[..., 0] = np.sin(np.deg2rad(wind_dir))
    wind_np[..., 1] = np.cos(np.deg2rad(wind_dir))

    # ------------------------------------------------------------
    # 加入 topo 資料（若存在）
    # ------------------------------------------------------------
    topo_file = next(ped_file.parent.glob("*_topo"), None)
    if topo_file:
        topo = np.loadtxt(topo_file)[:, :, np.newaxis]
        ped_np = np.concatenate([topo, wind_np, ped_np], axis=-1)
        channel_names = ["topo", "windInitX", "windInitY"] + vars_names
    else:
        ped_np = np.concatenate([wind_np, ped_np], axis=-1)
        channel_names = ["windInitX", "windInitY"] + vars_names

    # ------------------------------------------------------------
    # 後處理
    # ------------------------------------------------------------
    # HWC → CHW
    ped_np = np.transpose(ped_np, (2, 0, 1))
    # NaN 處理與 mask
    ped_np = nan_to_masked_CHW(ped_np)
    # 加上 coord channel
    ped_np = add_coord_channels(ped_np)
    # 更新 channel 名稱
    channel_names = ["coord_y", "coord_x", "geo_mask"] + channel_names

    # ------------------------------------------------------------
    # Resize / Crop
    # ------------------------------------------------------------
    ped_t = torch.from_numpy(ped_np).unsqueeze(0)  # (1, C, H, W)
    _, C, H, W = ped_t.shape
    # scale = 3
    # final_size = 96
    # ped_i = F.interpolate(ped_t, size=(H//scale, W//scale), mode="nearest")
    # ped_np = ped_i.squeeze(0).numpy()
    # ped_np = center_crop(ped_np, final_size,final_size)  # 保留最終尺寸一致
    # print(f"處理 {case_name} ... shape = {ped_t.shape}  | shape = {ped_i.shape}  | shape = {ped_np.shape}")


    ped_i = F.interpolate(ped_t, size=(256,256), mode="bicubic")
    ped_np = ped_i.squeeze(0).numpy()


    # ------------------------------------------------------------
    # 累積
    # ------------------------------------------------------------
    all_cases.append(ped_np)
    case_names.append(case_name)
    if channel_names_ref is None:
        channel_names_ref = channel_names
    
    # count+=1
    # if count>4:
    #     break

# ================================================================
# 統一合併為 BCHW
# ================================================================
data_BCHW = np.stack(all_cases, axis=0)  # (B, C, H, W)
print_tensor_stats(data_BCHW)
# ================================================================
# 存檔
# ================================================================
np.savez_compressed(
    output_path,
    data=data_BCHW,
    case_names=np.array(case_names),
    channel_names=np.array(channel_names_ref)
)

# ================================================================
# 驗證輸出
# ================================================================
print(f"\n✅ 已儲存：{output_path}")
print(f"共 {len(case_names)} 個 case")
print(f"資料形狀：{data_BCHW.shape} (B, C, H, W)")
print(" 通道名稱 channel:", channel_names_ref)
print("範例 case:", case_names[0])



Processing ped files: 100%|██████████| 72/72 [00:09<00:00,  7.81it/s]


Tensor Channel-wise stats (共 11 個 channel):
ch           min         q1       mean         q3        max
------------------------------------------------------------
0      -1.000184  -0.499782   0.000000   0.499782   1.000184
1      -1.001059  -0.500806  -0.000000   0.500806   1.001059
2       1.000000   1.000000   1.000000   1.000000   1.000000
3     -48.800720   0.000000   4.752388   2.923584 225.204889
4       0.000000   0.000000   0.366546   0.707107   1.000000
5      -1.000000   0.439705   0.642144   1.000000   1.000000
6      -2.027470  -0.066149   0.075328   0.184772   1.807298
7      -1.832599  -0.043820   0.059269   0.159883   1.698117
8      -0.254789   0.057104   0.311276   0.485398   2.165388
9      -0.062029   0.015168   0.053031   0.079356   0.728069
10     -0.190695  -0.003430  -0.001935   0.000000   0.276433

✅ 已儲存：..\dataset\all_cases_BCHW.npz
共 72 個 case
資料形狀：(72, 11, 256, 256) (B, C, H, W)
 通道名稱 channel: ['coord_y', 'coord_x', 'geo_mask', 'topo', 'windInitX', 'windI

'uped', 'vped' vel_ped 平均風速(不包含湍流) Uped 平均風速(包含湍流)  TKEped(湍流項)



# plot

## func

In [4]:
import re
def plot_BCHW_channels(
    data_BCHW,
    case_names,
    channel_names,
    max_channels: int = 9,
    filter_patterns: list = []
):
    """
    將 BCHW 資料繪製成每個 case 一行、每個 channel 一列的圖。
    支援正規表達式過濾樣本名稱。

    Args:
        data_BCHW : np.ndarray | torch.Tensor
            shape = (B, C, H, W)
        case_names : list[str] or np.ndarray
            對應每個樣本的名稱
        channel_names : list[str] or np.ndarray
            每個 channel 的名稱
        max_channels : int, optional
            每個樣本最多顯示多少 channel (default: 9)
        filter_patterns : list[str], optional
            要排除的樣本名稱 regex pattern
    """

    # ------------------------------------------------------------
    # 型態與資料檢查
    # ------------------------------------------------------------
    if data_BCHW.ndim != 4:
        raise ValueError(f"輸入資料需為 4 維 (B,C,H,W)，但得到 ndim={data_BCHW.ndim}")

    if isinstance(data_BCHW, torch.Tensor):
        data_BCHW = data_BCHW.detach().cpu().numpy()

    B, C, H, W = data_BCHW.shape

    case_names = np.array(case_names)
    channel_names = np.array(channel_names)

    # ------------------------------------------------------------
    # 過濾樣本 (使用 regex)
    # ------------------------------------------------------------
    valid_indices = [
        i for i, name in enumerate(case_names)
        if not any(re.match(pat, name) for pat in filter_patterns)
    ]
    data_BCHW = data_BCHW[valid_indices]
    case_names = case_names[valid_indices]

    n_samples = len(case_names)
    n_show = min(max_channels, C)

    print(f"繪製 {n_samples} 個樣本，每個顯示前 {n_show} 個 channel")

    # ------------------------------------------------------------
    # 建立 subplot
    # ------------------------------------------------------------
    ncols = n_show + 2  # case name + 原圖 + channels
    fig, axes = plt.subplots(nrows=n_samples, ncols=ncols, figsize=(3*ncols, 3*n_samples))

    if n_samples == 1:
        axes = np.expand_dims(axes, 0)
    if ncols == 1:
        axes = np.expand_dims(axes, 1)

    # ------------------------------------------------------------
    # 主繪圖迴圈
    # ------------------------------------------------------------
    for i, (case_name, chw) in enumerate(zip(case_names, data_BCHW)):
        HWC = np.transpose(chw, (1, 2, 0))

        # 第一欄：case 名稱
        axes[i][0].text(0.5, 0.5, str(case_name), ha="center", va="center",
                        fontsize=12, weight="bold")
        axes[i][0].axis("off")

        # 第二欄：前三 channel 合成原圖
        if C >= 3:
            axes[i][1].imshow(HWC[:, :, :3])
        else:
            axes[i][1].imshow(HWC[:, :, 0], cmap="gray")
        axes[i][1].set_title("原圖")
        axes[i][1].axis("off")

        # 其餘欄：各 channel
        for j in range(n_show):
            ch_data = chw[j]
            ax = axes[i][j + 2]
            im = ax.imshow(ch_data, cmap="jet")
            title = channel_names[j] if j < len(channel_names) else f"ch{j}"
            ax.set_title(title)
            ax.axis("off")

            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

            # 若通道為常數 → 顯示數值
            if np.all(ch_data == ch_data.flat[0]):
                ax.text(W // 2, H // 2, f"{ch_data.flat[0]:.3f}",
                        color="white", ha="center", va="center",
                        fontsize=10, weight="bold",
                        bbox=dict(facecolor="black", alpha=0.6, boxstyle="round,pad=0.2"))

    plt.tight_layout()
    plt.savefig("channels.png",dpi=300)
    plt.show()

## viz

In [5]:
# 範例呼叫
npz_path="../dataset/all_cases_BCHW.npz"
npz = np.load(npz_path, allow_pickle=True)
data_BCHW = npz["data"]
case_names = npz["case_names"]
channel_names = npz["channel_names"]
print_tensor_stats(data_BCHW)
plot_BCHW_channels(
    data_BCHW=data_BCHW,
    case_names=case_names,
    channel_names=channel_names,
    max_channels=11,
    filter_patterns=["^CN-"]
)
print("=== 0-1 scale 後 ===")

data_BCHW = minmax_scale_channelwise(data_BCHW)
print_tensor_stats(data_BCHW)

# 呼叫繪圖函式
plot_BCHW_channels(
    data_BCHW=data_BCHW,
    case_names=case_names,
    channel_names=channel_names,
    max_channels=11,
    filter_patterns=["^CN-"]
)


Tensor Channel-wise stats (共 11 個 channel):
ch           min         q1       mean         q3        max
------------------------------------------------------------
0      -1.000184  -0.499782   0.000000   0.499782   1.000184
1      -1.001059  -0.500806  -0.000000   0.500806   1.001059
2       1.000000   1.000000   1.000000   1.000000   1.000000
3     -48.800720   0.000000   4.752388   2.923584 225.204889
4       0.000000   0.000000   0.366546   0.707107   1.000000
5      -1.000000   0.439705   0.642144   1.000000   1.000000
6      -2.027470  -0.066149   0.075328   0.184772   1.807298
7      -1.832599  -0.043820   0.059269   0.159883   1.698117
8      -0.254789   0.057104   0.311276   0.485398   2.165388
9      -0.062029   0.015168   0.053031   0.079356   0.728069
10     -0.190695  -0.003430  -0.001935   0.000000   0.276433
繪製 58 個樣本，每個顯示前 11 個 channel


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.9994703642425048..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.9991396157522389..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.9991396157522389..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.9991396157522389..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.9991396157522389..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.9991396157522389..1.0].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Go

KeyboardInterrupt: 