In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim
import torch.nn.functional as F

# Tree Encoders

In [None]:
class OutputModule(nn.Module):
    def __init__(self, cuda, mem_dim, num_classes, dropout = False):
        super(OutputModule, self).__init__()
        self.cudaFlag = cuda
        self.mem_dim = mem_dim
        self.num_classes = num_classes
        self.dropout = dropout

        self.l1 = nn.Linear(self.mem_dim, self.num_classes)
        self.logsoftmax = nn.LogSoftmax()
        if self.cudaFlag:
            self.l1 = self.l1.to('cuda:1')

    def forward(self, vec, training = False):
        if self.dropout:
            out = self.logsoftmax(self.l1(F.dropout(vec, training = training)))
        else:
            out = self.logsoftmax(self.l1(vec))
        return out

In [1]:
class treeEncoder(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion):
        super(fixedTreeEncoder, self).__init__()
        self.cudaFlag = cuda
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.labels = labels
        self.labelMap = labelMap
        self.criterion = criterion

        self.ix = nn.Linear(self.in_dim,self.mem_dim)
        self.ih = nn.Linear(self.mem_dim,self.mem_dim)

        self.fx = nn.Linear(self.in_dim,self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)

        self.ux = nn.Linear(self.in_dim,self.mem_dim)
        self.uh = nn.Linear(self.mem_dim,self.mem_dim)

        self.ox = nn.Linear(self.in_dim,self.mem_dim)
        self.oh = nn.Linear(self.mem_dim,self.mem_dim)
        
        self.userVects = userVects
        self.outputModule = OutputModule(self.cudaFlag,mem_dim,4,dropout=False)
    
    def predict(self,node):
        loss = Variable(torch.zeros(1))
        if self.cudaFlag:
            loss = loss.to('cuda:1')
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.userVects[node.uid],child_c,child_h)
        
        output = self.outputModule.forward(node.state[1], False)
        
        return output
    
    def forward(self,node):
        loss = Variable(torch.zeros(1))
        if self.cudaFlag:
            loss = loss.to('cuda:1')
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.userVects[node.uid],child_c,child_h)
        
        output = self.outputModule.forward(node.state[1], True)
        node.output = output
            
        label = Variable(torch.tensor(self.labelMap[node.label]))
                
        if self.cudaFlag:
            label = label.to('cuda:1')
            
        loss = loss + self.criterion(output.reshape(-1,4), label.reshape(-1))
        
        return node.state, loss
        
    def nodeForward(self,x,child_c,child_h):
        # h^~_j = sum of child hidden states
        child_h_sum = torch.sum(child_h,0)
        
        i = torch.sigmoid(self.ix(x) + self.ih(child_h_sum))
        o = torch.sigmoid(self.ox(x)+self.oh(child_h_sum))
        u = torch.tanh(self.ux(x)+self.uh(child_h_sum))
        
        fx = self.fx(x)
        f = torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0)
        fc = torch.sigmoid(f)
        
        c = i*u + torch.sum(fc,0)
        h = o*torch.tanh(c)
        
        return c,h
    
    def getChildStates(self,node):
        if node.num_children==0:
            child_c = Variable(torch.zeros(1,self.mem_dim))
            child_h = Variable(torch.zeros(1,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to('cuda:1'), child_h.to('cuda:1')
        
        else:
            child_c = Variable(torch.Tensor(node.num_children,self.mem_dim))
            child_h = Variable(torch.Tensor(node.num_children,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to('cuda:1'), child_h.to('cuda:1')
            
            for idx in range(node.num_children):
                child_c[idx] = node.childrenList[idx].state[0]
                child_h[idx] = node.childrenList[idx].state[1]
        return child_c, child_h

NameError: name 'nn' is not defined

# Temporal Decay Tree Encoder

## Need to multiply decay factor into the forget gates in the nodeForward Method

### implement both direct proportion decay and exponential decay

In [None]:
class treeEncoder(nn.Module):
    def __init__(self, cuda,in_dim, mem_dim,userVects,labels,labelMap,criterion):
        super(fixedTreeEncoder, self).__init__()
        self.cudaFlag = cuda
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        self.labels = labels
        self.labelMap = labelMap
        self.criterion = criterion

        self.ix = nn.Linear(self.in_dim,self.mem_dim)
        self.ih = nn.Linear(self.mem_dim,self.mem_dim)

        self.fx = nn.Linear(self.in_dim,self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)

        self.ux = nn.Linear(self.in_dim,self.mem_dim)
        self.uh = nn.Linear(self.mem_dim,self.mem_dim)

        self.ox = nn.Linear(self.in_dim,self.mem_dim)
        self.oh = nn.Linear(self.mem_dim,self.mem_dim)
        
        self.userVects = userVects
        self.outputModule = OutputModule(self.cudaFlag,mem_dim,4,dropout=False)
    
    def predict(self,node):
        loss = Variable(torch.zeros(1))
        if self.cudaFlag:
            loss = loss.to('cuda:1')
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.userVects[node.uid],child_c,child_h)
        
        output = self.outputModule.forward(node.state[1], False)
        
        return output
    
    def forward(self,node):
        loss = Variable(torch.zeros(1))
        if self.cudaFlag:
            loss = loss.to('cuda:1')
        
        for i in range(node.num_children):
            _, child_loss = self.forward(node.childrenList[i])
            loss = loss + child_loss
        child_c, child_h = self.getChildStates(node)
        node.state = self.nodeForward(self.userVects[node.uid],child_c,child_h)
        
        output = self.outputModule.forward(node.state[1], True)
        node.output = output
            
        label = Variable(torch.tensor(self.labelMap[node.label]))
                
        if self.cudaFlag:
            label = label.to('cuda:1')
            
        loss = loss + self.criterion(output.reshape(-1,4), label.reshape(-1))
        
        return node.state, loss
        
    def nodeForward(self,x,child_c,child_h):
        # h^~_j = sum of child hidden states
        child_h_sum = torch.sum(child_h,0)
        
        i = torch.sigmoid(self.ix(x) + self.ih(child_h_sum))
        o = torch.sigmoid(self.ox(x)+self.oh(child_h_sum))
        u = torch.tanh(self.ux(x)+self.uh(child_h_sum))
        
        fx = self.fx(x)
        f = torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0)
        fc = torch.sigmoid(f)
        
        c = i*u + torch.sum(fc,0)
        h = o*torch.tanh(c)
        
        return c,h
    
    def getChildStates(self,node):
        if node.num_children==0:
            child_c = Variable(torch.zeros(1,self.mem_dim))
            child_h = Variable(torch.zeros(1,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to('cuda:1'), child_h.to('cuda:1')
        
        else:
            child_c = Variable(torch.Tensor(node.num_children,self.mem_dim))
            child_h = Variable(torch.Tensor(node.num_children,self.mem_dim))
            if self.cudaFlag:
                child_c, child_h = child_c.to('cuda:1'), child_h.to('cuda:1')
            
            for idx in range(node.num_children):
                child_c[idx] = node.childrenList[idx].state[0]
                child_h[idx] = node.childrenList[idx].state[1]
        return child_c, child_h

# LSTM Top Level Encoder