In [None]:
import torch
import numpy as np
import cvxpy as cp
import tritrain
import matplotlib.pyplot as plt
from dataset import read_all_of_huaweicup, tensor_file_name_converter
from anchorplt import calculate_expanded_hull, plt_point_hull

In [None]:
# 假设网络给出的嵌入向量和 anch_pos 的定义已经在您的代码中
num_points = 20000
embedding_dim = 64
bs_pos, tol_samp_num, anch_samp_num, port_num, ant_num, sc_num, anch_pos, H, d_geo = read_all_of_huaweicup(1, 1)

H_real = H.real
H_imag = H.imag
H_combined = np.stack((H_real, H_imag), axis=2)
H_combined = H_combined.reshape(tol_samp_num, 4, ant_num, sc_num)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 在需要使用模型的地方加载模型
embed_net_loaded = tritrain.TripletNet().to(device)
embed_net_loaded.load_state_dict(torch.load('./CompetitionData1/Round1NET11721287896.9979763.pth'))
embed_net_loaded.eval()  # 切换到评估模式

print("Model loaded from Round1NET11721287896.9979763.pth")
embeddings = map_embeddings(embed_net_loaded, H_combined, device)

# 调用函数
mapped_coords = map_embeddings_to_coords(embeddings, anch_pos)

# 输出前五个映射结果
print("Mapped Coordinates for the first 5 points:")
print(mapped_coords[:5])

# 验证转换效果
if mapped_coords is not None:
    print("Original Known Coordinates:")
    print(anch_pos[:5, 1:])
    print("Mapped Coordinates for Known Embeddings:")
    mapped_known_coords = mapped_coords[anch_pos[:5, 0].astype(int)]
    print(mapped_known_coords[:5])
    # 绘制结果
    plt.figure(figsize=(10, 6))
    plt.scatter(mapped_coords[:, 0], mapped_coords[:, 1], s=10, label='Mapped Points')
    plt.scatter(anch_pos[:, 1], anch_pos[:, 2], s=50, label='Anchor Points', marker='x')
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.title('Mapped Coordinates vs Anchor Points')
    plt.legend()
    plt.grid(True)
    plt.show()

    # 计算差值
    indices = anch_pos[:, 0].astype(int)
    real_coords = anch_pos[:, 1:]
    diff = mapped_coords[indices] - real_coords
    # 计算所有锚点偏差的总和
    total_diff_sum = np.sum(np.abs(diff))  # 可以改为 np.linalg.norm(diff) 求范数

    print("Differences between Mapped and Real Coordinates:",diff[:5])
    print("Total deviation of anchor points:", total_diff_sum)
else:
    print("Mapping failed.")
