# Vector-HaSH 类比推理模型

本notebook实现了基于多Grid的Vector-HaSH模型用于类比推理任务。

## 模型框架

- **多个Grid模块**：每个关系类型对应一个独立的Grid模块
- **共享HPC空间**：所有Grid共享同一个HPC（Place Cell）空间
- **双向连接**：
  - Grid ↔ HPC: 通过权重矩阵 W_pg 和 W_gp
  - 维护W_pg[grid_index]和W_gp[grid_index]
  - HPC ↔ Sensory: 通过权重矩阵 W_sp 和 W_ps

- **Relation Classifier**：决定哪个Grid模块更新状态



In [2]:
# === 导入库 ===
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import randn, randint
from tqdm import tqdm
from collections import defaultdict
from sentence_transformers import SentenceTransformer
import torch
import pandas as pd
import json
import os
import sys

# 导入自定义模块
sys.path.append(os.path.abspath('.'))
try:
    from src.assoc_utils_np import train_gcpc
    from src.assoc_utils_np_2D import gen_gbook_2d, path_integration_Wgg_2d, module_wise_NN_2d
    from src.seq_utils import nonlin, sensorymap
except ImportError:
    print("警告: 无法导入src模块，将使用本地实现")

# 绘图设置
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False

print("✓ 导入完成")

  from .autonotebook import tqdm as notebook_tqdm


✓ 导入完成


## 1. 超参数配置和模型结构初始化


In [3]:

nruns = 1
Np = 342  # Place Cell数量
lambdas = [3, 4, 5, 7]  # Grid模块周期
Ng = np.sum(np.square(lambdas))  
Npos = np.prod(lambdas) 

print(f"Grid空间大小 (Npos): {Npos} x {Npos} = {Npos*Npos}")
print(f"Grid Cell数量 (Ng): {Ng}")
print(f"Place Cell数量 (Np): {Np}")

# === 2. 加载关系类型映射 ===
label_map_path = 'label_map.json'
with open(label_map_path, 'r', encoding='utf-8') as f:
    label_map = json.load(f)

# 关系类型列表（grid_index -> 关系名称）
RELATION_TYPES = [label_map[str(i)] for i in range(len(label_map))]
N_GRIDS = len(RELATION_TYPES)
RELATION_TO_IDX = {rel: idx for idx, rel in enumerate(RELATION_TYPES)}


print(f"\n=== 多Grid配置 ===")
print(f"Grid数量: {N_GRIDS}")
print(f"关系类型: {RELATION_TYPES[:5]}... (共{len(RELATION_TYPES)}种)")


Grid空间大小 (Npos): 420 x 420 = 176400
Grid Cell数量 (Ng): 99
Place Cell数量 (Np): 342

=== 多Grid配置 ===
Grid数量: 19
关系类型: ['Uncategorized', '主体-产物', '主体-动作', '亲属关系', '位置/空间']... (共19种)


## 2. 生成gbooks和共享pbook

In [4]:
# 生成gbooks
gbooks = []
gbooks_flattened = []

for i in range(N_GRIDS):
    gbook = gen_gbook_2d(lambdas, Ng, Npos)
    gbooks.append(gbook)
    gbook_transposed = np.transpose(gbook, (0, 2, 1))  
    gbooks_flattened.append(gbook_transposed.reshape(Ng, Npos*Npos))

gbook_init =  gen_gbook_2d(lambdas, Ng, Npos)
# 生成共享pbook---全局共享hpc
Wpg = randn(Np, Ng)   
c = 0.10  # 连接概率
prune = int((1-c)*Np*Ng)
mask = np.ones((Np, Ng))
mask[randint(low=0, high=Np, size=prune), randint(low=0, high=Ng, size=prune)] = 0
Wpg = np.multiply(mask, Wpg)
thresh = 2.0

# 使用第一个gbook作为参考生成pbook（起始的gbook都是一样的）
pbook = nonlin(np.einsum('ij,jlm->ilm', Wpg, gbooks[0]), thresh=thresh)  #(Np,Npos,Npos)
pbook_transposed = np.transpose(pbook, (0, 2, 1))
pbook_flattened = pbook_transposed.reshape(Np, Npos*Npos)

print(f"\n✓ 共享HPC (pbook) 生成完成: shape {pbook.shape}")
print(f"  W_pg shape: {Wpg.shape}")

# 伪逆学习连接矩阵 Wgp
Wgp_list = []
for grid_idx in range(N_GRIDS):
    P_matrix = pbook_flattened[:, :]  # (Np, Npos*Npos) - 所有位置的Place向量
    G_matrix = gbooks_flattened[grid_idx]  # (Ng, Npos*Npos) - 所有位置的Grid向量
    P_pinv = np.linalg.pinv(P_matrix)
    Wgp_learned = G_matrix @ P_pinv  # (Ng, Np)
    Wgp_list.append(Wgp_learned)


✓ 共享HPC (pbook) 生成完成: shape (342, 420, 420)
  W_pg shape: (342, 99)


In [5]:
# 辅助函数 

def flat_idx(x, y, Npos=Npos):
    """将2D坐标转换为1D索引"""
    return int(y * Npos + x)

def get_pvec(x, y, pbook_flattened=pbook_flattened):
    """获取指定坐标的Place Cell向量"""
    return pbook_flattened[0, :, flat_idx(x, y)]

def cosine_sim(v1, v2):
    """计算余弦相似度"""
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-9)

def rls_step(W, theta, a, y):
    """
    递归最小二乘(RLS)更新步骤
    """
    a = np.asarray(a).reshape(-1, 1)
    y = np.asarray(y).reshape(-1, 1)
    
    pred_before = W @ a
    err_before = y - pred_before
    
    denom = 1.0 + (a.T @ theta @ a).item()
    bk = (theta @ a) / denom
    
    theta_new = theta - (theta @ a @ bk.T)
    W_new = W + (err_before @ bk.T)
    
    pred_after = W_new @ a
    err_after = y - pred_after
    
    return W_new, theta_new, float(np.linalg.norm(bk)), float(np.linalg.norm(err_before)), float(np.linalg.norm(err_after))


## 3. 数据加载和预处理


In [6]:
# 加载e-kar数据
data_path = 'obj_grid_index.csv'
df = pd.read_csv(data_path)

print(f"加载数据: {len(df)} 条词对")
print(f"数据预览:")
print(df.head(10))

# 统计每个grid的数据量
grid_counts = df['grid_index'].value_counts().sort_index()
print(f"\n各Grid数据量统计:")
for grid_idx, count in grid_counts.items():
    print(f"  Grid {grid_idx} ({RELATION_TYPES[grid_idx]}): {count} 条")

# 提取所有唯一的对象（用于生成Sensory向量）
all_objects = set(df['Obj_A'].tolist() + df['Obj_B'].tolist())
C = len(all_objects)  # 对象总数
print(f"\n唯一对象数量: {C}")


加载数据: 1675 条词对
数据预览:
  Obj_A Obj_B  grid_index
0    稻谷    大米           1
1    核桃    桃酥          12
2    棉花    棉籽           1
3    西瓜    瓜子           1
4    花生   花生酱           1
5  旗开得胜  马到成功          16
6  拨乱反正  沉冤昭雪          16
7  牛高马大  虎穴得子          16
8  水到渠成  瓜熟蒂落          16
9  缘木求鱼  鹰击长空          16

各Grid数据量统计:
  Grid 0 (Uncategorized): 41 条
  Grid 1 (主体-产物): 141 条
  Grid 2 (主体-动作): 294 条
  Grid 3 (亲属关系): 3 条
  Grid 4 (位置/空间): 89 条
  Grid 5 (包含/种属): 313 条
  Grid 6 (反义/对立): 115 条
  Grid 7 (因果/依赖): 102 条
  Grid 8 (属性/特征): 94 条
  Grid 9 (工具-功能): 85 条
  Grid 10 (师生传承): 3 条
  Grid 11 (并列/同类): 49 条
  Grid 12 (材料-成品): 53 条
  Grid 13 (等级/排序): 10 条
  Grid 14 (组成/整体): 47 条
  Grid 15 (职业-对象): 5 条
  Grid 16 (象征/比喻): 56 条
  Grid 17 (近义/同一): 127 条
  Grid 18 (顺序/过程): 48 条

唯一对象数量: 2641


## 4. sbook 生成以及 与pbook的双向连接框架
- 使用中文embedding模型为每个对象生成嵌入向量
- 模型为`shibing624/text2vec-base-chinese`


In [7]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "shibing624/text2vec-base-chinese"
print(f"  尝试加载模型: {model_name}")
embedding_model = SentenceTransformer(model_name, device=device)
print(f"  ✓ 成功加载模型: {model_name} (设备: {device})")
# 为每个对象生成Sensory向量
object_to_id = {obj: idx for idx, obj in enumerate(sorted(all_objects))}
id_to_object = {idx: obj for obj, idx in object_to_id.items()}
events = []  # Sensory向量列表
events_with_word = {}  #词到向量的字典

if embedding_model is not None:
    print(f"\n正在为 {C} 个对象生成embedding向量...")
    # 批量生成embedding（更高效）
    object_list = sorted(all_objects)
    embeddings = embedding_model.encode(
        object_list,
        batch_size=128,
        show_progress_bar=True,
        #normalize_embeddings=True  # L2归一化
    )
    
    # 转换为numpy数组并存储
    for idx, (obj, emb) in enumerate(zip(object_list, embeddings)):
        vec = emb
        events.append(vec)
        events_with_word[obj] = vec
    
    Ns = embeddings.shape[1]  # 实际embedding维度
    print(f"✓ 使用embedding模型生成 {C} 个对象的Sensory向量")
    print(f"  Embedding维度: {Ns}")
    print(f"  模型: {model_name}")
else:
    # Fallback: 使用随机向量
    raise ImportError("Embedding模型加载失败")

  尝试加载模型: shibing624/text2vec-base-chinese
  ✓ 成功加载模型: shibing624/text2vec-base-chinese (设备: cuda)

正在为 2641 个对象生成embedding向量...


Batches: 100%|██████████| 21/21 [00:00<00:00, 50.44it/s]

✓ 使用embedding模型生成 2641 个对象的Sensory向量
  Embedding维度: 768
  模型: shibing624/text2vec-base-chinese





In [8]:
# 初始化权重矩阵 
epsilon = 0.05
# W_ps:  Sensory -> Place 
Wps = np.zeros((Np, Ns), dtype=float)

# W_sp:  Place -> Sensory
Wsp = np.zeros((Ns, Np), dtype=float)

print(f"✓ 权重矩阵初始化完成")
print(f"  W_gp: {len(Wgp_list)} 个矩阵，每个 shape {Wgp_list[0].shape}")
print(f"  W_ps: shape {Wps.shape}")
print(f"  W_sp: shape {Wsp.shape}")

✓ 权重矩阵初始化完成
  W_gp: 19 个矩阵，每个 shape (99, 342)
  W_ps: shape (342, 768)
  W_sp: shape (768, 342)


In [9]:
def nearest_neighbor(gin, gbook):
    
    # 计算相似度: (n_positions,)
    similarities = gin @ gbook
    
    # 找到所有最大值的索引（处理平局）
    max_val = np.max(similarities)
    max_indices = np.argwhere(similarities == max_val).flatten()  # 所有最大值位置
    
    # 随机选一个
    idx = np.random.choice(max_indices)
    
    return gbook[:, idx]

In [10]:
# === 10. 全局HPC位置分配 ===
# 每个对象在HPC中只有一个全局唯一位置（与Grid无关）

object_hpc_locs = {}  # object_id -> (x, y)
hpc_next_x = 0
hpc_next_y = 0

# 记录每行已分配的对象数量，用于判断是否需要跳过最后一个位置
row_counts = {}       # y -> count

def allocate_global_hpc_position(obj_id):
    """
    为对象分配全局唯一的HPC位置
    分配规则：从(0,0)开始，按行优先连续分配（先向右，再向上）。
    约束：当一行放满切换到下一行时，该行已分配的位置数必须为偶数。
          若分配最后一列会导致该行位置数为奇数，则主动跳过该位置并换行。
    """
    global hpc_next_x, hpc_next_y, row_counts, object_hpc_locs

    # 若对象已分配，直接返回已有位置
    if obj_id in object_hpc_locs:
        return object_hpc_locs[obj_id]

    # 确保HPC空间足够（调用前应已检测，此处为防御）
    if hpc_next_y >= Npos:
        raise RuntimeError(f"HPC空间不足，无法为对象{obj_id}分配位置！")

    while True:
        y = hpc_next_y
        current_row_count = row_counts.get(y, 0)

        # 如果当前位于本行最后一列，且分配此位置会使该行位置数变为奇数，则跳过该列
        if hpc_next_x == Npos - 1 and (current_row_count + 1) % 2 == 1:
            # 跳过最后一列，换行
            hpc_next_x = 0
            hpc_next_y += 1
            if hpc_next_y >= Npos:
                raise RuntimeError(f"HPC空间不足，无法为对象{obj_id}分配位置！")
            continue   # 重新检查新行的状态
        else:
            break

    x, y = hpc_next_x, hpc_next_y
    object_hpc_locs[obj_id] = (x, y)

    # 更新行计数
    row_counts[y] = row_counts.get(y, 0) + 1

    # 更新下一个可用位置
    hpc_next_x += 1
    if hpc_next_x >= Npos:
        hpc_next_x = 0
        hpc_next_y += 1
        # 换行时，刚结束的行必然已满足位置数为偶数（由跳过逻辑保证）
        # 可加断言验证，但非必须
        # assert row_counts.get(y, 0) % 2 == 0, f"行{y}位置数不是偶数！"

    return (x, y)


def check_hpc_space_enough(num_objects):
    """
    检测HPC空间是否足够容纳指定数量的对象
    新规则下，每行最多可分配的位置数为：
        - 若Npos为偶数：可放满整行，位置数为Npos（偶数）
        - 若Npos为奇数：不能放满最后1列，每行最多Npos-1个位置（偶数）
    总容量 = 每行最大个数 × 总行数(Npos)
    """
    if Npos % 2 == 0:
        max_per_row = Npos
    else:
        max_per_row = Npos - 1
    total_available = max_per_row * Npos
    return num_objects <= total_available


# ========== 使用示例 ==========
# 假设全局常量 Npos 已定义（例如 Npos = 5）
# 以下代码与原有逻辑保持一致

# 1. 先检测空间是否足够
if not check_hpc_space_enough(C):
    print(f"【错误】HPC空间不足，无法为{C}个对象分配位置！")
else:
    # 2. 空间足够时才分配位置
    for obj_id in range(C):
        allocate_global_hpc_position(obj_id)
    print(f"【成功】已为{C}个对象分配HPC位置，位置信息：{object_hpc_locs}")

【成功】已为2641个对象分配HPC位置，位置信息：{0: (0, 0), 1: (1, 0), 2: (2, 0), 3: (3, 0), 4: (4, 0), 5: (5, 0), 6: (6, 0), 7: (7, 0), 8: (8, 0), 9: (9, 0), 10: (10, 0), 11: (11, 0), 12: (12, 0), 13: (13, 0), 14: (14, 0), 15: (15, 0), 16: (16, 0), 17: (17, 0), 18: (18, 0), 19: (19, 0), 20: (20, 0), 21: (21, 0), 22: (22, 0), 23: (23, 0), 24: (24, 0), 25: (25, 0), 26: (26, 0), 27: (27, 0), 28: (28, 0), 29: (29, 0), 30: (30, 0), 31: (31, 0), 32: (32, 0), 33: (33, 0), 34: (34, 0), 35: (35, 0), 36: (36, 0), 37: (37, 0), 38: (38, 0), 39: (39, 0), 40: (40, 0), 41: (41, 0), 42: (42, 0), 43: (43, 0), 44: (44, 0), 45: (45, 0), 46: (46, 0), 47: (47, 0), 48: (48, 0), 49: (49, 0), 50: (50, 0), 51: (51, 0), 52: (52, 0), 53: (53, 0), 54: (54, 0), 55: (55, 0), 56: (56, 0), 57: (57, 0), 58: (58, 0), 59: (59, 0), 60: (60, 0), 61: (61, 0), 62: (62, 0), 63: (63, 0), 64: (64, 0), 65: (65, 0), 66: (66, 0), 67: (67, 0), 68: (68, 0), 69: (69, 0), 70: (70, 0), 71: (71, 0), 72: (72, 0), 73: (73, 0), 74: (74, 0), 75: (75, 0), 76: (

In [11]:
# 学习Wps和Wsp
Npatts = C  # 对象数量
path_pbook = np.zeros((Np, Npatts))
for obj_id in range(C):
    x, y = object_hpc_locs[obj_id]
    path_pbook[:, obj_id] = pbook_flattened[:, flat_idx(x, y)]
print(f"✓ 提取完成: path_pbook shape {path_pbook.shape}，包含{C}个对象的Place向量")

sbook = np.array(events).T  # shape: (Ns, C)
sbook_pinv = np.linalg.pinv(sbook)
Wps = path_pbook @ sbook_pinv  # shape: (Np, Ns)
path_pbook_pinv = np.linalg.pinv(path_pbook)
Wsp = sbook @ path_pbook_pinv  # shape: (Ns, Np)
print(f"\n✓ W_ps和W_sp伪逆学习完成")

✓ 提取完成: path_pbook shape (342, 2641)，包含2641个对象的Place向量

✓ W_ps和W_sp伪逆学习完成


In [12]:
#初始化对象在gbook中的位置映射（使用HPC位置）
# gbook_obj_positions[grid_idx][obj_id] = (x, y) 记录obj在该Grid的gbook(i)中的位置
# gbook_pos_obj[grid_idx][flat_idx(x,y)] = obj_id 记录该位置上是哪个对象
gbook_obj_positions = {} 
gbook_pos_obj = {}

for grid_idx in range(N_GRIDS):
    gbook_obj_positions[grid_idx] = {}
    gbook_pos_obj[grid_idx] = {}
    
    # 为每个对象使用HPC位置来初始化
    for obj_id in range(C):
        # 从HPC位置获取该对象的坐标
        x, y = object_hpc_locs[obj_id]
        flat_i = flat_idx(x, y)
        
        # 在所有Grid中都使用相同的HPC位置
        gbook_obj_positions[grid_idx][obj_id] = (x, y)
        gbook_pos_obj[grid_idx][flat_i] = obj_id

print(f"✓ 位置映射初始化完成")

✓ 位置映射初始化完成


In [36]:
print(gbook_pos_obj[2])
print(gbook_pos_obj[2].get(1210))
print(gbook_obj_positions[2])
gbook_obj_positions[2][None] = (1,1)


{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, 18: 18, 19: 19, 20: 20, 21: 21, 22: 22, 23: 23, 24: 24, 25: 25, 26: 26, 27: 27, 28: 28, 29: 29, 30: 30, 31: 31, 32: 32, 33: 33, 34: 34, 35: 35, 36: 36, 37: 37, 38: 38, 39: 39, 40: 40, 41: 41, 42: 42, 43: 43, 44: 44, 45: 45, 46: 46, 47: 47, 48: 48, 49: 49, 50: 50, 51: 51, 52: 52, 53: 53, 54: 54, 55: 55, 56: 56, 57: 57, 58: 58, 59: 59, 60: 60, 61: 61, 62: 62, 63: 63, 64: 64, 65: 65, 66: 66, 67: 67, 68: 68, 69: 69, 70: 70, 71: 71, 72: 72, 73: 73, 74: 74, 75: 75, 76: 76, 77: 77, 78: 78, 79: 79, 80: 80, 81: 81, 82: 82, 83: 83, 84: 84, 85: 85, 86: 86, 87: 87, 88: 88, 89: 89, 90: 90, 91: 91, 92: 92, 93: 93, 94: 94, 95: 95, 96: 96, 97: 97, 98: 98, 99: 99, 100: 100, 101: 101, 102: 102, 103: 103, 104: 104, 105: 105, 106: 106, 107: 107, 108: 108, 109: 109, 110: 110, 111: 111, 112: 112, 113: 113, 114: 114, 115: 115, 116: 116, 117: 117, 118: 118, 119: 119, 120: 120, 121: 121,

In [13]:
# 调整gbooks[grid_index] 以及对应的W_pg  W_gp
theta_gp_list = [(1.0 / (epsilon**2)) * np.eye(Np) for _ in range(N_GRIDS)]
theta_pg_list = [(1.0 / (epsilon**2)) * np.eye(Ng) for _ in range(N_GRIDS)]

# 将W_pg扩展为每个Grid一份（初始化为原始的Wpg）
Wpg_list = []
for grid_idx in range(N_GRIDS):
    Wpg_grid = Wpg[:, :].copy()
    Wpg_list.append(Wpg_grid)

print(f"初始化 {N_GRIDS} 个Grid的Wpg矩阵")

# 为每个Grid维护位置计数器
pair_position_counter = {i: 0 for i in range(N_GRIDS)}

def allocate_pair_position(grid_idx, pair_position_counter, Npos):
    """
    为数据对分配目标位置
    :param grid_idx: Grid索引
    :param pair_position_counter: 位置计数器字典
    :param Npos: 网格空间大小
    :return: 返回(x_a, y_a, x_b, y_b)坐标
    """
    pairs_per_row = Npos // 2
    total_pairs = pairs_per_row * Npos
    if pair_position_counter[grid_idx] >= total_pairs:
        print(f"  Grid {grid_idx} 容量已满，无法分配更多位置")
        return None
    pair_pos = pair_position_counter[grid_idx]
    
    row_idx = pair_pos  // pairs_per_row
    col_in_row = pair_pos % pairs_per_row
    
    x = col_in_row * 2
    y = row_idx
    
    return x, y, x + 1, y

def swap_gbook_vectors(grid_idx, obj_a_id, obj_b_id, target_x_a, target_y_a, target_x_b, target_y_b):
    """
    交换gbook中的向量并更新位置映射
    :return: 返回新的位置信息
    """
    # 获取当前位置
    curr_pos_a = gbook_obj_positions[grid_idx].get(obj_a_id)
    curr_pos_b = gbook_obj_positions[grid_idx].get(obj_b_id)

    curr_flat_a = flat_idx(*curr_pos_a)
    curr_flat_b = flat_idx(*curr_pos_b)
    target_flat_a = flat_idx(target_x_a, target_y_a)
    target_flat_b = flat_idx(target_x_b, target_y_b)
    
    # 获取目标位置上原有的对象
    obj_at_target_a = gbook_pos_obj[grid_idx].get(target_flat_a)
    obj_at_target_b = gbook_pos_obj[grid_idx].get(target_flat_b)
    
    # ===== 处理对象a =====
    if curr_flat_a != target_flat_a:
        # 交换gbook中两处的向量
        temp_vec = gbooks_flattened[grid_idx][:, curr_flat_a].copy()
        gbooks_flattened[grid_idx][:, curr_flat_a] = gbooks_flattened[grid_idx][:, target_flat_a]
        gbooks_flattened[grid_idx][:, target_flat_a] = temp_vec
    
        # 更新位置映射
        gbook_obj_positions[grid_idx][obj_a_id] = (target_x_a, target_y_a)
        gbook_pos_obj[grid_idx][target_flat_a] = obj_a_id
        if obj_at_target_a is not None:
            gbook_obj_positions[grid_idx][obj_at_target_a] = curr_pos_a
            gbook_pos_obj[grid_idx][curr_flat_a] = obj_at_target_a
        else:
            # 原位置清空
            gbook_pos_obj[grid_idx].pop(curr_flat_a, None)

    # ===== 处理对象b =====
    if curr_flat_b != target_flat_b:
        # 交换gbook中两处的向量
        temp_vec = gbooks_flattened[grid_idx][:, curr_flat_b].copy()
        gbooks_flattened[grid_idx][:, curr_flat_b] = gbooks_flattened[grid_idx][:, target_flat_b]
        gbooks_flattened[grid_idx][:, target_flat_b] = temp_vec
        
        # 更新位置映射
        gbook_obj_positions[grid_idx][obj_b_id] = (target_x_b, target_y_b)
        gbook_pos_obj[grid_idx][target_flat_b] = obj_b_id
        if obj_at_target_b is not None:
            gbook_obj_positions[grid_idx][obj_at_target_b] = curr_pos_b
            gbook_pos_obj[grid_idx][curr_flat_b] = obj_at_target_b
        else:
            gbook_pos_obj[grid_idx].pop(curr_flat_b, None)
    # 返回新位置
    return target_x_a, target_y_a, target_x_b, target_y_b

# 迭代训练数据
train_count = 0
for idx, row in df.iterrows():
    obj_a = row['Obj_A']
    obj_b = row['Obj_B']
    grid_idx = int(row['grid_index'])
    
    # 获取对象ID
    obj_a_id = object_to_id.get(obj_a)
    obj_b_id = object_to_id.get(obj_b)
    
    if obj_a_id is None or obj_b_id is None:
        continue
    
    # 分配位置，若容量已满则跳过本条数据
    new_pos = allocate_pair_position(grid_idx, pair_position_counter, Npos)
    if new_pos is None:
        continue   # 该 Grid 已满，跳过，不影响其他 Grid
    target_x_a, target_y_a, target_x_b, target_y_b = new_pos
    # 交换gbook中的向量，更新位置映射
    new_x_a, new_y_a, new_x_b, new_y_b = swap_gbook_vectors(
        grid_idx, obj_a_id, obj_b_id, target_x_a, target_y_a, target_x_b, target_y_b
    )
    
    #获取id
    idx_a = flat_idx(new_x_a, new_y_a)
    idx_b = flat_idx(new_x_b, new_y_b)
    
    # 从新位置提取Grid向量
    g_a = gbooks_flattened[grid_idx][:, idx_a]
    g_b = gbooks_flattened[grid_idx][:, idx_b]
    
    # 提取place向量
    x_a, y_a = object_hpc_locs[obj_a_id]
    x_b, y_b = object_hpc_locs[obj_b_id]
    
    idx_p_a = flat_idx(x_a, y_a)
    idx_p_b = flat_idx(x_b, y_b)
    p_a = pbook_flattened[:, idx_p_a]
    p_b = pbook_flattened[:, idx_p_b]
    
    # ===== 使用RLS更新 W_gp[grid_idx] =====
    # g = W_gp @ p
    Wgp_list[grid_idx], theta_gp_list[grid_idx], _, _, _ = rls_step(
        Wgp_list[grid_idx], theta_gp_list[grid_idx], p_a, g_a
    )
    Wgp_list[grid_idx], theta_gp_list[grid_idx], _, _, _ = rls_step(
        Wgp_list[grid_idx], theta_gp_list[grid_idx], p_b, g_b
    )
    
    # ===== 使用RLS更新 W_pg[grid_idx] =====
    # p = W_pg @ g
    Wpg_list[grid_idx], theta_pg_list[grid_idx], _, _, _ = rls_step(
        Wpg_list[grid_idx], theta_pg_list[grid_idx], g_a, p_a
    )
    Wpg_list[grid_idx], theta_pg_list[grid_idx], _, _, _ = rls_step(
        Wpg_list[grid_idx], theta_pg_list[grid_idx], g_b, p_b
    )
    
    train_count += 1
    pair_position_counter[grid_idx] += 1
    
    if (train_count + 1) % 500 == 0:
        print(f"  已处理 {train_count} 条训练数据")
        # 显示某个Grid的位置分布示例
        sample_grid = 0
        sample_positions = list(gbook_obj_positions[sample_grid].values())[:5]
        print(f"    (Grid {sample_grid}样本位置: {sample_positions})")

print(f"\n✓ 训练完成，共处理 {train_count} 条数据")
print(f"  W_gp列表已更新，共 {len(Wgp_list)} 个Grid模块")
print(f"  W_pg列表已更新，共 {len(Wpg_list)} 个Grid模块")
print(f"\n各Grid的数据处理统计:")
for grid_idx in range(N_GRIDS):
    pair_count = pair_position_counter[grid_idx]
    if pair_count > 0:
        pairs_per_row = Npos // 2
        last_row = (pair_count - 1) // pairs_per_row
        print(f"  Grid {grid_idx} ({RELATION_TYPES[grid_idx]}): {pair_count}对数据，占用行0-{last_row}")



初始化 19 个Grid的Wpg矩阵
  已处理 499 条训练数据
    (Grid 0样本位置: [(394, 0), (71, 2), (174, 2), (168, 2), (412, 3)])
  已处理 999 条训练数据
    (Grid 0样本位置: [(394, 0), (71, 2), (174, 2), (168, 2), (412, 3)])
  已处理 1499 条训练数据
    (Grid 0样本位置: [(394, 0), (71, 2), (174, 2), (168, 2), (412, 3)])

✓ 训练完成，共处理 1675 条数据
  W_gp列表已更新，共 19 个Grid模块
  W_pg列表已更新，共 19 个Grid模块

各Grid的数据处理统计:
  Grid 0 (Uncategorized): 41对数据，占用行0-0
  Grid 1 (主体-产物): 141对数据，占用行0-0
  Grid 2 (主体-动作): 294对数据，占用行0-1
  Grid 3 (亲属关系): 3对数据，占用行0-0
  Grid 4 (位置/空间): 89对数据，占用行0-0
  Grid 5 (包含/种属): 313对数据，占用行0-1
  Grid 6 (反义/对立): 115对数据，占用行0-0
  Grid 7 (因果/依赖): 102对数据，占用行0-0
  Grid 8 (属性/特征): 94对数据，占用行0-0
  Grid 9 (工具-功能): 85对数据，占用行0-0
  Grid 10 (师生传承): 3对数据，占用行0-0
  Grid 11 (并列/同类): 49对数据，占用行0-0
  Grid 12 (材料-成品): 53对数据，占用行0-0
  Grid 13 (等级/排序): 10对数据，占用行0-0
  Grid 14 (组成/整体): 47对数据，占用行0-0
  Grid 15 (职业-对象): 5对数据，占用行0-0
  Grid 16 (象征/比喻): 56对数据，占用行0-0
  Grid 17 (近义/同一): 127对数据，占用行0-0
  Grid 18 (顺序/过程): 48对数据，占用行0-0


In [14]:
# 计算Wgg_list：通过伪逆学习
def create_shifted_gbook(gbook_flat, Npos):
    """
    创建向右移动一格的gbook版本
    位置(x,y)的向量移动到((x+1)%Npos, y)
    """
    
    G_next = np.zeros_like(gbook_flat)
    
    for y in range(Npos):
        for x in range(Npos):
            # 向量来自 (x, y)，去到 (x+1, y)
            new_x = (x + 1) % Npos
            old_flat = flat_idx(x, y)
            new_flat = flat_idx(new_x, y)
            G_next[:, new_flat] = gbook_flat[:, old_flat]
    
    return G_next

# 初始化Wgg_list
Wgg_list = []

# 为每个Grid计算Wgg
for grid_idx in range(N_GRIDS):
    G_current = gbooks_flattened[grid_idx]
    G_next = create_shifted_gbook(G_current, Npos)
    
    # 通过伪逆学习：Wgg = G_next @ pinv(G_current)
    Wgg = G_next @ np.linalg.pinv(G_current)
    Wgg_list.append(Wgg)
    
    print(f"✓ Grid {grid_idx} ({RELATION_TYPES[grid_idx]}) 的Wgg矩阵计算完成，shape: {Wgg.shape}")

print(f"\n✓ 所有Grid的Wgg_list计算完成，共 {len(Wgg_list)} 个矩阵")

✓ Grid 0 (Uncategorized) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 1 (主体-产物) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 2 (主体-动作) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 3 (亲属关系) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 4 (位置/空间) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 5 (包含/种属) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 6 (反义/对立) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 7 (因果/依赖) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 8 (属性/特征) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 9 (工具-功能) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 10 (师生传承) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 11 (并列/同类) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 12 (材料-成品) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 13 (等级/排序) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 14 (组成/整体) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 15 (职业-对象) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 16 (象征/比喻) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 17 (近义/同一) 的Wgg矩阵计算完成，shape: (99, 99)
✓ Grid 18 (顺序/过程) 的Wgg矩阵计算完成，shape: (99, 99)

✓ 所有Grid的Wgg_list计算完成，共 19 个矩阵


In [15]:
#test
gbook_init_transposed = np.transpose(gbook_init, (0, 2, 1))
gbook_flattened_init = gbook_init_transposed.reshape(Ng,Npos*Npos)
print(object_to_id["无业"])
id = object_to_id["无业"]

s_true = sbook[:, id]
print(s_true.shape)

x, y = object_hpc_locs[id]
idx = flat_idx(x, y)
print(idx)
p_true = pbook_flattened[:, idx]
p_input = Wps @ s_true

grid_idx = 7
gbook_current = gbooks_flattened[grid_idx]

gbook_pos = gbook_obj_positions[grid_idx][id]
g_true = gbooks_flattened[grid_idx][:, flat_idx(*gbook_pos)]
g_true2 = gbook_flattened_init[:, idx]
Niter = 1 

p = p_input
for i in range(Niter):
        g_in = Wgp_list[grid_idx] @ p

        # 使用nearest_neighbor函数，针对当前grid的gbook进行最近邻搜索
        g = nearest_neighbor(g_in, gbook_current) 

        p = nonlin(Wpg_list[grid_idx] @ g, thresh)
print("g_cleaned = ",g)
print("g_true = ",g_true)
print("g_true2 = ",g_true2)

1161
(768,)
1161
g_cleaned =  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0.]
g_true =  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0.]
g_true2 =  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0.]


## 导出VH模型状态

In [None]:
# === 导出模型状态供推理脚本使用 ===
import pickle
import os

os.makedirs('./model_state', exist_ok=True)

# 创建完整的状态字典
model_state = {
    'Wps': Wps,
    'Wsp': Wsp,
    'Wpg_list': Wpg_list,
    'Wgp_list': Wgp_list,
    'Wgg_list': Wgg_list,
    'gbooks_flat': gbooks_flattened,
    'pbook_flat': pbook_flattened,
    
    'Np': Np,
    'Ng': Ng,
    'Npos': Npos,
    'Ns': Ns,
    'thresh': thresh,
    
    'object_hpc_locs': object_hpc_locs,
    'object_to_id': object_to_id,
    'id_to_object': id_to_object,
    
    'RELATION_TYPES': RELATION_TYPES,
    'N_GRIDS': N_GRIDS,
}

# 保存为 pickle
with open('./model_state/model_state.pkl', 'wb') as f:
    pickle.dump(model_state, f)

# 定义JSON序列化转换函数
def convert_to_serializable(obj):
    """将 numpy 类型转换为 Python 原生类型"""
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

# 保存元数据为 JSON（使用自定义转换函数处理 numpy 类型）
metadata = {
    'Np': Np,
    'Ng': Ng,
    'Npos': Npos,
    'Ns': Ns,
    'thresh': thresh,
    'N_GRIDS': N_GRIDS,
    'object_hpc_locs': {str(k): v for k, v in object_hpc_locs.items()},
    'object_to_id': object_to_id,
    'RELATION_TYPES': RELATION_TYPES,
}

with open('./model_state/metadata.json', 'w', encoding='utf-8') as f:
    json.dump(metadata, f, ensure_ascii=False, indent=2, default=convert_to_serializable)

print("✓ 模型状态已导出到 ./model_state/")


✓ 模型状态已导出到 ./model_state/


In [16]:
# 测试Wgg_list：验证每个gbook的每个位置通过Wgg能正确预测下一个位置的向量
import numpy as np

grid_idx = 7
Wgg = Wgg_list[grid_idx]
gbook_flat = gbooks_flattened[grid_idx]
    
x, y = (0,0)

idx_current = flat_idx(x, y)
g_current = gbook_flat[:, idx_current]

x_next = (x + 1) % Npos
idx_next = flat_idx(x_next, y)
g_next_actual = gbook_flat[:, idx_next]
            
            # 预测下一个向量
g_next_pred = Wgg @ g_current
g = nearest_neighbor(g_next_pred, gbook_flat)


print("实际的下一个向量:", g_next_actual)
print("预测的下一个向量:", g)
print("原本的向量:", g_current)

实际的下一个向量: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.]
预测的下一个向量: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.]
原本的向量: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.]


In [17]:
id = 1161
grid_idx = 7
x,y = gbook_obj_positions[grid_idx][id]
g_current = gbooks_flattened[grid_idx][:, flat_idx(x, y)]
print(g_current)

g_next_actual = gbooks_flattened[grid_idx][:, flat_idx((x + 1) % Npos, y)]
print(g_next_actual)

p_true1 = pbook_flattened[:, flat_idx(x, y)]
p_true2 = pbook_flattened[:, flat_idx((x + 1) % Npos, y)]
import numpy as np

# 将所有事件向量堆叠成矩阵 (C, Ns)
event_vectors = np.array([events_with_word[obj] for obj in object_list])

def nearest_neighbor2(query, prototype_matrix):
    """
    寻找与 query 最相似的原型向量（基于余弦相似度）。
    
    参数:
        query: 1D numpy 数组，形状 (dim,)
        prototype_matrix: 2D numpy 数组，形状 (dim, n_prototypes)，每一列是一个原型向量
    
    返回:
        best_prototype: 最佳匹配的原型向量 (形状 (dim,))
        best_index: 该原型在矩阵中的列索引
    """
    # 计算点积 (形状: (n_prototypes,))
    dot_products = query @ prototype_matrix
    
    # 计算 query 和每个原型向量的 L2 范数
    query_norm = np.linalg.norm(query)
    prototype_norms = np.linalg.norm(prototype_matrix, axis=0)
    
    # 处理零向量情况：query 为零向量时所有相似度为 0
    if query_norm == 0:
        similarities = np.zeros(prototype_matrix.shape[1])
    else:
        # 分母 = ||query|| * ||prototype||
        denom = query_norm * prototype_norms
        # 使用 np.divide 避免除零警告，并将分母为零的位置的相似度设为 0
        similarities = np.divide(dot_products, denom,
                                 out=np.zeros_like(dot_products),
                                 where=denom != 0)
    
    # 找到最大相似度对应的索引
    best_idx = np.argmax(similarities)
    
    # 返回最佳原型向量及其索引
    return prototype_matrix[:, best_idx]

def cosine_similarity(vec, mat):
    """计算一维向量 vec 与矩阵 mat 每一行的余弦相似度"""
    # 归一化 vec
    vec_norm = np.linalg.norm(vec)
    if vec_norm > 0:
        vec_normed = vec / vec_norm
    else:
        vec_normed = vec
    
    # 归一化 mat 的每一行
    mat_norms = np.linalg.norm(mat, axis=1, keepdims=True)
    mat_norms[mat_norms == 0] = 1  # 避免除零
    mat_normed = mat / mat_norms
    
    # 计算点积（余弦相似度）
    similarities = mat_normed @ vec_normed
    return similarities

p_out1 = nearest_neighbor2(Wpg_list[grid_idx] @ g_current , pbook_flattened)
p_out2 = nearest_neighbor2(Wpg_list[grid_idx] @ g_next_actual, pbook_flattened)

s_out1 = Wsp @ p_out1
s_out2 = Wsp @ p_out2
print("p_out1 = ", p_out1)
print("p_true1 = ", p_true1)
print("p_out2 = ", p_out2)
print("p_true2 = ", p_true2)
print("s_out1 = ", s_out1)
print("s_out2 = ", s_out2)


# 计算相似度并找出最佳匹配
sim1 = cosine_similarity(s_out1, event_vectors)
best_idx1 = np.argmax(sim1)
best_word1 = object_list[best_idx1]

sim2 = cosine_similarity(s_out2, event_vectors)
best_idx2 = np.argmax(sim2)
best_word2 = object_list[best_idx2]

print(f"s_out1 最相似的词: {best_word1} (相似度: {sim1[best_idx1]:.4f})")
print(f"s_out2 最相似的词: {best_word2} (相似度: {sim2[best_idx2]:.4f})")

[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.]
p_out1 =  [-0.         -0.         -0.         -0.         -0.         -0.
 -0.         -0.         -0.         -0.         -0.         -0.
 -0.         -0.         -0.         -0.         -0.         -0.
 -0.         -0.         -0.         -0.         -0.         -0.
 -0.         -0.         -0.         -0.         -0.         -0.
 -0.         -0.         -0.          0.00603632 -0.       