In [1]:
%run 'tree2seq_dataloader.ipynb'

100%|██████████| 1000/1000 [00:00<00:00, 33466.62it/s]
100%|██████████| 1000/1000 [00:00<00:00, 2112999.50it/s]
100%|██████████| 1000/1000 [00:00<00:00, 7106.15it/s]

71


Columns 0 to 12 
    1     0     0     0     0     0     0     0     0     0     0     0     0

Columns 13 to 25 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 26 to 38 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 39 to 51 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 52 to 64 
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 65 to 70 
    0     0     0     0     0     0
[torch.FloatTensor of size 1x71]

(* (* (+ 55 (+ 56 53)) 31) (- 5 (* 54 9)))
( * ( * ( + 55 ( + 56 53 ) ) 31 ) ( - 5 ( * 54 9 ) ) )

(* (* (+ 55 (+ 56 53)) 31) (- 5 (* 54 9)))
['55', '56', '53', '31', '5', '54', '9']
(12 (6 (4 0 (3 1 2)) 5) (11 7 (10 8 9)))
(* (* (+ 55 (+ 56 53)) 31) (- 5 (* 54 9)))





In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as Var

# module for childsumtreelstm
class ChildSumTreeLSTM(nn.Module):
    def __init__(self, in_dim, mem_dim):
        super(ChildSumTreeLSTM, self).__init__()
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.ioux = nn.Linear(self.in_dim, 3 * self.mem_dim)
        self.iouh = nn.Linear(self.mem_dim, 3 * self.mem_dim)
        self.fx = nn.Linear(self.in_dim, self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)

    def node_forward(self, inputs, child_c, child_h):
        child_h_sum = torch.sum(child_h, dim=0, keepdim=True)

        iou = self.ioux(inputs) + self.iouh(child_h_sum)
        i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
        i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)

        f = F.sigmoid(
            self.fh(child_h) +
            self.fx(inputs).repeat(len(child_h), 1)
        )
        fc = torch.mul(f, child_c)

        c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
        h = torch.mul(o, F.tanh(c))
        return c, h

    def forward(self, tree, inputs):
        if isinstance(tree, Tree):
            child_states = []
            for child in tree:
                child_states.append(self.forward(child, inputs))
                
            child_c, child_h = zip(* map(lambda x: x, child_states))
            child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0)
            state = self.node_forward(inputs[tree.label()], child_c, child_h)
        else:
            child_c = Var(inputs[0].data.new(1, self.mem_dim).fill_(0.))
            child_h = Var(inputs[0].data.new(1, self.mem_dim).fill_(0.))
            state = self.node_forward(inputs[tree], child_c, child_h)
        return state


In [12]:
if __name__ == '__main__':
    print(dir())
    dataset = Dataset()
    dataset.create_vocab('/data2/t2t/train.orig')
    one_hot_dict = dataset.create_one_hot()

    trees = dataset.read_trees('/data2/t2t/train.orig')
    seqs = dataset.read_seqs('/data2/t2t/train.orig')

100%|██████████| 1000/1000 [00:00<00:00, 21926.19it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1489454.55it/s]

['ChildSumTreeLSTM', 'Dataset', 'F', 'In', 'Out', 'Tree', 'Var', '_', '__', '___', '__builtin__', '__builtins__', '__doc__', '__loader__', '__name__', '__package__', '__spec__', '_dh', '_i', '_i1', '_i10', '_i11', '_i12', '_i2', '_i3', '_i4', '_i5', '_i6', '_i7', '_i8', '_i9', '_ih', '_ii', '_iii', '_oh', 'copy', 'cst', 'data', 'dataset', 'exit', 'get_ipython', 'nn', 'np', 'one_hot_dict', 'operator', 'ptr_trees', 'quit', 'seqs', 'tmp', 'torch', 'tqdm', 'trees']
71





In [13]:
cst = ChildSumTreeLSTM(dataset.vector_dim, 64)

In [16]:
cst.forward(ptr_trees[1][0], Var(ptr_trees[1][1]))

(Variable containing:
 
 Columns 0 to 9 
 -0.2279 -0.4095  0.0287  0.0377  0.1976  0.0180 -0.2470  0.0252  0.0021  0.2113
 
 Columns 10 to 19 
  0.0532 -0.1306 -0.1550 -0.0267  0.0747  0.3380  0.0241  0.0096 -0.2594 -0.1655
 
 Columns 20 to 29 
  0.0813 -0.0133 -0.1468  0.0634  0.1857 -0.2634  0.0988  0.1450 -0.1498  0.2546
 
 Columns 30 to 39 
 -0.1359 -0.4108  0.2300  0.0162 -0.1968  0.0431  0.3946 -0.1576  0.0574  0.1402
 
 Columns 40 to 49 
  0.0936  0.1136  0.2470  0.2659 -0.1649  0.0587  0.0767  0.0589 -0.1570 -0.0027
 
 Columns 50 to 59 
 -0.1235  0.1413  0.3275  0.0608  0.2070  0.4714 -0.1290 -0.3743  0.1217 -0.0848
 
 Columns 60 to 63 
  0.0487  0.3895  0.0129 -0.2850
 [torch.FloatTensor of size 1x64], Variable containing:
 
 Columns 0 to 9 
 -0.1072 -0.1857  0.0126  0.0191  0.1055  0.0087 -0.1178  0.0117  0.0011  0.1119
 
 Columns 10 to 19 
  0.0264 -0.0679 -0.0872 -0.0132  0.0349  0.1558  0.0128  0.0047 -0.1152 -0.0801
 
 Columns 20 to 29 
  0.0366 -0.0065 -0.0765  0.0323  0