In [11]:
import torch, os, math
import numpy as np
import matplotlib.pyplot as plt

pt_path = "temporal_reuse_mask_frames_ep08.pt"  # 你的pt文件
out_dir = "mask_viz"; os.makedirs(out_dir, exist_ok=True)

mask = torch.load(pt_path, map_location="cpu")          # (T-1, H, W), bool 或 0/1
if mask.dtype != torch.bool: mask = mask > 0.5          # 统一成bool
mask_np = mask.numpy().astype(np.float32)               # 0/1

T_1, H, W = mask_np.shape
ncols = 8
nrows = math.ceil(T_1 / ncols)

# 保存整页网格
fig, axes = plt.subplots(nrows, ncols, figsize=(2.2*ncols, 2.2*nrows), squeeze=False)
for i in range(nrows*ncols):
    r, c = divmod(i, ncols)
    ax = axes[r, c]; ax.axis('off')
    if i < T_1:
        ax.imshow(mask_np[i], cmap="gray", vmin=0, vmax=1)
        ax.set_title(f"F={i+1}")   # 注意：这是 frame=1..T-1（与第0帧比）
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "reuse_mask_grid_08.png"), dpi=180)
plt.close(fig)

# 保存单帧图
frame_dir = os.path.join(out_dir, "frames"); os.makedirs(frame_dir, exist_ok=True)
for i in range(T_1):
    plt.figure(figsize=(3,3))
    plt.imshow(mask_np[i], cmap="gray", vmin=0, vmax=1)
    plt.axis('off'); plt.title(f"F={i+1}")
    plt.tight_layout()
    plt.savefig(os.path.join(frame_dir, f"mask_f{i+1:03d}.png"), dpi=160)
    plt.close()
print("done ->", out_dir)


  mask = torch.load(pt_path, map_location="cpu")          # (T-1, H, W), bool 或 0/1


done -> mask_viz


In [4]:
import torch, os, math
import matplotlib.pyplot as plt
import numpy as np
import imageio.v2 as imageio

pt_path = "second_order_gap.pt"
out_dir = "viz_maxdiff"; os.makedirs(out_dir, exist_ok=True)

# 加载数据
max_diff = torch.load(pt_path, map_location="cpu")  # (T-1, H, W)
max_diff = max_diff.float().numpy()

T_1, H, W = max_diff.shape

# ---- 网格图 ----
ncols = 8
nrows = math.ceil(T_1 / ncols)
fig, axes = plt.subplots(nrows, ncols, figsize=(2.2*ncols, 2.2*nrows), squeeze=False)

vmin, vmax = np.percentile(max_diff, 1), np.percentile(max_diff, 99)  # 用分位数做归一化，避免极值影响
for i in range(nrows*ncols):
    r, c = divmod(i, ncols)
    ax = axes[r, c]; ax.axis('off')
    if i < T_1:
        ax.imshow(max_diff[i], cmap="viridis", vmin=vmin, vmax=vmax)
        ax.set_title(f"F={i+1}", fontsize=8)  # 注意这里是相对于 frame=0 的第 i+1 帧
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "maxdiff_grid.png"), dpi=180)
plt.close()
print("saved -> maxdiff_grid.png")

# ---- 单帧热力图 ----
for i in range(T_1):
    plt.figure(figsize=(3,3))
    plt.imshow(max_diff[i], cmap="viridis", vmin=vmin, vmax=vmax)
    plt.colorbar(label="L2 difference")
    plt.title(f"Frame {i+1} vs Frame 0")
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"maxdiff_f{i+1:03d}.png"), dpi=160)
    plt.close()

# ---- GIF ----
frames = []
for i in range(T_1):
    arr = (np.clip((max_diff[i]-vmin)/(vmax-vmin), 0, 1)*255).astype(np.uint8)
    rgb = plt.cm.viridis(arr/255.0)[..., :3]   # 用 viridis colormap 转成 RGB
    rgb = (rgb*255).astype(np.uint8)
    frames.append(rgb)
imageio.mimsave(os.path.join(out_dir, "maxdiff.gif"), frames, fps=5)
print("saved -> maxdiff.gif")


  max_diff = torch.load(pt_path, map_location="cpu")  # (T-1, H, W)


saved -> maxdiff_grid.png
saved -> maxdiff.gif
