Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 11, 2019
1 parent d2b816b commit 6ca3dfc
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions torch_struct/networks/TreeLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.nn as nn
import dgl


class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__()
Expand All @@ -14,27 +13,21 @@ def __init__(self, x_size, h_size):
self.U_f = nn.Linear(2 * h_size, 2 * h_size)

def message_func(self, edges):
return {"h": edges.src["h"], "c": edges.src["c"]}
return {'h': edges.src['h'], 'c': edges.src['c']}

def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
# equation (2)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
# second term of equation (5)
c = th.sum(f * nodes.mailbox["c"], 1)
return {"iou": self.U_iou(h_cat), "c": c}
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}

def apply_node_func(self, nodes):
# equation (1), (3), (4)
iou = nodes.data["iou"] + self.b_iou
iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5)
c = i * u + nodes.data["c"]
# equation (6)
c = i * u + nodes.data['c']
h = o * th.tanh(c)
return {"h": h, "c": c}
return {'h' : h, 'c' : c}


def run(cell, graph, iou, h, c):
Expand Down

0 comments on commit 6ca3dfc

Please sign in to comment.