In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D   # noqa: F401

# H_unit = paths.cfr(
#     frequencies         = freqs,
#     sampling_frequency  = 1/ofdm_symbol_duration,
#     num_time_steps      = num_ofdm_symbols,   
#     normalize_delays    = False,
#     normalize           = False,
#     out_type            = "numpy"
#     ).squeeze()           
# print("H_unit.shape", H_unit.shape)
H_all = H_unit.sum(axis=0)
H_des = H_unit[idx_des].sum(axis=0)   # (T, F)
H_jam = H_unit[idx_jam].sum(axis=0)   # (T, F) 
tx_p_lin = 10**(np.array([tx.power_dbm for tx in all_txs]) / 10) / 1e3   
tx_p_lin = np.squeeze(tx_p_lin)  
print("tx_p_lin.shape", tx_p_lin.shape)  
sqrtP    = np.sqrt(tx_p_lin)[:, None, None]            #
print("sqrtP.shape", sqrtP.shape)  
H_unit     = H_unit * sqrtP                                           
print("H_unit.shape", H_unit.shape)  



# 11) 計算 Delay-Doppler圖
def to_delay_doppler(H_tf):
    Hf      = np.fft.fftshift(H_tf, axes=1)            # F shift
    h_delay = np.fft.ifft(Hf, axis=1 , norm="ortho")   # F→delay
    h_dd    = np.fft.fft(h_delay, axis=0 , norm="ortho")# t→doppler
    h_dd    = np.fft.fftshift(h_dd, axes=0)            # doppler shift
    return h_dd
Hdd_list = [to_delay_doppler(H_unit[i]) 
            for i in range(H_unit.shape[0])]   # list 長度 = 6，每項 shape=(1024,1024)

# ========= 2) 每 Tx 轉 DD & 取幅度 =========
Hdd_list = [ np.abs(to_delay_doppler(H_unit[i])) for i in range(H_unit.shape[0]) ]

# ========= 3) 動態組合「個別」與「合成」的網格 =========
grids   = []
labels  = []
doppler_bins = np.arange(-num_ofdm_symbols/2*doppler_resolution,num_ofdm_symbols/2*doppler_resolution,doppler_resolution)
delay_bins = np.arange(0,num_subcarriers*delay_resolution,delay_resolution) / 1e-9
x, y = np.meshgrid(delay_bins, doppler_bins)


offset = 20
x_start = int(num_subcarriers/2)-offset
x_end = int(num_subcarriers/2)+offset
y_start = 0
y_end = offset
x_grid = x[x_start:x_end,y_start:y_end]
y_grid = y[x_start:x_end,y_start:y_end]

# --- Desired 個別 ---
for k,i in enumerate(idx_des):
    Zi = Hdd_list[i][x_start:x_end,y_start:y_end]   # (offset, offset) 小窗
    grids.append(Zi)
    labels.append(f"Des Tx{i}")

# --- Jammer 個別 ---
for k,i in enumerate(idx_jam):
    Zi = Hdd_list[i][x_start:x_end,y_start:y_end]
    grids.append(Zi)
    labels.append(f"Jam Tx{i}")

# --- Desired All ---
if idx_des:
    Z_des_all = np.sum([Hdd_list[i] for i in idx_des], axis=0)
    grids.append(Z_des_all[x_start:x_end,y_start:y_end])
    labels.append("Des ALL")

# --- Jammer All ---
if idx_jam:
    Z_jam_all = np.sum([Hdd_list[i] for i in idx_jam], axis=0)
    grids.append(Z_jam_all[x_start:x_end,y_start:y_end])
    labels.append("Jam ALL")

# --- All Tx ---
Z_all = np.sum(Hdd_list, axis=0)
grids.append(Z_all[x_start:x_end,y_start:y_end])
labels.append("ALL Tx")

# ========= 4) 統一 Z 軸 =========
z_min = 0
z_max = max(g.max() for g in grids) * 1.05

# ========= 5) 自動排版 (每列最多 3 張) =========
n_plots  = len(grids)
cols     = 3
rows     = int(np.ceil(n_plots / cols))
figsize  = (cols*4.5, rows*4.5)

fig = plt.figure(figsize=figsize)

for idx,(Z,label) in enumerate(zip(grids, labels), start=1):
    ax = fig.add_subplot(rows, cols, idx, projection='3d')
    ax.plot_surface(x_grid, y_grid, Z, cmap='viridis', edgecolor='none')
    ax.set_title(f"Delay–Doppler |{label}|", pad=8)
    ax.set_xlabel("Delay (ns)"); ax.set_ylabel("Doppler (Hz)"); ax.set_zlabel("|H|")
    ax.set_zlim(z_min, z_max)
    # ax.view_init(elev=53, azim=-32)

plt.tight_layout()
plt.show()
