In [1]:
%run Latex_macros.ipynb

<IPython.core.display.Latex object>

# Encoder/Decoder architecture

Two RNN's
- Encoder: takes input sequence $\x$
- Decoder: creates output sequence $\hat\y$

RNN's process sequences using the "loop architecture"

Consider the task of
- constructing the *next* element $\hat\y_\tp$ of sequence $\y$
- conditioned on some input sequence $\x = \x_{(1)} \dots \x_{(\tt')}$

$$
\pr{\hat\y_\tp | \x_{(1)} \dots \x_\tp}
$$

## RNN Loop architecture
- Uses a "latent state" that is updated with each element of the sequence, then predict the output

$$
\begin{array}[lll] \\
\pr{\h_\tp | \x_\tp, \h_{(\tt-1)} } & \text{latent variable } \h_\tp \text{encodes } [ \x_{(1)} \dots \x_\tp ]\\
\pr{\hat\y_\tp | \h_\tp }              & \text{prediction contingent on latent variable} \\
\end{array}
$$
    
<br>
<div>
    <center><strong>Loop with latent state</strong></center>
    <img src="images/RNN_arch_loop.png" width=70%>
</div>


# Original Encoder/Decoder architecture

<table>
    <tr>
        <th><center>RNN Encoder/Decoder without Attention<br>Bottleneck</center></th>
    </tr>
    <tr>
        <td><img src="images/RNN_layer_API_Encoder_Decoder_1.png" width=80%</td>
    </tr>
  
</table>

Critique
- bottleneck
    - *all* information about input $\x$ passes through out of Encoder (red line)
    - and must be carried over to every iteration of the Decoder loop (red box)
    
- loop architecture for Encoder and Decoder
    - dependency: horizontal line carrying latent state across time

# Cross-Attention: removing the bottleneck

We removed the bottleneck via *Cross Attention*
- Decoder has *direct access* to **all** outputs (i.e., Latent sates) of the Encoder
    - each Encoder output is proxy for a prefix of the input
    
The pink box is the sequent of Encoder outputs
$$
\bar\h_{(1:\bar T)}
$$



<table>
    <tr>
        <th><center>RNN Encoder/Decoder with Cross Attention</center></th>
    </tr>
    <tr>
        <td><img src="images/RNN_layer_API_Encoder_Decoder_Attention_1.png"
             width=80%</td>
    </tr>
   
</table>

# Encoder Self-Attention: removing the Encoder loop

There is an alternative to the loop architecture for processing sequences
- the direct function approach

The alternative to the loop was to create a "direct function"
- Taking a **sequence** $\x_{(1 \dots \tt)}$ as input
- Outputting $\hat\y_\tp$

<br>
<div>
    <center><strong>Direct function</strong></center>
    <img src="images/RNN_arch_parallel.png" width=50%>
</div>

Can output *all* elements of sequence $\hat\y$ *simultaneously*
- each output position is independent of previous output
- only dependent on input

We removed the "loop" architecture of the Encoder by using  the direct function approach
- the mechanism enabling each position of the Encoder output to *attend* to the entire sequence $x$ is called *Self-Attention*
    - Notice: no dependency arrow between circles in the Encoder
- Encoder output is a direct function of **all** positions in the input
    - all Encoder output positions can be computed *in parallel*

The blue box represents the *entire* input sequence
$$
\x_{(1:\bar T)}
$$
We no longer refer to the Encoder output as a Latent state
- no more loop !

<table>
    <tr>
        <th><center>RNN Encoder/Decoder with Cross Attention/Decoder Self Attention</center></th>
    </tr>
    <tr>
        <td><img src="images/RNN_layer_API_Encoder_Decoder_Attention_Encoder_Self_Attention.png"
             width=80%</td>
    </tr>
   
</table>

Observe that
- by removing the looping architecture from the Encoder
- the Encoder is no longer called an RNN

# Masked Self Attention

With *unmasked* Self Attention
- Encoder output $\bar\h_\tp$ at position $\tt$ 
- is a function of **all** inputs $\x_{(1: \bar T)}$
    - including positions after $\tt$
    
This is useful, for example, when the meaning of a word depends on its *entire* context.
- as in our motivating example

For certain tasks (not so for our motivating example), full visibility of all inputs is not permissible
- "looking into the future"
    - e.g., predict stock return based only on **past** information

In this case, we use *masked* Self Attention
- we use a mask to hide inputs from position $\tt$ onwards so that
- output $\bar\h_\tp$ at position $\tt$ 
- is a function  only of **preceding** inputs $\x_{(1 : \tt-1)}$

We will see the use of masking in the next section.

# Causal Masked Self Attention: removing the Decoder loop

Finally we remove the loop architecture for the Decoder as well using
a different "flavor" of Self-Attention
- Masked Self-Attention.



<table>
    <tr>
        <th><center>Encoder/Decoder with Cross Attention and Self Attention (Encoder/Decoder)</center></th>
    </tr>
    <tr>
        <td><img src="images/RNN_layer_API_Encoder_Decoder_Attention_All_Self_Attention.png"
             width=80%</td>
    </tr>
   
</table>

The grey box represents the *entire* output sequence
$$
\hat\y_{(1:T)}
$$

From this diagram: it appears that
- the Encoder/Decoder can produce output $\hat\y_\tp$
- while attending to outputs *that have not yet been generated* at the start of step $\tt$
$$\hat\y_{(\tt : T)}$$
- "looking into the future"

That is, it is computing
$$\prc{\hat\y_\tp}{\hat\y_{(1:T)} }$$
What is going on ?

## Teacher forcing at training time

An explanation of this strange behavior
is that the behavior of the model is *different*
- at training time
- versus at test/inference time

*Teacher Forcing* alters the training behavior in order to improve
the ability of a model to learn.

Let's examine  [Teacher Forcing](Teacher_Forcing.ipynb) in depth.



## Masked attention

Hopefully it is clear that, regardless of whether we are computing $\hat\y_\tp$
- at training time
- at inference time

the computation should depend only on positions $1:\tt-1$ of the output.
- can't peek into the future

To enforce this
- we *mask* the outputs
- so that only positions $1:\tt-1$ are visible when generating output position $\tt$

The general mechanism of hiding some inputs is called
- **Masked Self-Attention**

The specific masking of only future positions is called
- **Causal Masked Self-Attention** of **Causal Self-Attention**


In [2]:
print("Done")

Done
