In [None]:
# ================== 1. 导入依赖 ==================
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T

from data_loader import SevenScenesDataset
from models import PoseNet
from utils import pose_error, tq_to_pose
from geometry_baseline import GeometryBaseline


In [None]:
# ================== 2. 设置参数 ==================
device = "cuda" if torch.cuda.is_available() else "cpu"
scene = "chess"
root_dir = "7-scenes-dataset"

transform = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

# 相机内参 (7-Scenes Kinect 默认值)
fx, fy, cx, cy = 585.0, 585.0, 320.0, 240.0


In [None]:
# ================== 3. 加载测试集 ==================
test_set = SevenScenesDataset(root_dir, scene=scene, split="test", transform=transform, return_depth=True)
print("Test samples:", len(test_set))


In [None]:
# ================== 4. 加载 PoseNet 模型 ==================
from torch.utils.data import DataLoader

test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

model = PoseNet(backbone="resnet18", pretrained=False).to(device)
model.load_state_dict(torch.load("posenet_best.pth", map_location=device))
model.eval()


In [None]:
# ================== 5. 评估 PoseNet ==================
t_errs_posenet, r_errs_posenet = [], []

with torch.no_grad():
    for batch in test_loader:
        img = batch["image"].to(device)
        pose_gt = batch["pose_matrix"].squeeze(0).numpy()

        pred = model(img).squeeze(0).cpu().numpy()
        t_pred, q_pred = pred[:3], pred[3:] / np.linalg.norm(pred[3:])
        pose_pred = tq_to_pose(t_pred, q_pred)

        t_err, r_err = pose_error(pose_gt, pose_pred)
        t_errs_posenet.append(t_err)
        r_errs_posenet.append(r_err)

print(f"PoseNet mean T_err = {np.mean(t_errs_posenet):.3f} m")
print(f"PoseNet mean R_err = {np.mean(r_errs_posenet):.2f} °")


In [None]:
# ================== 6. 评估 PnP+RANSAC ==================
baseline = GeometryBaseline([fx, fy, cx, cy], method="ORB")

t_errs_pnp, r_errs_pnp = [], []

# 遍历 test_set 的相邻帧作为参考
for i in range(len(test_set)-1):
    sample1 = test_set[i]
    sample2 = test_set[i+1]

    img1 = sample1["image"].permute(1,2,0).numpy() * 255
    img1 = img1.astype(np.uint8)
    depth1 = sample1["depth"].numpy()

    img2 = sample2["image"].permute(1,2,0).numpy() * 255
    img2 = img2.astype(np.uint8)
    depth2 = sample2["depth"].numpy()

    pose_gt = sample1["pose_matrix"].numpy()

    pose_pred = baseline.estimate_pose(img1, depth1, img2, depth2)
    if pose_pred is None:
        continue

    t_err, r_err = pose_error(pose_gt, pose_pred)
    t_errs_pnp.append(t_err)
    r_errs_pnp.append(r_err)

print(f"PnP+RANSAC mean T_err = {np.mean(t_errs_pnp):.3f} m")
print(f"PnP+RANSAC mean R_err = {np.mean(r_errs_pnp):.2f} °")


In [None]:
# ================== 7. 误差对比可视化 ==================
plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.boxplot([t_errs_posenet, t_errs_pnp], labels=["PoseNet", "PnP+RANSAC"])
plt.ylabel("Translation Error (m)")
plt.title("Translation Error Comparison")

plt.subplot(1,2,2)
plt.boxplot([r_errs_posenet, r_errs_pnp], labels=["PoseNet", "PnP+RANSAC"])
plt.ylabel("Rotation Error (°)")
plt.title("Rotation Error Comparison")

plt.tight_layout()
plt.show()


In [None]:
# ================== 8. 相机轨迹可视化 ==================
from mpl_toolkits.mplot3d import Axes3D

poses_gt, poses_pred = [], []

with torch.no_grad():
    for batch in test_loader:
        img = batch["image"].to(device)
        pose_gt = batch["pose_matrix"].squeeze(0).numpy()
        poses_gt.append(pose_gt[:3, 3])  # 提取平移 (x,y,z)

        pred = model(img).squeeze(0).cpu().numpy()
        t_pred, q_pred = pred[:3], pred[3:] / np.linalg.norm(pred[3:])
        pose_pred = tq_to_pose(t_pred, q_pred)
        poses_pred.append(pose_pred[:3, 3])

poses_gt = np.array(poses_gt)
poses_pred = np.array(poses_pred)

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection="3d")

# Ground Truth 轨迹 (蓝色)
ax.plot(poses_gt[:,0], poses_gt[:,1], poses_gt[:,2], "-o", color="blue", label="Ground Truth")

# PoseNet 预测轨迹 (红色)
ax.plot(poses_pred[:,0], poses_pred[:,1], poses_pred[:,2], "-o", color="red", label="PoseNet Predicted")

ax.set_xlabel("X (m)")
ax.set_ylabel("Y (m)")
ax.set_zlabel("Z (m)")
ax.set_title(f"Trajectory Comparison ({scene})")
ax.legend()
plt.show()


In [None]:
# ================== 9. 三轨迹可视化 (GT, PoseNet, PnP+RANSAC) ==================
poses_gt, poses_posenet, poses_pnp = [], [], []

with torch.no_grad():
    for i in range(len(test_set)-1):
        sample = test_set[i]
        img = sample["image"].unsqueeze(0).to(device)  # batch=1
        pose_gt = sample["pose_matrix"].numpy()
        poses_gt.append(pose_gt[:3, 3])

        # ---- PoseNet 预测 ----
        pred = model(img).squeeze(0).cpu().numpy()
        t_pred, q_pred = pred[:3], pred[3:] / np.linalg.norm(pred[3:])
        pose_pred = tq_to_pose(t_pred, q_pred)
        poses_posenet.append(pose_pred[:3, 3])

        # ---- PnP+RANSAC ----
        if i < len(test_set)-1:
            sample_next = test_set[i+1]

            img1 = sample["image"].permute(1,2,0).numpy() * 255
            img1 = img1.astype(np.uint8)
            depth1 = sample["depth"].numpy()

            img2 = sample_next["image"].permute(1,2,0).numpy() * 255
            img2 = img2.astype(np.uint8)
            depth2 = sample_next["depth"].numpy()

            pose_pred_pnp = baseline.estimate_pose(img1, depth1, img2, depth2)
            if pose_pred_pnp is not None:
                poses_pnp.append(pose_pred_pnp[:3, 3])
            else:
                poses_pnp.append([np.nan, np.nan, np.nan])  # 占位

poses_gt = np.array(poses_gt)
poses_posenet = np.array(poses_posenet)
poses_pnp = np.array(poses_pnp)

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection="3d")

# Ground Truth
ax.plot(poses_gt[:,0], poses_gt[:,1], poses_gt[:,2], "-o", color="blue", label="Ground Truth")

# PoseNet
ax.plot(poses_posenet[:,0], poses_posenet[:,1], poses_posenet[:,2], "-o", color="red", label="PoseNet")

# PnP+RANSAC
mask = ~np.isnan(poses_pnp[:,0])
ax.plot(poses_pnp[mask,0], poses_pnp[mask,1], poses_pnp[mask,2], "-o", color="green", label="PnP+RANSAC")

ax.set_xlabel("X (m)")
ax.set_ylabel("Y (m)")
ax.set_zlabel("Z (m)")
ax.set_title(f"Trajectory Comparison: {scene}")
ax.legend()
plt.show()


In [None]:
# ================== 10. 2D XY 平面轨迹可视化 ==================
plt.figure(figsize=(8,6))

# Ground Truth
plt.plot(poses_gt[:,0], poses_gt[:,1], "-o", color="blue", label="Ground Truth", markersize=3)

# PoseNet
plt.plot(poses_posenet[:,0], poses_posenet[:,1], "-o", color="red", label="PoseNet", markersize=3)

# PnP+RANSAC
mask = ~np.isnan(poses_pnp[:,0])
plt.plot(poses_pnp[mask,0], poses_pnp[mask,1], "-o", color="green", label="PnP+RANSAC", markersize=3)

plt.xlabel("X (m)")
plt.ylabel("Y (m)")
plt.title(f"Trajectory Comparison (XY plane) - {scene}")
plt.legend()
plt.axis("equal")   # 保持比例尺一致
plt.grid(True)
plt.show()
