-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Description
I have a graph where the number of source nodes (users) is not the same as the number of item nodes. Since multiple users can share the same item. When doing the forward pass "Many to One " edges and the mismatched number of features are causing this error.
What is the standard way to handle this Many to One edge relationship scenario. As the number of items are less..I have assigned it to 8 rows with 3 features each.
`import dgl
import dgl.nn as dglnn
import dgl.function as fn
import torch
import torch.nn as nn
num_features = 3
Define the nodes and edges
num_users = 10
num_items = 8
g = dgl.heterograph({
('user', 'connects', 'item'): (range(num_users), range(num_users, num_users + num_items))
})
Set node features for user and item nodes
g.nodes['user'].data['user_features'] = torch.randn(num_users, num_features)
g.nodes['item'].data['item_features'] = torch.randn(num_items, num_features)
print(g)
class HAN(nn.Module):
def init(self, g, in_size, hidden_size, out_size, num_heads):
super().init()
# Define the layers for each node type
self.layers = nn.ModuleDict({
'user': dglnn.SAGEConv(in_size, out_size, 'mean'),
'item': dglnn.SAGEConv(in_size, out_size, 'mean')
})
def forward(self, g, inputs):
h_dict = inputs
for ntype in g.ntypes:
layer = self.layers[ntype]
g.nodes[ntype].data['h'] = h_dict[ntype]
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_agg'))
h_dict[ntype] = layer(g.nodes[ntype], h_dict[ntype])
return h_dict
Example usage
model = HAN(g, num_features, 16, 16, 2)
output = model(g, {'user': g.nodes['user'].data, 'item': g.nodes['item'].data})
`
`---------------------------------------------------------------------------
DGLError Traceback (most recent call last)
in <cell line: 11>()
9 num_users = 10
10 num_items = 8
---> 11 g = dgl.heterograph({
12 ('user', 'connects', 'item'): (range(num_users), range(num_users, num_users + num_items))
13 })
2 frames
/usr/local/lib/python3.10/dist-packages/dgl/heterograph_index.py in create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col, formats, row_sorted, col_sorted)
1290 if isinstance(formats, str):
1291 formats = [formats]
-> 1292 return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
1293 int(num_ntypes),
1294 int(num_src),
dgl/_ffi/_cython/./function.pxi in dgl._ffi._cy3.core.FunctionBase.call()
dgl/_ffi/_cython/./function.pxi in dgl._ffi._cy3.core.FuncCall()
DGLError: [18:56:58] /opt/dgl/src/graph/unit_graph.cc:71: Check failed: src->shape[0] == dst->shape[0] (10 vs. 8) : Input arrays should have the same length.`