#Understanding torch computational graphs and LTSMs
This module serves as an exercise in understanding how computational
graphs can be modelled using lua.

Given an ltsm rnn, we can think of the algorithm of being comprised of 4 main stages.

1. The forget gate: f_t = sigma(Wf(ht_1, x_t) + b_f)S
2. The input gate: 
    i_t = sigma(Wi(ht_1, x_t) + b_i)
    ~c_t = tanh(Wc(ht_1, x_t) +b_c)
3. Generate new cell state:
    c_t = f_t*(c_t-1) + i_t * c_t
4. Output decision:
    o_t = sigma(Wo(ht_1, x_t) + b_o)
    h_t = tanh(c_t)*o_t
                
The model descbribed below impelements a computational graph using lua tables to make computation more effecient and faster.

#STEP 1: 
Join the input sequence $x_{t}$, the previous cell state $c_{t}$ and the previous hidden state $h_{t-1}$ into a lua table.
![Step 1: Generate input table](img/c1.png)
Where the output preactivation vector is then split into the following array:
${i, rnn_size, 2*rnn\_size+i, 3*rnn\_size+i}$. This is done by the nn graph module functions:

nn.Narrow(dim, start, len)(preactivation)
eg for the $pre\_sigmoid\_chunk$ = nn.Narrow(2, 1, 3*$rnn\_size$)(preactivations):
translation, from the 2nd col of the vector preactivations exctract a vector of length 3*rnn\_size

#Step 2:
![Step 2: split up components into their respective outputs before passing through a sigmoid layer](img/c2.png)



#Step 3:
![Step 3: Convert the input x into the current cell state](img/c3.png)



In [12]:
--This module creates the described computational graph
require 'nn'
require 'nngraph'

local LSTM = {}

function LSTM.create(input_size, rnn_size)
--Definition of some of these functions:
--nn.Narrow(dim, start, len) - selects a subvector along 
--dim dimension having len elements starting from start index
    
--nn.CMulTable() - outputs the product of tensors in forwarded table
--nn.CAddTable() - outputs the sum of tensors in forwarded table
  --------------------- input structure ---------------------
   
  ------------------------------------------------------------  
  ----Step 1, as shown in the above computational graphs
  ------------------------------------------------------------
  local inputs = {}
  table.insert(inputs, nn.Identity()())   -- network input
  table.insert(inputs, nn.Identity()())   -- c at time t-1
  table.insert(inputs, nn.Identity()())   -- h at time t-1
  local input = inputs[1]
  local prev_c = inputs[2]
  local prev_h = inputs[3]

  --------------------- preactivations ----------------------
  local i2h = nn.Linear(input_size, 4 * rnn_size)(input)   -- input to hidden
  local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)    -- hidden to hidden
  local preactivations = nn.CAddTable()({i2h, h2h})        -- i2h + h2h

  
    
  ----Step 2, as shown in the above computational graphs 
  ------------------ non-linear transforms ------------------
  -- gates

  local pre_sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(preactivations)
  local all_gates = nn.Sigmoid()(pre_sigmoid_chunk)

  -- input
  local in_chunk = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(preactivations)
  local in_transform = nn.Tanh()(in_chunk)

  ---------------------- gate narrows -----------------------
  local in_gate = nn.Narrow(2, 1, rnn_size)(all_gates)
  local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(all_gates)
  local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(all_gates)
  --------------------------------------------------------------
    
  ---- Step 3, as in comp graph above---------------------------
  --------------------- next cell state ---------------------
  local c_forget = nn.CMulTable()({forget_gate, prev_c})  -- previous cell state contribution
  local c_input = nn.CMulTable()({in_gate, in_transform}) -- input contribution
  local next_c = nn.CAddTable()({
    c_forget,
    c_input
  })

  -------------------- next hidden state --------------------
  local c_transform = nn.Tanh()(next_c)
  local next_h = nn.CMulTable()({out_gate, c_transform})

  --------------------- output structure --------------------
  outputs = {}
  table.insert(outputs, next_c)
  table.insert(outputs, next_h)

  -- packs the graph into a convenient module with standard API (:forward(), :backward())
  return nn.gModule(inputs, outputs)
end

return LSTM