In [12]:
from torch_geometric.nn import HANConv
from torch_geometric.data import Data
import torch

# 定义输入维度和输出维度
in_channels = {'paper': 128, 'author': 64}
out_channels = 32

# 定义元路径
metadatas = [
    ['author', 'paper'],
    [('paper', 'cites', 'paper'),('paper', 'written-by', 'author'), ('author', 'written-by', 'paper')]
]

# 初始化 HANConv
conv = HANConv(in_channels, out_channels, metadatas)

# 示例数据
x_dict = {'paper': torch.randn(4, 128), 'author': torch.randn(2, 64)}
edge_index_dict = {
    ('paper', 'written-by', 'author'): torch.tensor([[0, 3], [0, 1]], dtype=torch.long),
    ('author', 'written-by', 'paper'): torch.tensor([[0, 1], [0, 1]], dtype=torch.long),
    ('paper', 'cites', 'paper'): torch.tensor([[0, 1], [1, 2]], dtype=torch.long),
}

# 数据对象
data = Data(x_dict=x_dict, edge_index_dict=edge_index_dict)

# 前向传播
out = conv(x_dict, edge_index_dict)

# 输出结果
print(out)


{'paper': tensor([[0.0082, 0.0461, 0.1396, 0.0509, 0.1497, 0.0000, 0.0000, 0.0592, 0.0000,
         0.0000, 0.0000, 0.2453, 0.0871, 0.0000, 0.2320, 0.2721, 0.0000, 0.0000,
         0.0000, 0.0000, 0.2303, 0.0000, 0.3004, 0.0000, 0.0000, 0.0000, 0.2314,
         0.0000, 0.0000, 0.3704, 0.0080, 0.0000],
        [0.0000, 0.0000, 0.4850, 0.1628, 0.0493, 0.2898, 0.0000, 0.0000, 0.2588,
         0.0678, 0.3150, 0.4700, 0.0000, 0.2587, 0.2546, 0.3685, 0.0936, 0.0643,
         0.8218, 0.3512, 0.0654, 0.1216, 0.3058, 0.0000, 0.3133, 0.0973, 0.0000,
         0.0000, 1.0420, 0.8384, 0.2642, 0.0000],
        [0.0000, 0.1478, 0.0000, 0.0245, 0.0000, 0.0000, 0.0000, 0.0000, 0.1340,
         0.0000, 0.0000, 0.0000, 0.0416, 0.1043, 0.0000, 0.0077, 0.2696, 0.0000,
         0.3268, 0.2465, 0.5880, 0.1105, 0.0000, 0.2060, 0.0000, 0.2005, 0.0000,
         0.4250, 0.3597, 0.6593, 0.0000, 0.0410],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.000

In [9]:
out['paper'].shape

torch.Size([4, 32])

In [10]:
out['author'].shape

torch.Size([2, 32])