🏷️sec_bptt
So far we have repeatedly alluded to things like
exploding gradients,
vanishing gradients,
and the need to
detach the gradient for RNNs.
For instance, in :numref:sec_rnn_scratch
we invoked the detach
function on the sequence.
None of this was really fully
explained, in the interest of being able to build a model quickly and
to see how it works.
In this section,
we will delve a bit more deeply
into the details of backpropagation for sequence models and why (and how) the mathematics works.
We encountered some of the effects of gradient explosion when we first
implemented RNNs (:numref:sec_rnn_scratch
).
In
particular,
if you solved the exercises,
you would
have seen that gradient clipping is vital to ensure proper
convergence.
To provide a better understanding of this issue, this
section will review how gradients are computed for sequence models.
Note
that there is nothing conceptually new in how it works. After all, we are still merely applying the chain rule to compute gradients. Nonetheless, it is
worth while reviewing backpropagation (:numref:sec_backprop
) again.
We have described forward and backward propagations
and computational graphs
in MLPs in :numref:sec_backprop
.
Forward propagation in an RNN is relatively
straightforward.
Backpropagation through time is actually a specific
application of backpropagation
in RNNs :cite:Werbos.1990
.
It
requires us to expand the
computational graph of an RNN
one time step at a time to
obtain the dependencies
among model variables and parameters.
Then,
based on the chain rule,
we apply backpropagation to compute and
store gradients.
Since sequences can be rather long, the dependency can be rather lengthy.
For instance, for a sequence of 1000 characters,
the first token could potentially have significant influence on the token at the final position.
This is not really computationally feasible
(it takes too long and requires too much memory) and it requires over 1000 matrix products before we would arrive at that very elusive gradient.
This is a process fraught with computational and statistical uncertainty.
In the following we will elucidate what happens
and how to address this in practice.
🏷️subsec_bptt_analysis
We start with a simplified model of how an RNN works. This model ignores details about the specifics of the hidden state and how it is updated. The mathematical notation here does not explicitly distinguish scalars, vectors, and matrices as it used to do. These details are immaterial to the analysis and would only serve to clutter the notation in this subsection.
In this simplified model,
we denote subsec_rnn_w_hidden_states
that the input and the hidden state
can be concatenated to
be multiplied by one weight variable in the hidden layer.
Thus, we use
$$\begin{aligned}h_t &= f(x_t, h_{t-1}, w_h),\o_t &= g(h_t, w_o),\end{aligned}$$
:eqlabel:eq_bptt_ht_ot
where
For backpropagation, matters are a bit trickier, especially when we compute the gradients with regard to the parameters
$$\begin{aligned}\frac{\partial L}{\partial w_h} & = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial w_h} \& = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial o_t} \frac{\partial g(h_t, w_o)}{\partial h_t} \frac{\partial h_t}{\partial w_h}.\end{aligned}$$
:eqlabel:eq_bptt_partial_L_wh
The first and the second factors of the
product in :eqref:eq_bptt_partial_L_wh
are easy to compute.
The third factor eq_bptt_ht_ot
,
eq_bptt_partial_ht_wh_recur
To derive the above gradient, assume that we have three sequences
eq_bptt_at
By substituting
the gradient computation in :eqref:eq_bptt_partial_ht_wh_recur
satisfies
eq_bptt_at
,
we can remove the recurrent computation in :eqref:eq_bptt_partial_ht_wh_recur
with
eq_bptt_partial_ht_wh_gen
While we can use the chain rule to compute
Obviously,
we can just compute the full sum in
:eqref:eq_bptt_partial_ht_wh_gen
.
However,
this is very slow and gradients can blow up,
since subtle changes in the initial conditions can potentially affect the outcome a lot.
That is, we could see things similar to the butterfly effect where minimal changes in the initial conditions lead to disproportionate changes in the outcome.
This is actually quite undesirable in terms of the model that we want to estimate.
After all, we are looking for robust estimators that generalize well. Hence this strategy is almost never used in practice.
Alternatively,
we can truncate the sum in
:eqref:eq_bptt_partial_ht_wh_gen
after sec_rnn_scratch
.
This leads to an approximation of the true gradient, simply by terminating the sum at
Jaeger.2002
.
One of the consequences of this is that the model focuses primarily on short-term influence rather than long-term consequences. This is actually desirable, since it biases the estimate towards simpler and more stable models.
Last, we can replace eq_bptt_partial_ht_wh_recur
with
It follows from the definition of Tallec.Ollivier.2017
.
:numref:fig_truncated_bptt
illustrates the three strategies when analyzing the first few characters of The Time Machine book using backpropagation through time for RNNs:
- The first row is the randomized truncation that partitions the text into segments of varying lengths.
- The second row is the regular truncation that breaks the text into subsequences of the same length. This is what we have been doing in RNN experiments.
- The third row is the full backpropagation through time that leads to a computationally infeasible expression.
Unfortunately, while appealing in theory, randomized truncation does not work much better than regular truncation, most likely due to a number of factors. First, the effect of an observation after a number of backpropagation steps into the past is quite sufficient to capture dependencies in practice. Second, the increased variance counteracts the fact that the gradient is more accurate with more steps. Third, we actually want models that have only a short range of interactions. Hence, regularly truncated backpropagation through time has a slight regularizing effect that can be desirable.
After discussing the general principle,
let us discuss backpropagation through time in detail.
Different from the analysis in
:numref:subsec_bptt_analysis
,
in the following
we will show
how to compute
the gradients of the objective function
with respect to all the decomposed model parameters.
To keep things simple, we consider
an RNN without bias parameters,
whose
activation function
in the hidden layer
uses the identity mapping (
$$\begin{aligned}\mathbf{h}t &= \mathbf{W}{hx} \mathbf{x}t + \mathbf{W}{hh} \mathbf{h}{t-1},\ \mathbf{o}t &= \mathbf{W}{qh} \mathbf{h}{t},\end{aligned}$$
where $\mathbf{W}{hx} \in \mathbb{R}^{h \times d}$, $\mathbf{W}{hh} \in \mathbb{R}^{h \times h}$, and
In order to visualize the dependencies among
model variables and parameters during computation
of the RNN,
we can draw a computational graph for the model,
as shown in :numref:fig_rnn_bptt
.
For example, the computation of the hidden states of time step 3, $\mathbf{h}3$, depends on the model parameters $\mathbf{W}{hx}$ and
As just mentioned, the model parameters in :numref:fig_rnn_bptt
are $\mathbf{W}{hx}$, $\mathbf{W}{hh}$, and $\mathbf{W}{qh}$.
Generally,
training this model
requires
gradient computation with respect to these parameters
$\partial L/\partial \mathbf{W}{hx}$, $\partial L/\partial \mathbf{W}{hh}$, and $\partial L/\partial \mathbf{W}{qh}$.
According to the dependencies in :numref:fig_rnn_bptt
,
we can traverse
in the opposite direction of the arrows
to calculate and store the gradients in turn.
To flexibly express the multiplication
of matrices, vectors, and scalars of different shapes
in the chain rule,
we continue to use
the
sec_backprop
.
First of all,
differentiating the objective function
with respect to the model output
at any time step
eq_bptt_partial_L_ot
Now, we can calculate the gradient of the objective function
with respect to
the parameter $\mathbf{W}{qh}$
in the output layer:
$\partial L/\partial \mathbf{W}{qh} \in \mathbb{R}^{q \times h}$. Based on :numref:fig_rnn_bptt
,
the objective function
$$ \frac{\partial L}{\partial \mathbf{W}{qh}} = \sum{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}t}, \frac{\partial \mathbf{o}t}{\partial \mathbf{W}{qh}}\right) = \sum{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top, $$
where eq_bptt_partial_L_ot
.
Next, as shown in :numref:fig_rnn_bptt
,
at the final time step
$$\frac{\partial L}{\partial \mathbf{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}T} \right) = \mathbf{W}{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T}.$$
:eqlabel:eq_bptt_partial_L_hT_final_step
It gets trickier for any time step
$$\frac{\partial L}{\partial \mathbf{h}t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}t}, \frac{\partial \mathbf{o}t}{\partial \mathbf{h}t} \right) = \mathbf{W}{hh}^\top \frac{\partial L}{\partial \mathbf{h}{t+1}} + \mathbf{W}{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t}.$$
:eqlabel:eq_bptt_partial_L_ht_recur
For analysis,
expanding the recurrent computation
for any time step
$$\frac{\partial L}{\partial \mathbf{h}t}= \sum{i=t}^T {\left(\mathbf{W}{hh}^\top\right)}^{T-i} \mathbf{W}{qh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}}.$$
:eqlabel:eq_bptt_partial_L_ht
We can see from :eqref:eq_bptt_partial_L_ht
that
this simple linear example already
exhibits some key problems of long sequence models: it involves potentially very large powers of subsec_bptt_analysis
.
In practice, this truncation is effected by detaching the gradient after a given number of time steps.
Later on
we will see how more sophisticated sequence models such as long short-term memory can alleviate this further.
Finally,
:numref:fig_rnn_bptt
shows that
the objective function
$$ \begin{aligned} \frac{\partial L}{\partial \mathbf{W}{hx}} &= \sum{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}t}, \frac{\partial \mathbf{h}t}{\partial \mathbf{W}{hx}}\right) = \sum{t=1}^T \frac{\partial L}{\partial \mathbf{h}t} \mathbf{x}t^\top,\ \frac{\partial L}{\partial \mathbf{W}{hh}} &= \sum{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}t}, \frac{\partial \mathbf{h}t}{\partial \mathbf{W}{hh}}\right) = \sum{t=1}^T \frac{\partial L}{\partial \mathbf{h}t} \mathbf{h}{t-1}^\top, \end{aligned} $$
where
eq_bptt_partial_L_hT_final_step
and
:eqref:eq_bptt_partial_L_ht_recur
is the key quantity
that affects the numerical stability.
Since backpropagation through time
is the application of backpropagation in RNNs,
as we have explained in :numref:sec_backprop
,
training RNNs
alternates forward propagation with
backpropagation through time.
Besides,
backpropagation through time
computes and stores the above gradients
in turn.
Specifically,
stored intermediate values
are reused
to avoid duplicate calculations,
such as storing
$\partial L/\partial \mathbf{h}t$
to be used in computation of both $\partial L / \partial \mathbf{W}{hx}$ and
- Backpropagation through time is merely an application of backpropagation to sequence models with a hidden state.
- Truncation is needed for computational convenience and numerical stability, such as regular truncation and randomized truncation.
- High powers of matrices can lead to divergent or vanishing eigenvalues. This manifests itself in the form of exploding or vanishing gradients.
- For efficient computation, intermediate values are cached during backpropagation through time.
- Assume that we have a symmetric matrix
$\mathbf{M} \in \mathbb{R}^{n \times n}$ with eigenvalues$\lambda_i$ whose corresponding eigenvectors are $\mathbf{v}i$ ($i = 1, \ldots, n$). Without loss of generality, assume that they are ordered in the order $|\lambda_i| \geq |\lambda{i+1}|$.- Show that
$\mathbf{M}^k$ has eigenvalues$\lambda_i^k$ . - Prove that for a random vector
$\mathbf{x} \in \mathbb{R}^n$ , with high probability$\mathbf{M}^k \mathbf{x}$ will be very much aligned with the eigenvector$\mathbf{v}_1$ of$\mathbf{M}$ . Formalize this statement. - What does the above result mean for gradients in RNNs?
- Show that
- Besides gradient clipping, can you think of any other methods to cope with gradient explosion in recurrent neural networks?