In [1]:
"""
层权重共享（Weight Sharing）

层权重共享是指在网络的不同位置复用同一个层实例，这些位置共享相同的权重。

应用场景：
1. 孪生网络（Siamese Network）：比较两个输入的相似度
2. 参数减少：共享权重显著降低模型参数量
3. 对偶输入处理：如问答系统、图像对比、签名验证等

技术要点：
- 同一个Layer对象被多次调用
- 所有调用共享相同的权重矩阵
- 梯度更新会同时影响所有使用该层的分支

本示例：构建问题-答案相似度判别模型
"""

from keras import layers
from keras.layers import Input
from keras.models import Model
import numpy as np

# 设置随机种子
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# 实例化一个LSTM层（将被共享）
shared_lstm = layers.LSTM(32)

# 左分支：处理第一个输入序列
left_input = Input(shape=(None, 128), name='left_input')
left_output = shared_lstm(left_input)

# 右分支：处理第二个输入序列，复用相同的LSTM层
right_input = Input(shape=(None, 128), name='right_input')
right_output = shared_lstm(right_input)

# 合并两个分支的输出
merged = layers.concatenate([left_output, right_output], axis=-1)

# 添加分类层：判断两个序列的相似度
predictions = layers.Dense(1, activation='sigmoid')(merged)

# 构建双输入单输出模型
model = Model([left_input, right_input], predictions)
model.compile(optimizer='rmsprop', 
              loss='binary_crossentropy', 
              metrics=['acc'])

model.summary()

# 生成模拟数据进行验证
num_samples = 1000
timesteps = 10

# 模拟左右两个序列输入
left_data = np.random.randn(num_samples, timesteps, 128).astype('float32')
right_data = np.random.randn(num_samples, timesteps, 128).astype('float32')

# 模拟二分类标签（0表示不相似，1表示相似）
targets = np.random.randint(0, 2, size=(num_samples, 1)).astype('float32')

# 训练模型
print("\n开始训练...")
history = model.fit([left_data, right_data], targets, 
                    epochs=3, 
                    batch_size=64, 
                    validation_split=0.2)

# 验证权重共享
print(f"\n模型总参数量: {model.count_params()}")
print(f"LSTM参数量: {shared_lstm.count_params()}")
print("注意：由于权重共享，两个分支只使用一组LSTM参数")

2025-12-12 22:20:58.655332: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-12 22:20:58.661780: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-12 22:20:58.670013: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-12 22:20:58.672369: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-12 22:20:58.678423: I tensorflow/core/platform/cpu_feature_guar

I0000 00:00:1765549259.838472   34790 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1765549259.861037   34790 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1765549259.861893   34790 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1765549259.863891   34790 cuda_executor.cc:1015] successful NUMA node read from SysFS ha


开始训练...
Epoch 1/3


2025-12-12 22:21:00.891193: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 91301


[1m 1/13[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m9s[0m 755ms/step - acc: 0.4688 - loss: 0.7470

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 17ms/step - acc: 0.4812 - loss: 0.7120 - val_acc: 0.4900 - val_loss: 0.6987


Epoch 2/3


[1m 1/13[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 18ms/step - acc: 0.5156 - loss: 0.6902

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - acc: 0.5838 - loss: 0.6661 - val_acc: 0.4850 - val_loss: 0.6990


Epoch 3/3


[1m 1/13[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m0s[0m 20ms/step - acc: 0.6562 - loss: 0.6385

[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - acc: 0.6725 - loss: 0.6355 - val_acc: 0.5150 - val_loss: 0.6999



模型总参数量: 20673
LSTM参数量: 20608
注意：由于权重共享，两个分支只使用一组LSTM参数
