# Stacked Bi-directional LSTMs in Gluon

Starting with a plain RNN, we'll look at various modifications and build up to the complete Stacked Bi-directional LSTM model.

* Base: __RNN__
* Modification: __Stacked RNN__
* Modification: __Bi-directional RNN__
* Modification: __LSTM__
* Combined: __Stacked Bi-directional LSTM__

# Base: RNN

### Implicit initial hidden state

In [74]:
import mxnet as mx

In [75]:
sequence_length = 4
batch_size = 5
channels = 3

inputs = mx.nd.random.uniform(shape=(sequence_length, batch_size, channels))
first_input = inputs[0]
first_input


[[ 0.98734874  0.91809195  0.75928247]
 [ 0.57968092  0.36454463  0.45589843]
 [ 0.50106317  0.64235175  0.37638915]
 [ 0.25402522  0.36491182  0.55107915]
 [ 0.26090449  0.13190325  0.49597031]]
<NDArray 5x3 @cpu(0)>

In [76]:
hid_layers = 1
hid_units = 6

rnn = mx.gluon.rnn.RNN(hidden_size=hid_units, num_layers=hid_layers, layout='TNC')

In [77]:
# lazy initialize weights
rnn.initialize()

In [78]:
# since not provided, will initialize hidden state to zeros of approprate shape
outputs = rnn(inputs)

In [79]:
# for a plain rnn, output is the same as hidden state. get it for every time step.
outputs.shape

(4, 5, 6)

In [80]:
final_output = outputs[-1]
final_output


[[ 0.          0.          0.07570447  0.03122622  0.02412307  0.        ]
 [ 0.          0.          0.07036699  0.03546758  0.02842021  0.01387072]
 [ 0.          0.          0.07198478  0.03106744  0.02606469  0.02974632]
 [ 0.          0.          0.06010557  0.01828784  0.02100974  0.04091441]
 [ 0.          0.          0.02801176  0.00882845  0.01097421  0.02552567]]
<NDArray 5x6 @cpu(0)>

### Explicit initial hidden state

In [81]:
hid_init = mx.nd.random.uniform(shape=(hid_layers, batch_size, hid_units))

In [82]:
# get tuple returned
outputs, hid_states = rnn(inputs, hid_init)

In [83]:
outputs.shape

(4, 5, 6)

In [84]:
final_output = outputs[-1]
final_output


[[ 0.          0.          0.07571341  0.03124048  0.02413538  0.        ]
 [ 0.          0.          0.07037674  0.03546442  0.02842471  0.01386217]
 [ 0.          0.          0.07199639  0.03109403  0.0260729   0.02971762]
 [ 0.          0.          0.0601346   0.01827844  0.02102314  0.04088894]
 [ 0.          0.          0.02804597  0.00882176  0.01098775  0.02550704]]
<NDArray 5x6 @cpu(0)>

In [85]:
# single hidden state between blocks for plain rnn
len(hid_states)

1

In [86]:
# only get for last time step
hid_states[0].shape

(1, 5, 6)

In [87]:
# same as final_output
hid_states[0]


[[[ 0.          0.          0.07571341  0.03124048  0.02413538  0.        ]
  [ 0.          0.          0.07037674  0.03546442  0.02842471  0.01386217]
  [ 0.          0.          0.07199639  0.03109403  0.0260729   0.02971762]
  [ 0.          0.          0.0601346   0.01827844  0.02102314  0.04088894]
  [ 0.          0.          0.02804597  0.00882176  0.01098775  0.02550704]]]
<NDArray 1x5x6 @cpu(0)>

# Modification: Stacked RNN

In [88]:
hid_layers = 2

In [89]:
stack_rnn = mx.gluon.rnn.RNN(hidden_size=hid_units, num_layers=hid_layers, layout='TNC')
stack_rnn.initialize()

In [90]:
hid_init = mx.nd.random.uniform(shape=(hid_layers, batch_size, hid_units))
outputs, hid_states = stack_rnn(inputs, hid_init)

In [91]:
# output unchanged by number of layers. once again, one per time step
outputs.shape

(4, 5, 6)

In [92]:
final_output = outputs[-1]
final_output


[[ 0.00406747  0.          0.          0.00646659  0.          0.00679768]
 [ 0.00386951  0.          0.          0.00600511  0.00066995  0.0054298 ]
 [ 0.0037198   0.          0.          0.00581633  0.00157896  0.00428085]
 [ 0.00373613  0.          0.          0.00632023  0.00210966  0.00340568]
 [ 0.00151994  0.          0.          0.00328808  0.00111398  0.00152676]]
<NDArray 5x6 @cpu(0)>

In [93]:
# single hidden state between blocks for plain rnn
len(hid_states)

1

In [94]:
# but now have more hidden states (last step only)
hid_states[0].shape

(2, 5, 6)

In [95]:
# see last element is same as output (first is not part of output)
hid_states[0]


[[[ 0.05187661  0.03314793  0.08078527  0.08762819  0.02698768  0.        ]
  [ 0.0393471   0.02712335  0.08269582  0.07542091  0.01029002  0.        ]
  [ 0.02794424  0.02142053  0.08560424  0.06404411  0.          0.        ]
  [ 0.00873652  0.01335483  0.09287432  0.04383021  0.          0.        ]
  [ 0.          0.00535677  0.04472532  0.01648927  0.          0.        ]]

 [[ 0.00406747  0.          0.          0.00646659  0.          0.00679768]
  [ 0.00386951  0.          0.          0.00600511  0.00066995  0.0054298 ]
  [ 0.0037198   0.          0.          0.00581633  0.00157896  0.00428085]
  [ 0.00373613  0.          0.          0.00632023  0.00210966  0.00340568]
  [ 0.00151994  0.          0.          0.00328808  0.00111398  0.00152676]]]
<NDArray 2x5x6 @cpu(0)>

# Modification: Bi-directional RNNs

In [96]:
hid_layers = 1
bidirectional = True

In [97]:
bidir_rnn = mx.gluon.rnn.RNN(hidden_size=hid_units, num_layers=hid_layers, layout='TNC', bidirectional=bidirectional)
bidir_rnn.initialize()

In [98]:
# now hid_layers * 2, initial hidden states for forward and backward rnns.
hid_init = mx.nd.random.uniform(shape=(hid_layers * 2, batch_size, hid_units))
outputs, hid_states = bidir_rnn(inputs, hid_init)

In [99]:
# hid_units * 2 channels
# 6 from forward rnn, 6 from backward rnn, concatenated to give 12
outputs.shape

(4, 5, 12)

In [100]:
final_output = outputs[-1]
final_output


[[ 0.08956115  0.          0.          0.03769168  0.          0.04953455
   0.          0.04680495  0.02324563  0.11417244  0.          0.06207506]
 [ 0.08177032  0.          0.          0.04138822  0.          0.02983724
   0.          0.05591435  0.          0.07984196  0.04715843  0.04717504]
 [ 0.07705043  0.          0.          0.040458    0.          0.02450354
   0.          0.05141705  0.          0.09351053  0.02774166  0.04651833]
 [ 0.0664442   0.          0.          0.03642423  0.          0.00982516
   0.          0.10003141  0.          0.12314523  0.          0.01361769]
 [ 0.02545916  0.          0.          0.01715774  0.          0.          0.
   0.07882698  0.02267237  0.0498731   0.01533458  0.01271129]]
<NDArray 5x12 @cpu(0)>

In [101]:
# from forward rnn
final_output[:,:6]


[[ 0.08956115  0.          0.          0.03769168  0.          0.04953455]
 [ 0.08177032  0.          0.          0.04138822  0.          0.02983724]
 [ 0.07705043  0.          0.          0.040458    0.          0.02450354]
 [ 0.0664442   0.          0.          0.03642423  0.          0.00982516]
 [ 0.02545916  0.          0.          0.01715774  0.          0.        ]]
<NDArray 5x6 @cpu(0)>

In [102]:
# single hidden state between blocks for plain rnn
len(hid_states)

1

In [103]:
# forward rnn hidden, then backward rnn hidden
# BUT from different time steps! orward rnn hidden from last time step, backward rnn hidden from first time step.
# useful when feeding a decoder, otherwise backward rnn only seen 1 example by step n.
hid_states[0]


[[[ 0.08956115  0.          0.          0.03769168  0.          0.04953455]
  [ 0.08177032  0.          0.          0.04138822  0.          0.02983724]
  [ 0.07705043  0.          0.          0.040458    0.          0.02450354]
  [ 0.0664442   0.          0.          0.03642423  0.          0.00982516]
  [ 0.02545916  0.          0.          0.01715774  0.          0.        ]]

 [[ 0.          0.06210476  0.          0.1196948   0.03029597  0.        ]
  [ 0.          0.0277768   0.          0.05905489  0.02285744  0.00437873]
  [ 0.          0.04278467  0.          0.07068352  0.01208621  0.        ]
  [ 0.          0.04355423  0.          0.03600511  0.02986493  0.        ]
  [ 0.          0.02898492  0.          0.02716501  0.02252049  0.        ]]]
<NDArray 2x5x6 @cpu(0)>

In [104]:
# same as first 6 channels of output at last stage
hid_states[0][0]


[[ 0.08956115  0.          0.          0.03769168  0.          0.04953455]
 [ 0.08177032  0.          0.          0.04138822  0.          0.02983724]
 [ 0.07705043  0.          0.          0.040458    0.          0.02450354]
 [ 0.0664442   0.          0.          0.03642423  0.          0.00982516]
 [ 0.02545916  0.          0.          0.01715774  0.          0.        ]]
<NDArray 5x6 @cpu(0)>

In [105]:
first_output = outputs[0]
first_output


[[ 0.06676939  0.          0.04957572  0.          0.          0.          0.
   0.06210476  0.          0.1196948   0.03029597  0.        ]
 [ 0.          0.04321212  0.04632973  0.          0.01752034  0.01783034
   0.          0.0277768   0.          0.05905489  0.02285744  0.00437873]
 [ 0.          0.0462874   0.03006371  0.          0.03545707  0.          0.
   0.04278467  0.          0.07068352  0.01208621  0.        ]
 [ 0.02601724  0.00984341  0.01533633  0.          0.          0.          0.
   0.04355423  0.          0.03600511  0.02986493  0.        ]
 [ 0.01376808  0.00530647  0.          0.          0.          0.01757525
   0.          0.02898492  0.          0.02716501  0.02252049  0.        ]]
<NDArray 5x12 @cpu(0)>

In [106]:
# from backward rnn
first_output[:,6:]


[[ 0.          0.06210476  0.          0.1196948   0.03029597  0.        ]
 [ 0.          0.0277768   0.          0.05905489  0.02285744  0.00437873]
 [ 0.          0.04278467  0.          0.07068352  0.01208621  0.        ]
 [ 0.          0.04355423  0.          0.03600511  0.02986493  0.        ]
 [ 0.          0.02898492  0.          0.02716501  0.02252049  0.        ]]
<NDArray 5x6 @cpu(0)>

In [107]:
# same as last 6 channels of output at first stage
hid_states[0][1]


[[ 0.          0.06210476  0.          0.1196948   0.03029597  0.        ]
 [ 0.          0.0277768   0.          0.05905489  0.02285744  0.00437873]
 [ 0.          0.04278467  0.          0.07068352  0.01208621  0.        ]
 [ 0.          0.04355423  0.          0.03600511  0.02986493  0.        ]
 [ 0.          0.02898492  0.          0.02716501  0.02252049  0.        ]]
<NDArray 5x6 @cpu(0)>

# Modification: LSTM

In [108]:
hid_layers = 1

In [109]:
lstm = mx.gluon.rnn.LSTM(hidden_size=hid_units, num_layers=hid_layers, layout='TNC')
lstm.initialize()

In [110]:
hid_init_h = mx.nd.random.uniform(shape=(hid_layers, batch_size, hid_units))
hid_init_c = mx.nd.random.uniform(shape=(hid_layers, batch_size, hid_units))
hid_init = [hid_init_h, hid_init_c]
outputs, hid_states = lstm(inputs, hid_init)

In [111]:
# output same as before
outputs.shape

(4, 5, 6)

In [112]:
final_output = outputs[-1]
final_output


[[ 0.01253454 -0.02345554  0.02795083  0.02464588  0.00788901  0.02751796]
 [ 0.01594986 -0.0184415   0.03759725  0.01181254  0.0063988   0.02784991]
 [ 0.0094468  -0.00300365  0.03994624 -0.00055172  0.03623314  0.01056249]
 [ 0.01638241  0.00139618  0.04743286  0.01258996  0.02631953  0.02926661]
 [ 0.00316927 -0.01528073  0.04393305  0.00166157  0.01046809  0.01166981]]
<NDArray 5x6 @cpu(0)>

In [113]:
# now have two cell memory and hidden state
len(hid_states)

2

In [114]:
# hidden state (bottom line in diagram)
hid_states[0].shape

(1, 5, 6)

In [115]:
# cell memory (top line in diagram)
hid_states[1].shape

(1, 5, 6)

In [116]:
# same as the output for uni-directional and non-stacked case
hid_states[0]


[[[ 0.01253454 -0.02345554  0.02795083  0.02464588  0.00788901  0.02751796]
  [ 0.01594986 -0.0184415   0.03759725  0.01181254  0.0063988   0.02784991]
  [ 0.0094468  -0.00300365  0.03994624 -0.00055172  0.03623314  0.01056249]
  [ 0.01638241  0.00139618  0.04743286  0.01258996  0.02631953  0.02926661]
  [ 0.00316927 -0.01528073  0.04393305  0.00166157  0.01046809  0.01166981]]]
<NDArray 1x5x6 @cpu(0)>

# Combined: Stacked Bi-directional LSTM

In [117]:
hid_layers = 2
bidirectional = True

In [118]:
stack_bidir_lstm = mx.gluon.rnn.LSTM(hidden_size=hid_units, num_layers=hid_layers, layout='TNC', bidirectional=bidirectional)
stack_bidir_lstm.initialize()

In [119]:
# 2 * hid_layers (since bi-directional)
hid_init_h = mx.nd.random.uniform(shape=(2*hid_layers, batch_size, hid_units))
hid_init_c = mx.nd.random.uniform(shape=(2*hid_layers, batch_size, hid_units))
hid_init = [hid_init_h, hid_init_c]
outputs, hid_states = stack_bidir_lstm(inputs, hid_init)

In [120]:
# 2 * hid_units = 12 channels since bi-directional
outputs.shape

(4, 5, 12)

In [121]:
final_output = outputs[-1]
final_output


[[ 0.02836313  0.00606763  0.02310923  0.00844923  0.03136069  0.00835426
   0.00834884  0.22160307  0.16497645  0.17498061  0.17497027  0.16968904]
 [ 0.02018425 -0.00205736  0.02368148  0.01517256  0.02936829  0.01105657
  -0.00114159  0.02814094  0.06628538  0.18898237  0.02543253  0.16276605]
 [ 0.02180842  0.02542746  0.04244835 -0.00741896  0.03391297  0.00343686
   0.18560167  0.15380849  0.1863319   0.14486022  0.11376306  0.03071362]
 [ 0.01190888  0.00480322  0.01914669  0.004719    0.01296413 -0.00054099
   0.05539425  0.09306861  0.21310844  0.12720783  0.08961899  0.15656401]
 [ 0.01545303  0.00941906  0.02581433  0.01284648  0.00813204  0.00568986
   0.19248244  0.0284084   0.17106406  0.02503861  0.10314985  0.01844846]]
<NDArray 5x12 @cpu(0)>

In [122]:
# channels from forward rnn in last step of last layer
final_output[:,:6]


[[ 0.02836313  0.00606763  0.02310923  0.00844923  0.03136069  0.00835426]
 [ 0.02018425 -0.00205736  0.02368148  0.01517256  0.02936829  0.01105657]
 [ 0.02180842  0.02542746  0.04244835 -0.00741896  0.03391297  0.00343686]
 [ 0.01190888  0.00480322  0.01914669  0.004719    0.01296413 -0.00054099]
 [ 0.01545303  0.00941906  0.02581433  0.01284648  0.00813204  0.00568986]]
<NDArray 5x6 @cpu(0)>

In [123]:
# channels from backward rnn in last step of last layer
final_output[:,6:]


[[ 0.00834884  0.22160307  0.16497645  0.17498061  0.17497027  0.16968904]
 [-0.00114159  0.02814094  0.06628538  0.18898237  0.02543253  0.16276605]
 [ 0.18560167  0.15380849  0.1863319   0.14486022  0.11376306  0.03071362]
 [ 0.05539425  0.09306861  0.21310844  0.12720783  0.08961899  0.15656401]
 [ 0.19248244  0.0284084   0.17106406  0.02503861  0.10314985  0.01844846]]
<NDArray 5x6 @cpu(0)>

In [124]:
len(hid_states)

2

In [125]:
# hidden state
hid_states[0].shape

(4, 5, 6)

In [126]:
# cell memeory
hid_states[1].shape

(4, 5, 6)

In [127]:
# combined forward and backward, then across stack, e.g.
# [ L1_forward
#   L1_backward,
#   L2_forward,
#   L2_backward ]
hid_states[0]


[[[  1.03355674e-02   1.60647300e-03   1.32582430e-02   1.38146300e-02
    -1.35828340e-02  -1.12173790e-02]
  [  2.33832765e-02   1.31784236e-05   2.16201320e-02   2.42266501e-03
    -9.80209676e-04  -1.19874710e-02]
  [  1.59390830e-02   2.46383925e-03   5.02551533e-03   1.26088131e-02
     4.33453452e-03  -2.34422293e-02]
  [  1.38091547e-02   1.20086502e-02   4.43106666e-02   7.60523602e-03
    -7.71074602e-03  -2.75092013e-02]
  [  3.55932228e-02   8.30993708e-03   2.85629816e-02  -1.24422682e-03
     5.16857579e-03   6.64406968e-03]]

 [[  9.75320395e-03   3.27899903e-02   3.42854112e-02   9.73990746e-03
    -1.82624068e-02  -4.52016518e-02]
  [  2.99503710e-02   4.22814377e-02   1.59972627e-02   3.10577378e-02
     3.01205181e-03   3.97029473e-03]
  [  2.96767391e-02   4.26794328e-02   1.73485801e-02   7.67373433e-03
    -1.50409192e-02  -7.97832385e-03]
  [  1.15076527e-02   4.02115956e-02   2.62511540e-02   3.11808214e-02
     1.00480448e-02  -1.79561898e-02]
  [  2.81295236e

In [128]:
# take last two rows since bi-dir
hid_last = hid_states[0][-2:,:]

In [129]:
# first of row pair, to get forward
hid_last_forward = hid_last[0]

In [130]:
# same as first 6 channels of last step output
hid_last_forward


[[ 0.02836313  0.00606763  0.02310923  0.00844923  0.03136069  0.00835426]
 [ 0.02018425 -0.00205736  0.02368148  0.01517256  0.02936829  0.01105657]
 [ 0.02180842  0.02542746  0.04244835 -0.00741896  0.03391297  0.00343686]
 [ 0.01190888  0.00480322  0.01914669  0.004719    0.01296413 -0.00054099]
 [ 0.01545303  0.00941906  0.02581433  0.01284648  0.00813204  0.00568986]]
<NDArray 5x6 @cpu(0)>

In [131]:
first_output = outputs[0]

In [132]:
# last 6 channels of first step output
first_output[:,6:]


[[-0.00148457  0.0318089   0.01821657  0.0233146   0.02311079  0.02432076]
 [ 0.00166964  0.00372249  0.00529031  0.03002147  0.00336746  0.02585452]
 [ 0.02263732  0.01611603  0.02180129  0.02280513  0.01358894  0.00735507]
 [ 0.00859078  0.01099376  0.0230337   0.02421864  0.00913094  0.03071127]
 [ 0.0328386  -0.00055734  0.01982282  0.01173063  0.00941289  0.01053296]]
<NDArray 5x6 @cpu(0)>

In [133]:
# second of row pair, to get backward
hid_last_backward = hid_last[1]

In [134]:
hid_last_backward


[[-0.00148457  0.0318089   0.01821657  0.0233146   0.02311079  0.02432076]
 [ 0.00166964  0.00372249  0.00529031  0.03002147  0.00336746  0.02585452]
 [ 0.02263732  0.01611603  0.02180129  0.02280513  0.01358894  0.00735507]
 [ 0.00859078  0.01099376  0.0230337   0.02421864  0.00913094  0.03071127]
 [ 0.0328386  -0.00055734  0.01982282  0.01173063  0.00941289  0.01053296]]
<NDArray 5x6 @cpu(0)>