In [1]:
import torch
from torch_geometric.data import HeteroData

# Initialize the node features for each type of node
# For this example, let's say each node A has 5 features, and nodes B1 and B2 do not have features
num_nodes_A = 1000
num_nodes_B1 = 500
num_nodes_B2 = 500

x_A = torch.randn((num_nodes_A, 5))  # 5 features for each node A
x_B1 = torch.zeros((num_nodes_B1, 0))  # no features for nodes B1
x_B2 = torch.zeros((num_nodes_B2, 0))  # no features for nodes B2

# Initialize the edges for each type of edge
# For this example, let's say we have 2000 edges between nodes A and B1, 
# and 3000 edges between nodes B1 and B2
num_edges_A_B1 = 2000
num_edges_B1_B2 = 3000

edge_index_A_B1 = torch.randint(0, num_nodes_A, (2, num_edges_A_B1))
edge_index_B1_B2 = torch.randint(0, num_nodes_B1, (2, num_edges_B1_B2))

# Initialize the labels for each type of edge
# For this example, let's say we're doing binary classification, so labels are 0 or 1
y_A_B1 = torch.randint(0, 2, (num_edges_A_B1,))
y_B1_B2 = torch.randint(0, 2, (num_edges_B1_B2,))

# Create a HeteroData object
data = HeteroData()

# Add node data
data['A'].x = x_A
data['B1'].x = x_B1
data['B2'].x = x_B2

# Add edge data
data['A', 'connects', 'B1'].edge_index = edge_index_A_B1
data['A', 'connects', 'B1'].y = y_A_B1

data['B1', 'connects', 'B2'].edge_index = edge_index_B1_B2
data['B1', 'connects', 'B2'].y = y_B1_B2

In [8]:
x_B1

tensor([], size=(500, 0))

In [7]:
data['B1', 'connects', 'B2'].edge_index

tensor([[173, 436, 416,  ...,  48,  74, 306],
        [289, 254, 185,  ...,  68, 143, 399]])

In [9]:
data['A', 'connects', 'B1'].y

tensor([0, 1, 1,  ..., 0, 1, 1])

In [10]:
torch.ones(10)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(-1, hidden_channels, add_self_loops=False)
        self.conv2 = GCNConv(-1, out_channels, add_self_loops=False)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class Classifier(torch.nn.Module):
    def forward(self, x, edge_index):
        edge_feat_B1 = x["B1"][edge_index[("B1", "infects", "A")][0]]
        edge_feat_A = x["A"][edge_index[("B1", "infects", "A")][1]]
        return (edge_feat_B1 * edge_feat_A).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.gnn_B2_B1 = GNN(hidden_channels, out_channels)
        self.gnn_B1_A = GNN(hidden_channels, out_channels)
        self.classifier = Classifier()
    def forward(self, graph_data):
        # Propagate B2 features to B1
        x_B2 = graph_data['B2'].x
        edge_index_B2_B1 = graph_data[('B2', 'expressed', 'B1')].edge_index
        x_B1_from_B2 = self.gnn_B2_B1(x_B2, edge_index_B2_B1)
        # Propagate new B1 features to A
        edge_index_B1_A = graph_data[('B1', 'infects', 'A')].edge_index
        x_A_from_B1 = self.gnn_B1_A(x_B1_from_B2, edge_index_B1_A)
        # Classification based on new features
        x = {'B1': x_B1_from_B2, 'A': x_A_from_B1}
        pred = self.classifier(x, edge_index_B1_A)
        
        return pred