In [1]:
require 'nn'
require 'nngraph'
Plot = require "itorch.Plot"
lstm = {}
include('utils.lua')
include('LSTM.lua')

In [61]:
-- model hyper-params
emb_dim = 2
sum_rep_dim = 20
class_hidden = 40
class_types = 2
seq_len = 5
num_layers = 2
-- create a model (in forward order)
emb = nn.LookupTable(10, emb_dim) -- embeddings layer
lstm_net = lstm.LSTM({num_layers=num_layers, input_dim=emb_dim, hidden_dim=sum_rep_dim}) -- LSTM layer
classifier = nn.Sequential()  -- classification layer
classifier:add(nn.Linear(sum_rep_dim, class_hidden)) -- classification: linear
classifier:add(nn.ReLU())  -- classification: activation
classifier:add(nn.Linear(class_hidden, class_types)) -- classification: output
classifier:add(nn.LogSoftMax()) -- classification: softmax
criterion = nn.ClassNLLCriterion() -- negative log-likelihood for training
-- if we don't use "optim" package, the rest is not need 
-- (I am going to write SGD myself)
local modules = nn.Parallel()
modules:add(emb)
modules:add(lstm_net)
modules:add(classifier)
params, grad_params = modules:getParameters()

In [62]:
-- Utility functions
function plot_emb(emb, plot)
    local even_points = emb.weight:index(1, torch.LongTensor{1,3,5,7,9})
    local odd_points = emb.weight:index(1, torch.LongTensor{2,4,6,8,10})
    if plot == nil then
        plot = Plot()
    else
        plot._data = {}
    end
    plot:circle(odd_points[{{}, 1}], odd_points[{{}, 2}], 'red', 'odd')
    plot:circle(even_points[{{}, 1}], even_points[{{}, 2}], 'blue', 'even')
    plot:title('Embeddings of numbers 0 - 9')
    plot:legend(true)
    plot:redraw()
    return plot
end
function gen_data(len)
    local inputs = torch.Tensor(len)
    local outputs = torch.Tensor(len)
    for i = 1, len do
        inputs[i] = torch.uniform(1,11)
    end
    inputs:floor()
    local s = 0
    for i = 1, len do
        s = s + inputs[i] - 1
        outputs[i] = (s % 2) + 1
    end
    return inputs, outputs
end
function display(inputs, outputs)
    local s = 0
    for i = 1, inputs:size(1) do
        s = s + inputs[i] - 1
        out_text = 'Even'
        if outputs[i] == 2 then out_text = 'Odd' end
        print(string.format('input %d, sum %d, output %s', inputs[i] - 1, s, out_text))
    end
end

In [71]:
-- Initial embeddings
local _ = plot_emb(emb, nil)

In [72]:
-- SGD hyper-params (this basic version only needs learning rate and number of iterations)
local learning_rate = 0.01
local epochs = 7500
-- Parallization is not helpful here
local num_threads = torch.getnumthreads()
torch.setnumthreads(1)
-- do many iterations
--   \[T]/  Praise the Sun!
--    |_|
--    \ \
for i = 1,epochs do
    -- Reset the network
    lstm_net:forget()
    lstm_net:zeroGradParameters()
    emb:zeroGradParameters()
    local input_seq, output_seq = gen_data(seq_len) -- get inputs and outputs
    local emb_seq = emb:forward(input_seq)          -- get inputs' embeddings
    local sum_rep_seq = lstm_net:forward(emb_seq:t()) -- get representations from LSTM
    local grad_sum_rep_seq = torch.zeros(sum_rep_dim, num_layers, seq_len) -- for later
    local total_loss = 0
    for t = 1,seq_len do -- for each sum (in the sequence)
        classifier:zeroGradParameters() -- reset the classifier
        -- get output log distribution, only need hidden state
        local rep = sum_rep_seq[{{},1,t}]
        local net_output = classifier:forward(rep) 
        local loss = criterion:forward(net_output, output_seq[t]) -- compute loss
        total_loss = total_loss + loss
        local grad_out = criterion:backward(net_output, output_seq[t]) -- gradients of the loss
        local grad_sum_rep = classifier:backward(rep, grad_out) -- gradients of the representation
        grad_sum_rep_seq[{{},num_layers,t}] = grad_sum_rep -- we get the gradient one at a time, so keep it here for later
        classifier:updateParameters(learning_rate) -- update the classifier's parameters
    end
    -- compute embedding gradients (using all of the representation gradients)
    local grad_emb_seq = lstm_net:backward(emb_seq:t(), grad_sum_rep_seq)
    -- compute input gradients 
    -- (this returns nothing, because the input is just index.
    -- But we need to call it to update the gradients accroding to the inputs
    emb:backward(input_seq, grad_emb_seq:t())
    -- update parameters
    lstm_net:updateParameters(learning_rate)
    emb:updateParameters(learning_rate)
end
-- Set num threads back
torch.setnumthreads(num_threads)

In [75]:
-- Testing
lstm_net:forget() -- reset LSTM (go back to step 0)
local input_seq, output_seq = gen_data(seq_len)
local emb_seq = emb:forward(input_seq)
local sum_rep_seq = lstm_net:forward(emb_seq:t())
local output_seq = torch.Tensor(seq_len) -- model's outputs, for later
for t = 1,seq_len do
    -- print('input:')
    -- print(input_seq[t] - 1)
    -- print(sum_rep_seq[t][2])
    local rep = sum_rep_seq[{{},1,t}]
    local net_output = classifier:forward(rep)
    -- print(net_output)
    local m, d = torch.max(net_output, 1)
    output_seq[t] = d -- keep the output here
end
-- show the result, is it perfect?
-- Yes: cool, can we make the sequence longer?
-- No: tweak some hyper-parameters
display(input_seq, output_seq) 

input 7, sum 7, output Odd	
input 8, sum 15, output Odd	
input 0, sum 15, output Odd	
input 1, sum 16, output Even	
input 9, sum 25, output Odd	


In [76]:
local _ = plot_emb(emb,nil)

In [60]:
emb.weight

 1.2402  2.3508
-2.8068  1.5669
 1.2212  2.3530
-2.9475  1.6077
 1.2472  2.3113
-3.0214  1.6653
 1.2684  2.3388
 3.5085 -1.1818
 1.3694  2.2624
 3.5260 -1.1935
[torch.DoubleTensor of size 10x2]

