In [1]:
import torch

# Graph Class Definition

In [3]:
class Graph:
    def __init__(self, batch_size, embed_size, device):
        if batch_size is None:
            return
        self.batch_size = batch_size
        self.device = device
        # init all components in the graph
        # existing nodes in the graph: each row contains the embedding vector of the original sequence for 1 node
        self.nodes = torch.zeros(0, embed_size, dtype=torch.float32, device=device)
        # type of the existing nodes: 1 dim array, each index is a type of the node, either internal nodes or leaf nodes
        self.node_types = torch.zeros(0, dtype=torch.uint8, device=device)
        # starting point of each node:1 dim array, number indicate the origin node of edge i
        self.edge_source = torch.zeros(0, dtype=torch.long, device=device)
        # ending point of each node:1 dim array, number indicate the destination node of edge i
        self.edge_dest = torch.zeros(0, dtype=torch.long, device=device)
        
        # FIXME
        # No edge feature for now...
        # not sure what it is for now, could be the graph selected for the current batch, will update
        self.owner_masks = torch.zeros(batch_size, 0, dtype=torch.uint8, device=device)
        self.last_inserted_node = torch.zeros(batch_size, dtype=torch.long, device=device)
        # current running graph in the batch
        self.running = torch.ones(batch_size, device=device, dtype=torch.uint8)


# Model Definition

## Propagator:
The message passing part of the model, we update the node vectors based on the current existing graph

In [4]:
class Propagator(torch.nn.Module):
    def __init__(self, embed_size, dropout):
        super().__init__()
        # The message size in the message-passing
        self.message_size = embed_size * 2
        # update all node vectors back to original embed size
        self.node_update_fn = torch.nn.GRUCell(self.message_size, embed_size)

        # Get the node message through a linear layer
        self.message_node = torch.nn.Linear(embed_size, self.message_size, bias=False)
        # second layer of message passing
        self.message_layer = torch.nn.Sequential(
            torch.nn.Tanh(),
            torch.nn.Linear(self.message_size, self.message_size)
        )
        self.dropout = torch.nn.Dropout(dropout)
        self._reset_parameters(embed_size)

    @staticmethod
    def _node_update_mask(graph: Graph, mask_override: torch.ByteTensor):
        return graph.owner_masks[graph.running if mask_override is None else mask_override].sum(0)>0

    def forward(self, graph: Graph, mask_override: torch.ByteTensor = None):
        # no node or edge in the graph, no need for message passing
        if graph.nodes.shape[0]==0 or graph.edge_source==0:
            return graph
        # get all node features from embedding
        node_features = self.message_node(graph.node)
        # get the source and destestion node features
        e1 = node_features.index_select(dim=0, index=graph.edge_source)
        e2 = node_features.index_select(dim=0, index=graph.edge_dest)
        messages = e1 + e2
        messages = self.message_layer(messages)
        messages = self.dropout(dropout)

        # concatnate the messages for all nodes
        inputs = torch.zeros(graph.nodes.shape[0], self.message_size, device=graph.nodes.device,
                             dtype=graph.nodes.dtype).index_add_(0, graph.edge_dest, messages).\
                             index_add_(0, graph.edge_source, messages)

        inputs = self.dropout(inputs)
        
        # now we do message passing
        updated_nodes = self.node_update_fn(inputs, graph.nodes)
        # put the updated node into the graph set
        # only update the masked node, keep the unmasked one as original
        graph.nodes = torch.where(self._node_update_mask(graph, mask_override).unsqueeze(-1), updated_nodes, graph.nodes)

        return graph

    def reset_param()