In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import numpy as np

In [2]:
vocab_size = 50
embedding_dim = 30
batch_size = 3
sequence_length = 5

In [31]:
np_input = np.random.randint(low=0, high=vocab_size, size=(batch_size, sequence_length))
np_input[1][-1] = 0
np_input[2][-1] = 0
np_input[2][-2] = 0

inputs = Variable(torch.from_numpy(np_input))
inputs

Variable containing:
 44  23  25  11  21
  8  14  42  10   0
 28  28  35   0   0
[torch.LongTensor of size 3x5]

In [34]:
torch.sum(inputs, 1)

Variable containing:
 124
  74
  91
[torch.LongTensor of size 3x1]

In [4]:
embedding = nn.Embedding(vocab_size, embedding_dim)

In [5]:
B = embedding(inputs)
print(B.size())
s_b = torch.gt(inputs, 0).float()
print(s_b)

torch.Size([3, 5, 30])
Variable containing:
 1  1  1  1  1
 1  1  1  1  0
 1  1  1  0  0
[torch.FloatTensor of size 3x5]



In [6]:
V = [Variable(torch.zeros(batch_size, embedding_dim))]
print(V[0].size())
s = [Variable(torch.zeros(batch_size, 1))]
print(s)

torch.Size([3, 30])
[Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3x1]
]


In [7]:
B = embedding(inputs)
print(B.size())
s_b = torch.gt(inputs, 0).float()
print(s_b)

def readBuffer(alpha, B, s_b, pop=False):
    batch_size = B.size(0)
    sequence_len = B.size(1)
    embedding_dim = B.size(2)
    
    cumsum = Variable(torch.zeros(batch_size).float())
    vector = Variable(torch.zeros(batch_size, embedding_dim).float())
    batch_zeros = Variable(torch.zeros(batch_size))
    
    # May be more efficient not to loop like this... especially when buffer is empty
    for i in range(sequence_len):
        
        weights = torch.min(s_b[:, i], torch.max(batch_zeros, alpha - cumsum))  
        vector = torch.add(vector, torch.mul(weights.unsqueeze(1).expand_as(B[:, i]), B[:, i]))
        cumsum = torch.add(cumsum, weights)
        if pop:
            s_b[:, i] = torch.add(s_b[:, i], torch.mul(weights, -1))
            
            
        print(i)
        print(weights)
        
        if batch_size <= torch.sum(torch.ge(cumsum, alpha)).data[0]:
            break
            
    return vector, s_b

alpha_shift = Variable(torch.rand(batch_size))
print(alpha_shift)
vector, s_b = readBuffer(alpha_shift, B, s_b, pop=True)
print(s_b)
vector, s_b = readBuffer(alpha_shift, B, s_b, pop=True)
print(s_b)

torch.Size([3, 5, 30])
Variable containing:
 1  1  1  1  1
 1  1  1  1  0
 1  1  1  0  0
[torch.FloatTensor of size 3x5]

Variable containing:
 0.5315
 0.8538
 0.7823
[torch.FloatTensor of size 3]

0
Variable containing:
 0.5315
 0.8538
 0.7823
[torch.FloatTensor of size 3]

Variable containing:
 0.4685  1.0000  1.0000  1.0000  1.0000
 0.1462  1.0000  1.0000  1.0000  0.0000
 0.2177  1.0000  1.0000  0.0000  0.0000
[torch.FloatTensor of size 3x5]

0
Variable containing:
 0.4685
 0.1462
 0.2177
[torch.FloatTensor of size 3]

1
Variable containing:
 0.0630
 0.7075
 0.5646
[torch.FloatTensor of size 3]

Variable containing:
 0.0000  0.9370  1.0000  1.0000  1.0000
 0.0000  0.2925  1.0000  1.0000  0.0000
 0.0000  0.4354  1.0000  0.0000  0.0000
[torch.FloatTensor of size 3x5]



In [10]:
def pushStack(vector, alpha, V, s):
    V = [vector] + V
    s = [alpha] + s
    
    return V, s
    
V, s = pushStack(Variable(torch.ones(batch_size, embedding_dim)), Variable(torch.FloatTensor(batch_size).fill_(1)), V, s)

In [11]:
V = [Variable(torch.zeros(batch_size, embedding_dim))]
s = [Variable(torch.zeros(batch_size))]
alpha_push = Variable(torch.rand(batch_size))
# print(alpha_push)
alpha_ones = Variable(torch.ones(batch_size))
V, s = pushStack(Variable(torch.rand(batch_size, embedding_dim)), alpha_ones, V, s)
V, s = pushStack(Variable(torch.rand(batch_size, embedding_dim)), alpha_ones, V, s)
V, s = pushStack(Variable(torch.rand(batch_size, embedding_dim)), alpha_ones, V, s)
V, s = pushStack(Variable(torch.rand(batch_size, embedding_dim)), alpha_push, V, s)
# print(s)

def readStack(alpha, V, s, pop=False):
    batch_size = V[0].size(0)
    embedding_dim = V[0].size(1)
    
    vector1 = Variable(torch.zeros(batch_size, embedding_dim).float())
    vector2 = Variable(torch.zeros(batch_size, embedding_dim).float())
    
    cumsum = Variable(torch.zeros(batch_size).float())
    batch_zeros = Variable(torch.zeros(batch_size))
    batch_ones = Variable(torch.ones(batch_size))
    batch_alpha = Variable(torch.ones(batch_size)) if alpha is None else alpha
    
    # May be more efficient not to loop like this... especially when stack is empty
    for i in range(len(V)):
        weights1 = torch.min(s[i], torch.max(batch_zeros, batch_alpha - cumsum))
        weights2 = torch.min(s[i], torch.max(batch_zeros, batch_ones + batch_alpha - cumsum)) - weights1
        cumsum = torch.add(cumsum, weights1+weights2)
        
        vector1 = torch.add(vector1, torch.mul(weights1.unsqueeze(1).expand_as(V[0]), V[i]))
        vector2 = torch.add(vector2, torch.mul(weights2.unsqueeze(1).expand_as(V[0]), V[i]))
        
        if pop:
            s[i] = torch.add(s[i], torch.mul(weights1+weights2, -1))
        
        if batch_size == torch.sum(torch.ge(cumsum, 2)).data[0]:
            break

    return s

# print("s", s)
# alpha_reduce = Variable(torch.ones(batch_size))
r = readStack(None, V, s, pop=False)
print(r)
readStack(None, V, s, pop=True)
# print("s", s)

[Variable containing:
 0.0902
 0.0367
 0.2763
[torch.FloatTensor of size 3]
, Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]
, Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]
, Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]
, Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
]


[Variable containing:
  0
  0
  0
 [torch.FloatTensor of size 3], Variable containing:
  0
  0
  0
 [torch.FloatTensor of size 3], Variable containing:
  0.0902
  0.0367
  0.2763
 [torch.FloatTensor of size 3], Variable containing:
  1
  1
  1
 [torch.FloatTensor of size 3], Variable containing:
  0
  0
  0
 [torch.FloatTensor of size 3]]

In [30]:
torch.sum(torch.cat([x.unsqueeze(1) for x in s], 1), 1)

Variable containing:
 1.0902
 1.0367
 1.2763
[torch.FloatTensor of size 3x1]

In [32]:
def pop(alpha, B, s_b):
    batch_size = B.size(0)
    sequence_len = B.size(1)
    embedding_dim = B.size(2)
    
    cumsum = Variable(torch.zeros(batch_size).float())
    read = Variable(torch.zeros(batch_size, embedding_dim).float())
    batch_zeros = Variable(torch.zeros(batch_size))
    batch_alphas = Variable(torch.ones(batch_size).fill_(alpha))
    
    # May be more efficient not to loop like this... especially when buffer is empty
    for i in range(sequence_len):
        weights = torch.min(s_b[:, i], torch.max(batch_zeros, batch_alphas - cumsum))
        read = torch.add(read, torch.mul(weights.unsqueeze(1).expand_as(B[:, i]), B[:, i]))
        cumsum += s_b[:, i]
        s_b[:, i] = torch.add(s_b[:, i], torch.mul(weights, -1))
        
        if batch_size <= torch.sum(torch.ge(cumsum, alpha)):
            break
            
    return read, s_b

pop(1, B, s_b)

(Variable containing:
 
 Columns 0 to 12 
     0     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0     0
 
 Columns 13 to 25 
     0     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0     0
 
 Columns 26 to 29 
     0     0     0     0
     0     0     0     0
     0     0     0     0
 [torch.FloatTensor of size 3x30], Variable containing:
  0.0000  1.0000  1.0000  1.0000  1.0000
  0.0000  0.8942  1.0000  1.0000  0.0000
  0.0000  1.0000  1.0000  0.0000  0.0000
 [torch.FloatTensor of size 3x5])

In [14]:
cumsum = Variable(torch.zeros(batch_size).float())
read = Variable(torch.zeros(batch_size, embedding_dim).float())

batch_zeros = Variable(torch.zeros(batch_size))
batch_ones = Variable(torch.ones(batch_size))
for i in range(len(s_b[0])):
    weights = torch.min(s_b[:, i], torch.max(batch_zeros, 1 - cumsum))    
    read = torch.add(read, torch.mul(weights.unsqueeze(1).expand_as(B[:, i]), B[:, i]))
    cumsum += s_b[:, i]
    if batch_size <= torch.sum(torch.ge(cumsum, 1)):
        break
    
print(read)
print(cumsum)

Variable containing:

Columns 0 to 12 
    0     0     0     0     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 13 to 25 
    0     0     0     0     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     0     0     0     0     0

Columns 26 to 29 
    0     0     0     0
    0     0     0     0
    0     0     0     0
[torch.FloatTensor of size 3x30]

Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]



In [15]:
# Peek, Pop

In [16]:
controller = nn.LSTMCell(input_size=embedding_dim*3, hidden_size=embedding_dim)

In [None]:
W = nn.Linear(in_features=embedding_dim, out_features=2)

F.softmax()

In [69]:
def treelstm(c_left, c_right, gates, use_dropout=False, training=None):
    hidden_dim = c_left.size()[1]

    assert gates.size()[1] == hidden_dim * 5, "Need to have 5 gates."

    def slice_gate(gate_data, i):
        return gate_data[:, i * hidden_dim:(i + 1) * hidden_dim]

    # Compute and slice gate values
    i_gate, fl_gate, fr_gate, o_gate, cell_inp = \
        [slice_gate(gates, i) for i in range(5)]

    # Apply nonlinearities
    i_gate = F.sigmoid(i_gate)
    fl_gate = F.sigmoid(fl_gate)
    fr_gate = F.sigmoid(fr_gate)
    o_gate = F.sigmoid(o_gate)
    cell_inp = F.tanh(cell_inp)

    # Compute new cell and hidden value
    i_val = i_gate * cell_inp
    dropout_rate = 0.1
    if use_dropout:
        i_val = F.dropout(i_val, dropout_rate, training=training)
    c_t = fl_gate * c_left + fr_gate * c_right + i_val
    h_t = o_gate * F.tanh(c_t)

    return (c_t, h_t)

In [82]:
class SoftStack(nn.Module):
    def __init__(self, embedding_dim, hidden_size, vocab_size):
        super(SoftStack, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        
        # Reduce function for semantic composition.
        self.tree_left = nn.Linear(in_features=hidden_size, out_features=5*hidden_size)
        self.tree_right = nn.Linear(in_features=hidden_size, out_features=5*hidden_size)
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.controller = nn.LSTMCell(input_size=hidden_size*3, hidden_size=hidden_size)
        self.alpha_projector = nn.Linear(in_features=hidden_size, out_features=2)
        
    def readBuffer(self, alpha, B, s_b, pop=False):
        batch_size = B.size(0)
        sequence_len = B.size(1)

        cumsum = Variable(torch.zeros(batch_size).float())
        vector = Variable(torch.zeros(batch_size, self.embedding_dim).float())
        batch_zeros = Variable(torch.zeros(batch_size))
        batch_alpha = Variable(torch.ones(batch_size)) if alpha is None else alpha
        
        # May be more efficient not to loop like this... especially when buffer is empty
        for i in range(sequence_len):

            weights = torch.min(s_b[:, i], torch.max(batch_zeros, batch_alpha - cumsum))  
            vector = torch.add(vector, torch.mul(weights.unsqueeze(1).expand_as(B[:, i]), B[:, i]))
            cumsum = torch.add(cumsum, weights)
            if pop:
                s_b[:, i] = torch.add(s_b[:, i], torch.mul(weights, -1))

            if batch_size <= torch.sum(torch.ge(cumsum, batch_alpha)).data[0]:
                break

        return vector, s_b
    
    def pushStack(self, vector, alpha, V, s):
        V = [vector] + V
        s = [alpha] + s

        return V, s
    
    def readStack(self, alpha, V, s, pop=False):
        batch_size = V[0].size(0)
        embedding_dim = V[0].size(1)
    
        vector1 = Variable(torch.zeros(batch_size, embedding_dim).float())
        vector2 = Variable(torch.zeros(batch_size, embedding_dim).float())

        cumsum = Variable(torch.zeros(batch_size).float())
        batch_zeros = Variable(torch.zeros(batch_size))
        batch_ones = Variable(torch.ones(batch_size))
        batch_alpha = Variable(torch.ones(batch_size)) if alpha is None else alpha

        # May be more efficient not to loop like this... especially when stack is empty
        for i in range(len(V)):
            weights1 = torch.min(s[i], torch.max(batch_zeros, batch_alpha - cumsum))
            weights2 = torch.min(s[i], torch.max(batch_zeros, batch_ones + batch_alpha - cumsum)) - weights1
            cumsum = torch.add(cumsum, weights1+weights2)

            vector1 = torch.add(vector1, torch.mul(weights1.unsqueeze(1).expand_as(V[0]), V[i]))
            vector2 = torch.add(vector2, torch.mul(weights2.unsqueeze(1).expand_as(V[0]), V[i]))

            if pop:
                s[i] = torch.add(s[i], torch.mul(weights1+weights2, -1))

            if batch_size == torch.sum(torch.ge(cumsum, 2)).data[0]:
                break

        return vector1, vector2, s
    
    def init_controller(self, batch_size):
        # h_t, c_t
        state = (Variable(torch.zeros(batch_size, self.hidden_size)), \
                Variable(torch.zeros(batch_size, self.hidden_size)))
        return state
    
    def run_tree(self, v1, v2, use_dropout=False):
        gates = self.tree_left(v1)
        gates += self.tree_right(v2)
        c_t, h_t = treelstm(c_left=v1, c_right=v2, gates=gates, use_dropout=use_dropout, training=self.training)
        return (h_t, c_t)
    
    def forward(self, x, timesteps=None):
        batch_size = x.size(0)
        seq_len = x.size(1)
        if timesteps is None:
            time = 2*seq_len
        
        # initialize buffer
        B = embedding(x)
        s_b = torch.gt(x, 0).float()
        
        # initialize stack
        V = [Variable(torch.zeros(batch_size, self.hidden_size))]
        s = [Variable(torch.zeros(batch_size))]
        controller_state = self.init_controller(batch_size)
        
        # LSTM controller timesteps
        for t in range(time):
            # Read from stack and buffer
            x_b, _ = self.readBuffer(alpha=None, B=B, s_b=s_b, pop=False)
            x_1, x_2, _ = self.readStack(alpha=None, V=V, s=s, pop=False)
            x_t = torch.cat([x_b, x_1, x_2], 1)
            
            # Get alphas from controller
            controller_state = self.controller(x_t, controller_state)
            hidden_state = controller_state[0]
            alphas = F.softmax(self.alpha_projector(hidden_state))
            alpha_r = alphas[:, 0]
            alpha_s = alphas[:, 1]
            
            # Read from stack, reduce (treelstm), and push onto stack
            stack_reduce1, stack_reduce2, s = self.readStack(alpha_r, V, s, pop=True)
            tree_state = self.run_tree(stack_reduce1, stack_reduce2)
            hidden_state_tree = tree_state[0]
            V, s = self.pushStack(vector=hidden_state_tree, alpha=alpha_r, V=V, s=s)
            
            # Shift from buffer and push onto stack
            buffer_shift, s_b = self.readBuffer(alpha_s, B, s_b, pop=True)
            V, s = self.pushStack(vector=buffer_shift, alpha=alpha_s, V=V, s=s)
            
        # Pop from top of stack with strength 1 as final sentence representation
        x_1, x_2, _ = self.readStack(alpha=None, V=V, s=s, pop=False)
        return x_1

vocab_size = 50
embedding_dim = 30
batch_size = 3
hidden_size = embedding_dim

np_input = np.random.randint(low=1, high=vocab_size, size=(batch_size, sequence_length))
np_input[1][-1] = 0
np_input[2][-1] = 0
np_input[2][-2] = 0

inputs = Variable(torch.from_numpy(np_input))
print(inputs)

softstack = SoftStack(embedding_dim=embedding_dim, hidden_size=hidden_size, vocab_size=vocab_size)
softstack(inputs)

Variable containing:
 10  34  15  48  48
 14  33  17  27   0
 47  46  34   0   0
[torch.LongTensor of size 3x5]



Variable containing:

Columns 0 to 9 
-0.1191  0.2419 -0.2742  0.0171  0.0240  0.1359 -0.1157  0.1221 -0.1110  0.2958
 0.0070 -0.0128  0.0006 -0.0003  0.0163  0.0213  0.0085 -0.0048 -0.0063 -0.0138
-0.0014 -0.0132  0.0042 -0.0092  0.0117  0.0218  0.0094 -0.0085 -0.0016 -0.0114

Columns 10 to 19 
-0.2905  0.1654 -0.0387 -0.0140  0.4230  0.0085 -0.0100 -0.1814  0.2243  0.2183
-0.0083 -0.0173 -0.0118 -0.0449 -0.0134  0.0237 -0.0354 -0.0014  0.0327  0.0068
-0.0030 -0.0184 -0.0126 -0.0435 -0.0081  0.0242 -0.0378  0.0001  0.0313  0.0107

Columns 20 to 29 
 0.3681 -0.0540  0.2215 -0.0649 -0.2742  0.2598  0.1382 -0.1867 -0.0145 -0.2273
 0.0278  0.0274 -0.0227 -0.0240  0.0204  0.0375  0.0269  0.0150  0.0259 -0.0322
 0.0277  0.0336 -0.0286 -0.0267  0.0189  0.0393  0.0250  0.0125  0.0283 -0.0362
[torch.FloatTensor of size 3x30]