In [3]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim
import torch.nn.functional as F

In [None]:
from tqdm import tqdm

## Direct ChildSumTreeLSTM Code

In [3]:
class OutputModule(nn.Module):
    def __init__(self, cuda, mem_dim, num_classes, device, dropout = False):
        super(OutputModule, self).__init__()
        self.cudaFlag = cuda
        self.mem_dim = mem_dim
        self.num_classes = num_classes
        self.dropout = dropout
        self.device = device

        self.l1 = nn.Linear(self.mem_dim, self.num_classes)
        self.logsoftmax = nn.LogSoftmax()

    def forward(self, vec, training = False):
        if self.dropout:
            out = self.logsoftmax(self.l1(F.dropout(vec, training = training)))
        else:
            out = self.logsoftmax(self.l1(vec))
        return out

In [53]:
class ChildSumTreeLSTM(nn.Module):
    def __init__(self,cuda, in_dim, mem_dim,userVects,labels,labelMap,criterion,device):
        super(ChildSumTreeLSTM, self).__init__()
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.ioux = nn.Linear(self.in_dim, 3 * self.mem_dim)
        self.iouh = nn.Linear(self.mem_dim, 3 * self.mem_dim)
        self.fx = nn.Linear(self.in_dim, self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)
        self.fc = nn.Linear(self.mem_dim,4)

    def node_forward(self, inputs, child_c, child_h):
        child_h_sum = torch.sum(child_h, dim=0, keepdim=True)

        iou = self.ioux(inputs) + self.iouh(child_h_sum)
        i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
        i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)

        f = torch.sigmoid(
            self.fh(child_h) +
            self.fx(inputs).repeat(len(child_h), 1)
        )
        fc = torch.mul(f, child_c)

        c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
        h = torch.mul(o, torch.tanh(c))
        return c, h

    def forward(self, tree, inputs):
        for idx in range(tree.num_children):
            self.forward(tree.childrenList[idx], inputs)

        if tree.num_children == 0:
            child_c = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
            child_h = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
        else:
            child_c, child_h = zip(* map(lambda x: tree.children[x].state, tree.children))
            child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0)
        tree.state = self.node_forward(inputs[tree.uid], child_c, child_h)
        
        out = self.fc(tree.state[1])
        return tree.state, out

In [4]:
class treeEncoder(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion,device):
        super(treeEncoder, self).__init__()
        self.cudaFlag = cuda
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.device = device
        self.labels = labels
        self.labelMap = labelMap
        self.criterion = criterion

        self.ix = nn.Linear(self.in_dim,self.mem_dim)
        self.ih = nn.Linear(self.mem_dim,self.mem_dim)

        self.fx = nn.Linear(self.in_dim,self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)

        self.ux = nn.Linear(self.in_dim,self.mem_dim)
        self.uh = nn.Linear(self.mem_dim,self.mem_dim)

        self.ox = nn.Linear(self.in_dim,self.mem_dim)
        self.oh = nn.Linear(self.mem_dim,self.mem_dim)
        
        self.userVects = userVects
        self.outputModule = OutputModule(self.cudaFlag,mem_dim,4,self.device,dropout=False)
    
    def predict(self,node):
        loss = Variable(torch.zeros(1))
        
        for i in range(node.num_children):
            _, _ = self.forward(node.childrenList[i])
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.userVects[node.uid].to(self.device),child_c,child_h)
        
        output = self.outputModule.forward(node.state[1], False)
        
        return output
    
    def forward(self,node):
        loss = Variable(torch.zeros(1))
        
        if self.cudaFlag:
            loss = loss.to(self.device)
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.userVects[node.uid].to(self.device),child_c,child_h)
        
        output = self.outputModule.forward(node.state[1], True)
        node.output = output

        label = Variable(torch.tensor(self.labelMap[node.label]))
            
        loss = loss + self.criterion(output.reshape(-1,4), label.reshape(-1))
        
#         print(loss)
        return node.state,loss
        
    def nodeForward(self, inputs, child_c, child_h):
        child_h_sum = torch.sum(child_h, dim=0, keepdim=True)

        i = torch.sigmoid(self.ix(inputs) + self.ih(child_h_sum))
        o = torch.sigmoid(self.ox(inputs)+self.oh(child_h_sum))
        u = torch.tanh(self.ux(inputs)+self.uh(child_h_sum))
        
        f = torch.sigmoid(
            self.fh(child_h) +
            self.fx(inputs).repeat(len(child_h), 1)
        )
        fc = torch.mul(f, child_c)

        c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
        h = torch.mul(o, torch.tanh(c))
        return c, h
    
    def getChildStates(self,node):
        if node.num_children==0:
            child_c = Variable(torch.zeros(1,self.mem_dim))
            child_h = Variable(torch.zeros(1,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to(self.device), child_h.to(self.device)
        
        else:
            child_c = Variable(torch.Tensor(node.num_children,self.mem_dim))
            child_h = Variable(torch.Tensor(node.num_children,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to(self.device), child_h.to(self.device)
            
            for idx in range(node.num_children):
                child_c[idx] = node.childrenList[idx].state[0]
                child_h[idx] = node.childrenList[idx].state[1]
        return child_c, child_h

# Temporal Tree LSTM

In [9]:
a = [torch.tensor([1,2]),torch.tensor([3,4]),torch.tensor([2,5])]

In [50]:
class lstmTreeEncoder(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion,device):
        super(lstmTreeEncoder, self).__init__()
        self.device = device
#         self.hidden_size = int(mem_dim/2)
        self.hidden_size = mem_dim
        
        self.treeEnc = treeEncoder(cuda,in_dim,mem_dim,userVects,labels,labelMap,criterion,device)
        self.topLevelLSTM = nn.LSTM(mem_dim,self.hidden_size,batch_first=False)
        self.fc = nn.Linear(self.hidden_size, 4)
    
    def forward(self,listOfIncTrees):
        inp = []
        
        h0 = torch.zeros(1, 1, self.hidden_size).requires_grad_().to(self.device)
        c0 = torch.zeros(1, 1, self.hidden_size).requires_grad_().to(self.device)
        
        for tree in listOfIncTrees[:20]:
            inp.append(self.treeEnc(tree.root)[0][1])
#         inp.append(self.treeEnc(listOfIncTrees[-1].root)[0][1])
    
        inp = torch.stack(inp)
#         print(inp.shape)
#         inp = inp.reshape(inp.shape[0],1,inp.shape[1])
        out, (hn,cn) = self.topLevelLSTM(inp, (h0, c0))

        out = out[-1]
        out = self.fc(out)
        return out
#         return list((0,out)),0

## Temporal Decay Tree Model

In [None]:
import math

In [29]:
class decayTreeEncoder(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion,device):
        super(decayTreeEncoder, self).__init__()
        self.cudaFlag = cuda
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.device = device
        self.labels = labels
        self.labelMap = labelMap
        self.criterion = criterion

        self.ix = nn.Linear(self.in_dim,self.mem_dim)
        self.ih = nn.Linear(self.mem_dim,self.mem_dim)

        self.fx = nn.Linear(self.in_dim,self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)

        self.ux = nn.Linear(self.in_dim,self.mem_dim)
        self.uh = nn.Linear(self.mem_dim,self.mem_dim)

        self.ox = nn.Linear(self.in_dim,self.mem_dim)
        self.oh = nn.Linear(self.mem_dim,self.mem_dim)
        
        self.userVects = userVects
        self.outputModule = OutputModule(self.cudaFlag,mem_dim,4,self.device,dropout=False)
    
    def predict(self,node):
        self.startTime = node.time_stamp
        loss = Variable(torch.zeros(1))
        if self.cudaFlag:
            loss = loss.to(self.device)
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        timediff = self.startTime - node.time_stamp
        node.state = self.nodeForward(self.userVects[node.uid],child_c,child_h,timediff)
        
        output = self.outputModule.forward(node.state[1], False)
        
        return output
    
#     def forward(self,node):
        
#         loss = Variable(torch.zeros(1))
        
#         for i in range(node.num_children):
#             _, child_loss = self.forward(node.childrenList[i])
#             loss = loss + child_loss
#         child_c, child_h = self.getChildStates(node)
#         timediff = self.startTime - node.time_stamp
#         node.state = self.nodeForward(self.userVects[node.uid].to(self.device),child_c,child_h,timediff)
        
#         output = self.outputModule.forward(node.state[1], True)
#         node.output = output

#         label = Variable(torch.tensor(self.labelMap[node.label]))
            
#         loss = loss + self.criterion(output.reshape(-1,4), label.reshape(-1).to(self.device))
#         lossval = loss.to('cpu')
#         del loss
#         torch.cuda.empty_cache()
        
#         return node.state, lossval
    
    def forward(self,node):
        self.startTime = node.time_stamp
        loss = Variable(torch.zeros(1))
        
        if self.cudaFlag:
            loss = loss.to(self.device)
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        timediff = self.startTime - node.time_stamp
        node.state = self.nodeForward(self.userVects[node.uid].to(self.device),child_c,child_h,timediff)
        
        output = self.outputModule.forward(node.state[1], True)
        node.output = output

        label = Variable(torch.tensor(self.labelMap[node.label]))
            
        loss = loss + self.criterion(output.reshape(-1,4), label.reshape(-1))
        
#         print(loss)
        return node.state,loss
    
    def nodeForward(self, inputs, child_c, child_h,timediff):
        child_h_sum = torch.sum(child_h, dim=0, keepdim=True)

        i = torch.sigmoid(self.ix(inputs) + self.ih(child_h_sum))
        o = torch.sigmoid(self.ox(inputs)+self.oh(child_h_sum))
        u = torch.tanh(self.ux(inputs)+self.uh(child_h_sum))
        
        try:
            decayF = math.exp(timediff)
        except OverflowError:
            decayF = float('inf')
        
        f = torch.sigmoid(
            (self.fh(child_h) +
            self.fx(inputs).repeat(len(child_h), 1))*(math.exp(-1*decayF/60))
        )
        fc = torch.mul(f, child_c)

        c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
        h = torch.mul(o, torch.tanh(c))
        return c, h
    
    def getChildStates(self,node):
        if node.num_children==0:
            child_c = Variable(torch.zeros(1,self.mem_dim))
            child_h = Variable(torch.zeros(1,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to(self.device), child_h.to(self.device)
        
        else:
            child_c = Variable(torch.Tensor(node.num_children,self.mem_dim))
            child_h = Variable(torch.Tensor(node.num_children,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to(self.device), child_h.to(self.device)
            
            for idx in range(node.num_children):
                child_c[idx] = node.childrenList[idx].state[0]
                child_h[idx] = node.childrenList[idx].state[1]
        return child_c, child_h

In [None]:
class temporalDecayTreeEncoder(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion,device):
        super(temporalDecayTreeEncoder, self).__init__()
        self.device = device
        self.hidden_size = mem_dim
        
        self.treeEnc = decayTreeEncoder(cuda,in_dim,mem_dim,userVects,labels,labelMap,criterion,device)
        self.topLevelLSTM = nn.LSTM(mem_dim,self.hidden_size,batch_first=False)
        self.fc = nn.Linear(self.hidden_size, 4)
    
    def forward(self,listOfIncTrees):
        inp = []
        
        h0 = torch.zeros(1, 1, self.hidden_size).requires_grad_().to(self.device)
        c0 = torch.zeros(1, 1, self.hidden_size).requires_grad_().to(self.device)
        
        for tree in listOfIncTrees[:20]:
            inp.append(self.treeEnc(tree.root)[0][1])
#         inp.append(self.treeEnc(listOfIncTrees[-1].root)[0][1])
    
        inp = torch.stack(inp)
#         print(inp.shape)
#         inp = inp.reshape(inp.shape[0],1,inp.shape[1])
        out, (hn,cn) = self.topLevelLSTM(inp, (h0, c0))
        
#         out = out.reshape(-1)
        out = out[-1]
        out = self.fc(out)
        return out
#         return list((0,out)),0

## Tree Encoder + Text Model

In [18]:
class treeText(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion,device,vocabSize,textEncState=None):
        super(treeText,self).__init__()
        
        self.treeEnc = treeEncoder(cuda,in_dim,mem_dim,userVects,labels,labelMap,criterion,device)
        self.textEnc = RecArch(vocabSize, 256, 50, 4, 1, 'gru',device)
        
        if textEncState is None:
#             checkpoint = torch.load('../baselines/gruTextEnc_twit16.pth')
#             self.textEnc.load_state_dict(checkpoint['state_dict'])
            pass
        else:
            self.textEnc.load_state_dict(textEncState)
        
        self.fc = nn.Linear(mem_dim+50,4)
        
    def forward(self,tree,text):
        treeVec = self.treeEnc(tree)
        treeVec = treeVec[0][1].reshape(-1)
        
        textVec = self.textEnc(text)
        textVec = textVec.reshape(-1)
#         print(treeVec.shape)
#         print(textVec.shape)
        combVec =  torch.cat((treeVec,textVec))
#         combVec = textVec
        out = self.fc(combVec)
        return out

NameError: name 'nn' is not defined

In [None]:
class decayTreeText(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion,device):
        super(decayTreeText,self).__init__()
        checkpoint = torch.load('../baselines/gruTextEnc_twit16.pth')
        self.treeEnc = decayTreeEncoder(cuda,in_dim,mem_dim,userVects,labels,labelMap,criterion,device)

        self.textEnc = RecArch(3370, 256, 50, 4, 1, 'gru',device)
        self.textEnc.load_state_dict(checkpoint['state_dict'])
        
        self.fc = nn.Linear(mem_dim+50,4)
        
    def forward(self,tree,text):
        treeVec = self.treeEnc(tree)
        treeVec = treeVec[0][1].reshape(-1)
        
        textVec = self.textEnc(text)
        textVec = textVec.reshape(-1)
#         print(treeVec.shape)
#         print(textVec.shape)
        combVec =  torch.cat((treeVec,textVec))
#         combVec = textVec
        out = self.fc(combVec)
        return out