In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim
import torch.nn.functional as F

In [2]:
class treeEncoder(nn.Module):
    def __init__(self,inp_dim,hid_dim,cuda):
        super(treeEncoder,self).__init__()
        self.inp_dim = inp_dim
        self.hid_dim = hid_dim
        self.cudaAvailable = cuda
        
        self.ix = nn.Linear(self.inp_dim,self.hid_dim)
        self.ih = nn.Linear(self.hid_dim,self.hid_dim)
        
        self.ux = nn.Linear(self.inp_dim,self.hid_dim)
        self.uh = nn.Linear(self.hid_dim,self.hid_dim)
        
        self.ox = nn.Linear(self.inp_dim,self.hid_dim)
        self.oh = nn.Linear(self.hid_dim,self.hid_dim)
        
        self.fx = nn.Linear(self.inp_dim,self.hid_dim)
        self.fh = nn.Linear(self.hid_dim,self.hid_dim)
        
    def getParameters(self):
        params = []
        for m in [self.ix, self.ih, self.fx, self.fh, self.ox, self.oh, self.ux, self.uh]:
            l = list(m.parameters())
            params.extend(l)

        one_dim = [p.view(p.numel()) for p in params]
        params = F.torch.cat(one_dim)
        return params
        
    def get_child_states(self, node):
        if len(node.children) == 0:
            child_c = Variable(torch.zeros(1,1,self.hid_dim))
            child_h = Variable(torch.zeros(1,1,self.hid_dim))
            
            if self.cudaAvailable:
                child_c = child_c.to('cuda:1')
                child_h = child_h.to('cuda:1')
            
        else:
            child_c = Variable(torch.Tensor(len(node.children),1,self.hid_dim))
            child_h = Variable(torch.Tensor(len(node.children),1,self.hid_dim))
            
            if self.cudaAvailable:
                child_c, child_h = child_c.to('cuda:1'), child_h.to('cuda:1')
            
            for child_num in range(len(node.children)):
                child_c[child_num] = node.children[list(node.children)[child_num]].state[0]
                child_h[child_num] = node.children[list(node.children)[child_num]].state[1]
        
        return child_c, child_h
    
    def node_forward(self,inputs, child_c, child_h):
       
        child_h_sum = F.torch.sum(torch.squeeze(child_h,1),0)
        
#         print(inputs.type())
#         print(child_h_sum.type())
        
        i = F.sigmoid(self.ix(inputs)+self.ih(child_h_sum))
        o = F.sigmoid(self.ox(inputs)+self.oh(child_h_sum))
        u = F.tanh(self.ux(inputs)+self.uh(child_h_sum))

        # add extra singleton dimension
        fx = F.torch.unsqueeze(self.fx(inputs),1)
        f = F.torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0)
        f = F.sigmoid(f)

#         f = F.torch.unsqueeze(f,1) # comment to fix dimension missmatch
        fc = F.torch.squeeze(F.torch.mul(f,child_c),1)

        c = F.torch.mul(i,u) + F.torch.sum(fc,0)
        h = F.torch.mul(o, F.tanh(c))
        
        return child_c,child_h
        return c, h
    
    def forward(self,node,userVects):        
        for child_uid in node.children:
            self.forward(node.children[child_uid],userVects)
        
        child_c,child_h = self.get_child_states(node)
        
        print(node.uid)
        node.state = self.node_forward(userVects[node.uid],child_c,child_h)

        return node.state

## Direct ChildSumTreeLSTM Code

In [3]:
class OutputModule(nn.Module):
    def __init__(self, cuda, mem_dim, num_classes, device, dropout = False):
        super(OutputModule, self).__init__()
        self.cudaFlag = cuda
        self.mem_dim = mem_dim
        self.num_classes = num_classes
        self.dropout = dropout
        self.device = device

        self.l1 = nn.Linear(self.mem_dim, self.num_classes)
        self.logsoftmax = nn.LogSoftmax()
        if self.cudaFlag:
            self.l1 = self.l1.to(device)

    def forward(self, vec, training = False):
        if self.dropout:
            out = self.logsoftmax(self.l1(F.dropout(vec, training = training)))
        else:
            out = self.logsoftmax(self.l1(vec))
        return out

In [4]:
class fixedTreeEncoder(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion,device):
        super(fixedTreeEncoder, self).__init__()
        self.cudaFlag = cuda
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.device = device
        self.labels = labels
        self.labelMap = labelMap
        self.criterion = criterion

        self.ix = nn.Linear(self.in_dim,self.mem_dim)
        self.ih = nn.Linear(self.mem_dim,self.mem_dim)

        self.fx = nn.Linear(self.in_dim,self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)

        self.ux = nn.Linear(self.in_dim,self.mem_dim)
        self.uh = nn.Linear(self.mem_dim,self.mem_dim)

        self.ox = nn.Linear(self.in_dim,self.mem_dim)
        self.oh = nn.Linear(self.mem_dim,self.mem_dim)
        
        self.userVects = userVects
        self.outputModule = OutputModule(self.cudaFlag,mem_dim,4,self.device,dropout=False)
    
    def predict(self,node):
        loss = Variable(torch.zeros(1))
        if self.cudaFlag:
            loss = loss.to(self.device)
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.userVects[node.uid],child_c,child_h)
        
        output = self.outputModule.forward(node.state[1], False)
        
        return output
    
    def forward(self,node):
        loss = Variable(torch.zeros(1))
        if self.cudaFlag:
            loss = loss.to(self.device)
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.userVects[node.uid],child_c,child_h)
        
        output = self.outputModule.forward(node.state[1], True)
        node.output = output
            
        label = Variable(torch.tensor(self.labelMap[node.label]))
                
        if self.cudaFlag:
            label = label.to(self.device)
            
        loss = loss + self.criterion(output.reshape(-1,4), label.reshape(-1))
        
        return node.state, loss
        
    def nodeForward(self,x,child_c,child_h):
        # h^~_j = sum of child hidden states
        child_h_sum = torch.sum(child_h,0)

        i = torch.sigmoid(self.ix(x) + self.ih(child_h_sum))
        o = torch.sigmoid(self.ox(x)+self.oh(child_h_sum))
        u = torch.tanh(self.ux(x)+self.uh(child_h_sum))
        
        fx = self.fx(x)
        f = torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0)
        fc = torch.sigmoid(f)
        
        c = i*u + torch.sum(fc,0)
        h = o*torch.tanh(c)
        
        return c,h
    
    def getChildStates(self,node):
        if node.num_children==0:
            child_c = Variable(torch.zeros(1,self.mem_dim))
            child_h = Variable(torch.zeros(1,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to(self.device), child_h.to(self.device)
        
        else:
            child_c = Variable(torch.Tensor(node.num_children,self.mem_dim))
            child_h = Variable(torch.Tensor(node.num_children,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to(self.device), child_h.to(self.device)
            
            for idx in range(node.num_children):
                child_c[idx] = node.childrenList[idx].state[0]
                child_h[idx] = node.childrenList[idx].state[1]
        return child_c, child_h

In [5]:
class ChildSumTreeLSTM(nn.Module):
    def __init__(self, cuda, in_dim, mem_dim):
        super(ChildSumTreeLSTM, self).__init__()
        self.cudaFlag = cuda
        self.in_dim = in_dim
        self.mem_dim = mem_dim

        self.ix = nn.Linear(self.in_dim,self.mem_dim)
        self.ih = nn.Linear(self.mem_dim,self.mem_dim)

        self.fh = nn.Linear(self.mem_dim, self.mem_dim)
        self.fx = nn.Linear(self.in_dim,self.mem_dim)

        self.ux = nn.Linear(self.in_dim,self.mem_dim)
        self.uh = nn.Linear(self.mem_dim,self.mem_dim)

        self.ox = nn.Linear(self.in_dim,self.mem_dim)
        self.oh = nn.Linear(self.mem_dim,self.mem_dim)


    def getParameters(self):
        """
        Get flatParameters
        note that getParameters and parameters is not equal in this case
        getParameters do not get parameters of output module
        :return: 1d tensor
        """
        params = []
        for m in [self.ix, self.ih, self.fx, self.fh, self.ox, self.oh, self.ux, self.uh]:
            # we do not get param of output module
            l = list(m.parameters())
            params.extend(l)

        one_dim = [p.view(p.numel()) for p in params]
        params = F.torch.cat(one_dim)
        return params


    def node_forward(self, inputs, child_c, child_h):
        """
        :param inputs: (1, 300)
        :param child_c: (num_children, 1, mem_dim)
        :param child_h: (num_children, 1, mem_dim)
        :return: (tuple)
        c: (1, mem_dim)
        h: (1, mem_dim)
        """
        
        child_h_sum = F.torch.sum(torch.squeeze(child_h,1),0)

        i = F.sigmoid(self.ix(inputs)+self.ih(child_h_sum))
        o = F.sigmoid(self.ox(inputs)+self.oh(child_h_sum))
        u = F.tanh(self.ux(inputs)+self.uh(child_h_sum))

        # add extra singleton dimension
        fx = F.torch.unsqueeze(self.fx(inputs),1)
        f = F.torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0)
        f = F.sigmoid(f)

        # f = F.torch.unsqueeze(f,1) # comment to fix dimension missmatch
        fc = F.torch.squeeze(F.torch.mul(f,child_c),1)

        c = F.torch.mul(i,u) + F.torch.sum(fc,0)
        h = F.torch.mul(o, F.tanh(c))

        return c, h

    def forward(self, tree, embs, training = False):
        """
        Child sum tree LSTM forward function
        :param tree:
        :param embs: (sentence_length, 1, 300)
        :param training:
        :return:
        """

        # add singleton dimension for future call to node_forward
        # embs = F.torch.unsqueeze(self.emb(inputs),1)

        loss = Variable(torch.zeros(1)) # init zero loss
        if self.cudaFlag:
            loss = loss.cuda()

        print('number of children to go through: ',tree.num_children)
        for idx in range(tree.num_children):
            print('current child number: ', idx,' for node: ',tree.uid)
            _, child_loss = self.forward(tree.childrenList[idx], embs, training)
            loss = loss + child_loss
        child_c, child_h = self.get_child_states(tree)
        tree.state = self.node_forward(userVects[tree.uid], child_c, child_h)
        
        return tree.state, loss

    def get_child_states(self, tree):
        """
        Get c and h of all children
        :param tree:
        :return: (tuple)
        child_c: (num_children, 1, mem_dim)
        child_h: (num_children, 1, mem_dim)
        """
        # add extra singleton dimension in middle...
        # because pytorch needs mini batches... :sad:
        if tree.num_children==0:
            child_c = Variable(torch.zeros(1,1,self.mem_dim))
            child_h = Variable(torch.zeros(1,1,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.cuda(), child_h.cuda()
        else:
            child_c = Variable(torch.Tensor(tree.num_children,1,self.mem_dim))
            child_h = Variable(torch.Tensor(tree.num_children,1,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.cuda(), child_h.cuda()
            for idx in range(tree.num_children):
                child_c[idx] = tree.childrenList[idx].state[0]
                child_h[idx] = tree.childrenList[idx].state[1]
        return child_c, child_h