In [1]:
"""
Purpose: To Practice implementation of 
DGL Tree LSTM models

Documentation: https://docs.dgl.ai/tutorials/models/2_small_graph/3_tree-lstm.html?highlight=treelstm
Github: https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm
"""

'\nPurpose: To Practice implementation of \nDGL Tree LSTM models\n'

In [4]:

from dgl.data.tree import SSTDataset

# Playing with the Dataset

In [7]:
verbose = True

In [38]:
tree_data = SSTDataset()

Dataset("sst", num_graphs=8544, save_path=/root/.dgl/sst)

In [8]:
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SSTDataset(mode='tiny')  # the "tiny" set has only five trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes

if verbose:
    print(f"num_vocabs = {num_vocabs}")
    print(f"num_classes = {num_classes}")

num_vocabs = 19536
num_classes = 5


In [9]:
vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word

a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():
    if token != trainset.PAD_WORD:
        print(inv_vocab[token], end=" ")

the rock is destined to be the 21st century 's new `` conan '' and that he 's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal . 

In [12]:
a_tree.ndata["x"]

tensor([-1, -1,  0,  1, -1, -1,  2, -1,  3, -1, -1, -1, -1,  4, -1,  5, -1,  0,
        -1,  6, -1, -1,  7,  8, -1,  9, -1, 10, 11, 12, 13, -1, 14, -1, 15, -1,
         8, -1, 16, -1,  4, -1, -1, 17, -1, -1, 18, 19, -1, 20, 21, -1, 22, -1,
        -1, -1, -1, -1, 23, 24, 25, -1, 26, -1, 27, 28, 29, -1, 30, 31, 32])

In [28]:
tiny_sst[0].ndata

{'mask': tensor([0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1,
        0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1,
        0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1]), 'x': tensor([-1, -1,  0,  1, -1, -1,  2, -1,  3, -1, -1, -1, -1,  4, -1,  5, -1,  0,
        -1,  6, -1, -1,  7,  8, -1,  9, -1, 10, 11, 12, 13, -1, 14, -1, 15, -1,
         8, -1, 16, -1,  4, -1, -1, 17, -1, -1, 18, 19, -1, 20, 21, -1, 22, -1,
        -1, -1, -1, -1, 23, 24, 25, -1, 26, -1, 27, 28, 29, -1, 30, 31, 32]), 'y': tensor([3, 2, 2, 2, 4, 3, 2, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 3, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 4, 3, 2, 3, 3, 2, 3,
        2, 2, 3, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

# Step 0: Batching is done by just combining the disconnected trees into one tree and doing message passing on that

# Step 1: The Model

In [19]:
import torch as th
import torch.nn as nn
from collections import namedtuple
import dgl

class TreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__()
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
        self.U_f = nn.Linear(2 * h_size, 2 * h_size)

    def message_func(self, edges):
        return {'h': edges.src['h'], 'c': edges.src['c']}

    def reduce_func(self, nodes):
        # concatenate h_jl for equation (1), (2), (3), (4)
        h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
        # equation (2)
        f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
        # second term of equation (5)
        c = th.sum(f * nodes.mailbox['c'], 1)
        return {'iou': self.U_iou(h_cat), 'c': c}

    def apply_node_func(self, nodes):
        # equation (1), (3), (4)
        iou = nodes.data['iou'] + self.b_iou
        i, o, u = th.chunk(iou, 3, 1)
        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
        # equation (5)
        c = i * u + nodes.data['c']
        # equation (6)
        h = o * th.tanh(c)
        return {'h' : h, 'c' : c}

import dgl.function as fn
import torch as th

class TreeLSTM(nn.Module):
    def __init__(
        self,
        #num_vocabs,
        dataset_num_node_features,
        dataset_num_classes,
        #h_size,
        n_hidden_channels=64,
        dropout=0.5,
        ):
        
        super(TreeLSTM, self).__init__()
        #self.x_size = x_size
        #self.embedding = nn.Embedding(num_vocabs, x_size)
#         if pretrained_emb is not None:
#             print('Using glove')
#             self.embedding.weight.data.copy_(pretrained_emb)
#             self.embedding.weight.requires_grad = True
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(
            n_hidden_channels, 
            dataset_num_classes)
        self.cell = TreeLSTMCell(
            dataset_num_node_features,
            n_hidden_channels)

    def encode(
        self,
        batch,
        h,
        c,
        embeddings):
        
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
        # to heterogenous graph
        g = dgl.graph(g.edges())
        # feed embedding
        #embeds = self.embedding(batch.wordid * batch.mask)
        embeds = embeddings
        g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds))# * batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = h
        g.ndata['c'] = c
        # propagate
        dgl.prop_nodes_topo(g,
                            message_func=self.cell.message_func,
                            reduce_func=self.cell.reduce_func,
                            apply_node_func=self.cell.apply_node_func)
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        return h
        
    def forward(
        self,
        batch,
        h,
        c,
        embeddings):
        
        h = self.encode(self,
        batch=batch,
        h=h,
        c=c,
        embeddings=embeddings)
        
        logits = self.linear(h)
        return F.softmax(logits, dim=-1)

In [None]:
num_vocabs,
x_size,
h_size,
num_classes,
dropout,

In [21]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])

device = th.device('cpu')
# hyper parameters
x_size = 256
h_size = 256
dropout = 0.5
lr = 0.05
weight_decay = 1e-4
epochs = 10

# create the model
model = TreeLSTM(trainset.num_vocabs,
                 x_size,
                 h_size,
                 trainset.num_classes,
                 dropout)
print(model)

# create the optimizer
optimizer = th.optim.Adagrad(model.parameters(),
                          lr=lr,
                          weight_decay=weight_decay)

def batcher(dev):
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
        return SSTBatch(graph=batch_trees,
                        mask=batch_trees.ndata['mask'].to(device),
                        wordid=batch_trees.ndata['x'].to(device),
                        label=batch_trees.ndata['y'].to(device))
    return batcher_dev

train_loader = DataLoader(dataset=tiny_sst,
                          batch_size=5,
                          collate_fn=batcher(device),
                          shuffle=False,
                          num_workers=0)

# training loop
for epoch in range(epochs):
    for step, batch in enumerate(train_loader):
        g = batch.graph
        n = g.number_of_nodes()
        h = th.zeros((n, h_size))
        c = th.zeros((n, h_size))
        logits = model(batch, h, c)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, batch.label, reduction='sum')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = th.argmax(logits, 1)
        acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
        print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
            epoch, step, loss.item(), acc))

TreeLSTM(
  (embedding): Embedding(19536, 256)
  (dropout): Dropout(p=0.5, inplace=False)
  (linear): Linear(in_features=256, out_features=5, bias=True)
  (cell): TreeLSTMCell(
    (W_iou): Linear(in_features=256, out_features=768, bias=False)
    (U_iou): Linear(in_features=512, out_features=768, bias=False)
    (U_f): Linear(in_features=512, out_features=512, bias=True)
  )
)




Epoch 00000 | Step 00000 | Loss 443.8172 | Acc 0.1575 |
Epoch 00001 | Step 00000 | Loss 274.1146 | Acc 0.7216 |
Epoch 00002 | Step 00000 | Loss 327.7868 | Acc 0.6081 |
Epoch 00003 | Step 00000 | Loss 493.5061 | Acc 0.7839 |
Epoch 00004 | Step 00000 | Loss 427.5103 | Acc 0.6300 |
Epoch 00005 | Step 00000 | Loss 205.8990 | Acc 0.8242 |
Epoch 00006 | Step 00000 | Loss 105.0090 | Acc 0.8864 |
Epoch 00007 | Step 00000 | Loss 93.6178 | Acc 0.8938 |
Epoch 00008 | Step 00000 | Loss 62.3927 | Acc 0.9414 |
Epoch 00009 | Step 00000 | Loss 54.9089 | Acc 0.9341 |


In [57]:
batch.label.shape

torch.Size([273])

In [55]:
G

NameError: name 'G' is not defined