Skip to content

Tree-LSTM and msg/reduce/update functions #27

@jermainewang

Description

@jermainewang

We met a problem when implementing tree-lstm with the new reduce function. The Child-Sum Tree-LSTM is described as follows:

image

@GaiYu0 has implemented this when there are only msg and update functions (see codes here). However, with the new reduce function, it is tricky. Remember that the msg/reduce/update function signatures are defined as follows:

def message_func(src, dst, edge):
   # return the message to be sent along the edge.
   ...

def reduce_func(msgs):
   # return the reduced msg
   ...

def update_func(node, msg_reduced):
   # return the new node state
   ...

Let's look at the equations. Apparently, equation (2) is a summation that should be implemented in the reduce function. Equations (3), (5), (6) rely on the node state of the receiver node (x_j) and the reduced hidden state (h_tilda), so they should be put in the update function. Equations (4) and (7) are problematic. Equation (4) needs to compute a "pair-wise" transformation using the node state of the receiver and sender nodes, and the result is applied as the weight in the summation in equation (7). This computation paradigm is similar to the GAT model but without the softmax. As a result, equation (4) and (7) should be put in the reduce function but the reduce function has no access to the receiver state (x_j). This means the user needs to also include the dst node state (x_j) in the message. @ylfdq1118 also mentioned this problem in the GAT model before. As a result, we propose to change the msg/reduce/update signatures to followings:

def message_func(src, edge):
   # return the message to be sent along the edge.
   ...

def update_edge_func(src, dst, edge):
   # return the new edge state
   ...

def reduce_func(node, msgs):
   # return the reduced msg
   ...

def update_func(node, msg_reduced):
   # return the new node state
   ...

Change#1: The message function no longer has access to the dst node state. By contrast, the update edge function still has access to both end points.
Change#2: The reduce function now has access to the receiver node state.

This seems to be a step back as the new reduce function signature is equal to the old update function signature before we decided to add reduce. Reduce function is a very powerful tool as user can fuse the message and update function into a single reduce function now:

def message_func(src, edge):
  # Dummy message function that simply returns all the inputs.
  return {'src' : src, 'edge': edge}
def reduce_func(node, msgs):
  # The reduce function now has access to both the dst node and all the src nodes.
  ...
def update_func(node, msg_reduced):
  # Dummy update function that simply returns the reduced msg.
  return msg_reduced

We think this change is necessary as message/reduce/update functions have different batching approach: message function is batched by all the edges; reduce function is batched by nodes with the same degree; update function is batched by all the nodes. Since batching by all the edges/nodes are more efficient by the constrained batching of the reduce function, it is user's responsibility to make the reduce function as light as possible while putting more logic in the message/update functions. The best case is when reduce function is the built-in reducer (e.g. "sum"); the worst case is when all the logic is fused in the reduce function. For example, a good implementation of tree-lstm using the new signature is as follows. You can see how the equations are computed in different functions.

def message_func(src, edge):
  return {'h' : src['h'], 'c' : src['c']}
def reduce_func(node, msgs):
  # equation (2)
  h_tild = th.sum(msgs['h'], 1)
  # equation (4)
  wx = th.mm(v['x'], W_f).unsqueeze(1)
  uh = th.mm(msgs['h'], U_f)
  f = th.sigmoid(wx + uh + b_f)
  # equation (7) second term
  c_tild = th.sum(f * msgs['c'], 1)
  return {'h_tild' : h_tild, 'c_tild' : c_tild}
def update_func(v, accum):
  # equation (3), (5), (6)
  iou = th.mm(v['x'], W_iou) + th.mm(accum['h_tild'], self.U_iou) + self.b_iou
  i, o, u = th.chunk(iou, 3, 1)
  i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
  # equation (7)
  c = i * u + accum['c_tild']
  # equation (8)
  h = o * th.tanh(c)
  return {'h' : h, 'c' : c}

I may have just brought up the problem that has been found by @ylfdq1118 when implementing the GAT model.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions