## Tree-LSTM networks for sentiment analysis 

We will use Constituency Binary Tree-LSTM (instead of Child-Sum) which is  a generalization of long short-term memory (LSTM) networks to tree-structured network topologies using DGL to build a "latent tree".

### Stanford Sentiment Treebank Dataset

The dataset provides a fine-grained, tree-level sentiment annotation. There are five classes: Very negative, negative, neutral, positive, and very positive, which indicate the sentiment in the current subtree. Note: Non-leaf nodes in a constituency tree do not contain words.

In [23]:
import dgl
from dgl.data.tree import SST
from dgl.data import SSTBatch

import networkx as nx
import matplotlib.pyplot as pl

import torch as th
import torch.nn as nn

import dgl.function as fn
import torch as th

from torch.utils.data import DataLoader
import torch.nn.functional as F

In [10]:
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SST(mode='tiny')  # the "tiny" set has only five trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes

vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word

Preprocessing...
Dataset creation finished. #Trees: 5


In [11]:
a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():
    if token != trainset.PAD_WORD:
        print(inv_vocab[token], end=" ")

the rock is destined to be the 21st century 's new `` conan '' and that he 's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal . 

In [12]:
graph = dgl.batch(tiny_sst)

In [13]:
def plot_tree(g):
    # this plot requires pygraphviz package
    pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
    nx.draw(g, pos, with_labels=False, node_size=10,
            node_color=[[.5, .5, .5]], arrowsize=4)
    plt.show()

In [32]:
#plot_tree(graph.to_networkx())

In [17]:
class TreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__()
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
        self.U_f = nn.Linear(2 * h_size, 2 * h_size)

    def message_func(self, edges):
        return {'h': edges.src['h'], 'c': edges.src['c']}

    def reduce_func(self, nodes):
        # concatenate h_jl for equation (1), (2), (3), (4)
        h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
        # equation (2)
        f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
        # second term of equation (5)
        c = th.sum(f * nodes.mailbox['c'], 1)
        return {'iou': self.U_iou(h_cat), 'c': c}

    def apply_node_func(self, nodes):
        # equation (1), (3), (4)
        iou = nodes.data['iou'] + self.b_iou
        i, o, u = th.chunk(iou, 3, 1)
        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
        # equation (5)
        c = i * u + nodes.data['c']
        # equation (6)
        h = o * th.tanh(c)
        return {'h' : h, 'c' : c}

In [18]:
print('Traversing one tree:')
print(dgl.topological_nodes_generator(a_tree))

print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph))

Traversing one tree:
(tensor([ 2,  3,  6,  8, 13, 15, 17, 19, 22, 23, 25, 27, 28, 29, 30, 32, 34, 36,
        38, 40, 43, 46, 47, 49, 50, 52, 58, 59, 60, 62, 64, 65, 66, 68, 69, 70]), tensor([ 1, 21, 26, 45, 48, 57, 63, 67]), tensor([24, 44, 56, 61]), tensor([20, 42, 55]), tensor([18, 54]), tensor([16, 53]), tensor([14, 51]), tensor([12, 41]), tensor([11, 39]), tensor([10, 37]), tensor([35]), tensor([33]), tensor([31]), tensor([9]), tensor([7]), tensor([5]), tensor([4]), tensor([0]))
Traversing many trees at the same time:
(tensor([  2,   3,   6,   8,  13,  15,  17,  19,  22,  23,  25,  27,  28,  29,
         30,  32,  34,  36,  38,  40,  43,  46,  47,  49,  50,  52,  58,  59,
         60,  62,  64,  65,  66,  68,  69,  70,  74,  76,  78,  79,  82,  83,
         85,  88,  90,  92,  93,  95,  96, 100, 102, 103, 105, 109, 110, 112,
        113, 117, 118, 119, 121, 125, 127, 129, 130, 132, 133, 135, 138, 140,
        141, 142, 143, 150, 152, 153, 155, 158, 159, 161, 162, 164, 168, 170,
  

In [20]:
graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
graph.register_message_func(fn.copy_src('a', 'a'))
graph.register_reduce_func(fn.sum('a', 'a'))

In [21]:
traversal_order = dgl.topological_nodes_generator(graph)
graph.prop_nodes(traversal_order)

In [22]:
class TreeLSTM(nn.Module):
    def __init__(self,
                 num_vocabs,
                 x_size,
                 h_size,
                 num_classes,
                 dropout,
                 pretrained_emb=None):
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = nn.Embedding(num_vocabs, x_size)
        if pretrained_emb is not None:
            print('Using glove')
            self.embedding.weight.data.copy_(pretrained_emb)
            self.embedding.weight.requires_grad = True
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(h_size, num_classes)
        self.cell = TreeLSTMCell(x_size, h_size)

    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        g.register_apply_node_func(self.cell.apply_node_func)
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
        g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = h
        g.ndata['c'] = c
        # propagate
        dgl.prop_nodes_topo(g)
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        logits = self.linear(h)
        return logits

In [24]:
device = th.device('cpu')
# hyper parameters
x_size = 256
h_size = 256
dropout = 0.5
lr = 0.05
weight_decay = 1e-4
epochs = 10

In [25]:
# create the model
model = TreeLSTM(trainset.num_vocabs,
                 x_size,
                 h_size,
                 trainset.num_classes,
                 dropout)
print(model)

TreeLSTM(
  (embedding): Embedding(19536, 256)
  (dropout): Dropout(p=0.5, inplace=False)
  (linear): Linear(in_features=256, out_features=5, bias=True)
  (cell): TreeLSTMCell(
    (W_iou): Linear(in_features=256, out_features=768, bias=False)
    (U_iou): Linear(in_features=512, out_features=768, bias=False)
    (U_f): Linear(in_features=512, out_features=512, bias=True)
  )
)


In [26]:
# create the optimizer
optimizer = th.optim.Adagrad(model.parameters(),
                          lr=lr,
                          weight_decay=weight_decay)

In [28]:
def batcher(dev):
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
        return SSTBatch(graph=batch_trees,
                        mask=batch_trees.ndata['mask'].to(device),
                        wordid=batch_trees.ndata['x'].to(device),
                        label=batch_trees.ndata['y'].to(device))
    return batcher_dev

In [29]:
train_loader = DataLoader(dataset=tiny_sst,
                          batch_size=5,
                          collate_fn=batcher(device),
                          shuffle=False,
                          num_workers=0)

In [30]:
# training loop
for epoch in range(epochs):
    for step, batch in enumerate(train_loader):
        g = batch.graph
        n = g.number_of_nodes()
        h = th.zeros((n, h_size))
        c = th.zeros((n, h_size))
        logits = model(batch, h, c)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, batch.label, reduction='sum')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = th.argmax(logits, 1)
        acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
        print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
            epoch, step, loss.item(), acc))

Epoch 00000 | Step 00000 | Loss 442.1177 | Acc 0.2125 |
Epoch 00001 | Step 00000 | Loss 276.3337 | Acc 0.7436 |
Epoch 00002 | Step 00000 | Loss 1202.9537 | Acc 0.3370 |
Epoch 00003 | Step 00000 | Loss 502.2912 | Acc 0.5311 |
Epoch 00004 | Step 00000 | Loss 170.1425 | Acc 0.8352 |
Epoch 00005 | Step 00000 | Loss 232.8949 | Acc 0.7436 |
Epoch 00006 | Step 00000 | Loss 167.7142 | Acc 0.8425 |
Epoch 00007 | Step 00000 | Loss 105.0815 | Acc 0.8755 |
Epoch 00008 | Step 00000 | Loss 82.5724 | Acc 0.9377 |
Epoch 00009 | Step 00000 | Loss 62.3340 | Acc 0.9377 |


Source:
    - https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm
    - https://docs.dgl.ai/tutorials/models/2_small_graph/3_tree-lstm.html#sphx-glr-tutorials-models-2-small-graph-3-tree-lstm-py