In [1]:
%run Latex_macros.ipynb

<IPython.core.display.Latex object>

# Inside the LSTM: update equations

An LSTM layer, at time step $\tt$
- Takes input element $\x_\tp$
- Updates long term memory $\c_\tp$
- Updates control state $\h_\tp$
- Optionally outputs $\y_\tp$

according to the equations


$$
\begin{array}[lll]\\
\c_\tp & = & \remember_\tp \otimes \c_{(\tt-1)} + \save_\tp \otimes \c'_\tp & \text{Long term memory} \\
\h_\tp & = & \focus_\tp \otimes \tanh(\c_\tp) & \text{Short term memory/control}\\
\y_\tp &  = & \h_\tp & \text{Output}
\end{array}
$$

where
$$
\begin{array}[lll]\\
\remember & \text{is a gate that allows elements of } \c \text{ to be remembered/forgotten} \\
\focus & \text{is a mask that controls movement from long-term to short-term memory}\\
\save & \text{is a gate that controls selective updating of elements of } \c 
\end{array}
$$

A lot of moving parts !

The important thing to remember is that the layer
- Has a matrix $\W$ of weights
- That controls the update of each part

Training, as usual, seeks to discover the optimal (i.e., loss function minimizing) values for $\W$

Let's try to understand the whole by examining each piece.

# Memory/States

$\c_\tp$ is the long-term memory.

It is a vector of features that need to be retained throughout the computation
- As each element $\x_\tp$ of input sequence $\x$ is processed
- It records the "concepts" that are important to solve the task
- Might have many elements

$\h_\tp$ is the short-term or control memory.

Very much like a vanilla RNN, it's job is to guide the transition from state $\h_\tp$ to state $\h_{(\tt+1)}$.

It is a vector of features that need to be retained *for the immediate future*
- It might have many elements
- Same length as $\c_\tp$

An analogy might help.

Suppose you are driving a car in an unfamiliar city
- Short term memory is a map of the surrounding blocks
- Long term memory is a map of the city plus rules of the road

# Gates/Masks

$\remember, \save, \focus$ are vectors that interact with long term memory $\c_\tp$
- Element-wise
- So have same length as $\c_\tp$

They will be used
- To selectively modify individual elements of $\c_\tp$
- Forget/Reset the value of an element that is no longer relevant
- Decide which individual elements to update

The [classic paper that introduced the LSTM](https://www.bioinf.jku.at/publications/older/2604.pdf) gives these gates different names
- $\remember_\tp \mapsto \mathbf{f}_\tp$
    - $\mathbf{f}$ denotes "Forget" (although it really means "don't forget", i.e,  remember !)
- $\save_\tp \mapsto \mathbf{i}_\tp$
    - $\mathbf{i}$ denotes "Input"
- $\focus_\tp \mapsto \mathbf{o}_\tp$
    - $\mathbf{o}$ denotes "Output"

Hopefully the names in our presentation add clarity.

# Output 

$\y_\tp$ is the value (if any) output at step $\tt$, for
- A one to many function
- Or a many to many function

As written
$$
\y_\tp = \h_\tp
$$
so it has the same length as a memory element $\h_\tp, \c_\tp$.

This assumption is purely for simplicity
- You can map $\h_\tp$ through another layer
- That transforms $\h_\tp$ into the appropriate type/shape for output $\y_\tp$

In fact, our equation for the vanilla RNN included this final transformation of $\h$ to $\y$:
$$\y_\tp  =   \W_{hy} \h_\tp  + \b_y$$

# The update process

Let's examine the update equation for each of the parts.

## Update long-term memory

Long-term memory is updated in a two step process
- Produce a "candidate" updated value for each element of the state
- Decide which of the candidate updated values get applied to the long term memory
    - Successful candidates become part of long term memory
    - Unsuccessful candidates are dropped

The candidate update value vector $\c'_\tp$ is a function of
- The prior short term state $\h_{(t-1)}$
-  And the current input $\x_\tp$
- Controlled by parts of the weight matrix $\W$

$$\c'_\tp  = \tanh(\W_{x,c} \x_\tp + \W_{h,c}\h_{(t-1)} + \b_c)$$

This is very much like the RNN state update equation
- Although the RNN equation has $\h$ on both sides of the equation, so is directly recursive in form


We now need to decide which elements of $\c_\tp$ to change.

The $\remember$ mask controls forgetting the current value $\c_{(\tt-1)}$
- When $\remember_{\tp,j} = 0$
    - $\c_{(\tt-1),j}$, the $j^{th}$ element of $\c_{(\tt-1)}$
    - Will be reset to $0$ ("forgotten")
- When $\remember_{\tp,j} = 1$
    - $\c_{(\tt-1),j}$, the $j^{th}$ element of $c_{(\tt-1)}$ 
    - Will contribute to the new value $\c_{\tp,j}$


The $\save$ mask controls whether the candidate value $\c'_\tp$ contributes to the new value $\c_\tp$:
- When $\save_{\tp,j} = 1$
    - Candidate value $\c'_{\tp,j}$ will contribute to the new value $\c_{\tp,j}$
- When $\save_{\tp,j} = 0$
    - Candidate value $\c'_{\tp,j}$ will **not** contribute to the new value $\c_{\tp,j}$

Here is the update equation for $\c_\tp$.
- It combines the remember/forget decision for each element
- With the decision on passing through the candidate value for the element

$$\c_\tp = \remember_\tp \otimes \c_{(t-1)} + \save_\tp \otimes \c'_\tp$$

Cases:
- when $\remember_{\tp,j} = 0$ and $\save_{\tp,j} = 0$
    - $\c_{\tp,j}$ is reset to $0$
- when $\remember_{\tp,j} = 0$ and $\save_{\tp,j} = 1$
    - $\c'_{\tp,j}$ replaces $\c_{\tp,j}$ 
- when $\remember_{\tp,j} = 1$ and $\save_{\tp,j} = 0$
    - $\c_{\tp,j}$ is carried over from $\c_{(\tt-1),j}$
- when $\remember_{\tp,j} = 1$ and $\save_{\tp,j} = 1$
    - $\c_{\tp,j}$ is incremented by $\c'_{\tp,j}$
    

### The role of $\tanh$ in the candidate value equation

Why modify the candidate value $\c'_\tp$ by passing it through $\tanh$ ?

The $\tanh$ has the important property
- That its range is $[-1, +1]$

So updates to $\c_\tp$ have the flavor of either
- Incrementing existing value $\c_{(\tt-1),j}$ by $1$
- Or decrementing existing value $\c_{(\tt-1),j}$ by $1$

This makes $\c_\tp$ act like a *counter*.

Recall that when we tried to give intuition as to what each element of the latent state vector of an RNN did:
- we postulated that it acted as a counter

## Update short-term memory (control state)

The short-term memory update
- Selectively copies parts of the newly updated long-term memory $\c_\tp$
- To short-term memory

$$
\begin{array}\\
\h_\tp & = & \focus_\tp \otimes \tanh(\c_\tp) & \text{Short term memory/control}
\end{array}
$$

The $\focus$ mask selects which elements of long-term memory are immediately relevant for control

The $\tanh$ activation function applied to long-term memory $\c_\tp$
- Squashes the range

## The gate/mask update equations

All of the gates are updated via similar equations, taking
- The prior short term state $\h_{(t-1)}$
- And the current input $\x_\tp$
- Controlled by parts of the weight matrix $\W$

$$
\begin{array}[lll] \\
\remember_\tp  & = & f_\tp &  = & \sigma(\W_{x,f} \x_{(t)} + \W_{h,f}\h_{(t-1)} + \b_f) \\
\save_\tp      & = & i_\tp &  = & \sigma(\W_{x,i} \x_{(t)} + \W_{h,i}\h_{(t-1)} + \b_i) \\
\focus_\tp     & = & o_\tp &  = & \sigma(\W_{x,o} \x_{(t)} + \W_{h,o}\h_{(t-1)} + \b_o) \\
\end{array}
$$

Notice the use of the sigmoid activation for each gate/mask:
- This restricts the range of each element to $[0,1]$
- As needed by a gate/mask
- The gate values are "soft" decisions (rather than exclusively either True of False)

# Conclusion

That was quite a workout.

There were lots of moving parts, but hopefully you can now understand each.

To conclude, here is the full set of update equations

$$
\begin{array}[lll]\\
\c'_\tp  & = & \tanh(\W_{x,c} \x_\tp + \W_{h,c}\h_{(t-1)} + \b_c) & \text{Candidate update value} \\
\c_\tp & = & \remember_\tp \otimes \c_{(\tt-1)} + \save_\tp \otimes \c'_\tp & \text{Long term memory} \\
\h_\tp & = & \focus_\tp \otimes \tanh(\c_\tp) & \text{Short term memory/control}\\
\y_\tp &  = & \h_\tp & \text{Output} \\
\text{where} \\
\remember_\tp  & = &  \sigma(\W_{x,f} \x_{(t)} + \W_{h,f}\h_{(t-1)} + \b_f) \\
\save_\tp      & = &  \sigma(\W_{x,i} \x_{(t)} + \W_{h,i}\h_{(t-1)} + \b_i) \\
\focus_\tp     & = &  \sigma(\W_{x,o} \x_{(t)} + \W_{h,o}\h_{(t-1)} + \b_o) \\
\end{array}
$$



$$
\begin{array}[lll] \\
\remember_\tp  & = & f_\tp &  = & \sigma(\W_{x,f} \x_{(t)} + \W_{h,f}\h_{(t-1)} + \b_f) \\
\save_\tp      & = & i_\tp &  = & \sigma(\W_{x,i} \x_{(t)} + \W_{h,i}\h_{(t-1)} + \b_i) \\
\focus_\tp     & = & o_\tp &  = & \sigma(\W_{x,o} \x_{(t)} + \W_{h,o}\h_{(t-1)} + \b_o) \\
\end{array}
$$

In [2]:
print("Done")

Done
