Skip to content

Input feature mismatch when using a heterogeneous dgl graph with Source node having multiple edges to target node #6494

@kai5gabriel

Description

@kai5gabriel

I have a graph where the number of source nodes (users) is not the same as the number of item nodes. Since multiple users can share the same item. When doing the forward pass "Many to One " edges and the mismatched number of features are causing this error.
What is the standard way to handle this Many to One edge relationship scenario. As the number of items are less..I have assigned it to 8 rows with 3 features each.

`import dgl
import dgl.nn as dglnn
import dgl.function as fn
import torch
import torch.nn as nn
num_features = 3

Define the nodes and edges

num_users = 10
num_items = 8
g = dgl.heterograph({
('user', 'connects', 'item'): (range(num_users), range(num_users, num_users + num_items))
})

Set node features for user and item nodes

g.nodes['user'].data['user_features'] = torch.randn(num_users, num_features)
g.nodes['item'].data['item_features'] = torch.randn(num_items, num_features)

print(g)

class HAN(nn.Module):
def init(self, g, in_size, hidden_size, out_size, num_heads):
super().init()

    # Define the layers for each node type
    self.layers = nn.ModuleDict({
        'user': dglnn.SAGEConv(in_size, out_size, 'mean'),
        'item': dglnn.SAGEConv(in_size, out_size, 'mean')
    })

def forward(self, g, inputs):
    h_dict = inputs
    for ntype in g.ntypes:
        layer = self.layers[ntype]
        g.nodes[ntype].data['h'] = h_dict[ntype]
        g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_agg'))
        h_dict[ntype] = layer(g.nodes[ntype], h_dict[ntype])
    return h_dict

Example usage

model = HAN(g, num_features, 16, 16, 2)
output = model(g, {'user': g.nodes['user'].data, 'item': g.nodes['item'].data})
`

`---------------------------------------------------------------------------
DGLError Traceback (most recent call last)
in <cell line: 11>()
9 num_users = 10
10 num_items = 8
---> 11 g = dgl.heterograph({
12 ('user', 'connects', 'item'): (range(num_users), range(num_users, num_users + num_items))
13 })

2 frames
/usr/local/lib/python3.10/dist-packages/dgl/heterograph_index.py in create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col, formats, row_sorted, col_sorted)
1290 if isinstance(formats, str):
1291 formats = [formats]
-> 1292 return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
1293 int(num_ntypes),
1294 int(num_src),

dgl/_ffi/_cython/./function.pxi in dgl._ffi._cy3.core.FunctionBase.call()

dgl/_ffi/_cython/./function.pxi in dgl._ffi._cy3.core.FuncCall()

DGLError: [18:56:58] /opt/dgl/src/graph/unit_graph.cc:71: Check failed: src->shape[0] == dst->shape[0] (10 vs. 8) : Input arrays should have the same length.`

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions