### Checklist for submission

It is extremely important to make sure that:

1. Everything runs as expected (no bugs when running cells);
2. The output from each cell corresponds to its code (don't change any cell's contents without rerunning it afterwards);
3. All outputs are present (don't delete any of the outputs);
4. Fill in all the places that say `# YOUR CODE HERE`, or "**Your answer:** (fill in here)".
5. Never copy/paste any notebook cells. Inserting new cells is allowed, but it should not be necessary.
6. The notebook contains some hidden metadata which is important during our grading process. **Make sure not to corrupt any of this metadata!** The metadata may for example be corrupted if you copy/paste any notebook cells, or if you perform an unsuccessful git merge / git pull. It may also be pruned completely if using Google Colab, so watch out for this. Searching for "nbgrader" when opening the notebook in a text editor should take you to the important metadata entries.
7. Although we will try our very best to avoid this, it may happen that bugs are found after an assignment is released, and that we will push an updated version of the assignment to GitHub. If this happens, it is important that you update to the new version, while making sure the notebook metadata is properly updated as well. The safest way to make sure nothing gets messed up is to start from scratch on a clean updated version of the notebook, copy/pasting your code from the cells of the previous version into the cells of the new version.
8. If you need to have multiple parallel versions of this notebook, make sure not to move them to another directory.
9. Although not forced to work exclusively in the course `conda` environment, you need to make sure that the notebook will run in that environment, i.e. that you have not added any additional dependencies.

**FOR HA1, HA2, HA3 ONLY:** Failing to meet any of these requirements might lead to either a subtraction of POEs (at best) or a request for resubmission (at worst).

We advise you to perform the following steps before submission to ensure that requirements 1, 2, and 3 are always met: **Restart the kernel** (in the menubar, select Kernel$\rightarrow$Restart) and then **run all cells** (in the menubar, select Cell$\rightarrow$Run All). This might require a bit of time, so plan ahead for this (and possibly use Google Cloud's GPU in HA1 and HA2 for this step). Finally press the "Save and Checkout" button before handing in, to make sure that all your changes are saved to this .ipynb file.

### Fill in name of notebook file
This might seem silly, but the version check below needs to know the filename of the current notebook, which is not trivial to find out programmatically.

You might want to have several parallel versions of the notebook, and it is fine to rename the notebook as long as it stays in the same directory. **However**, if you do rename it, you also need to update its own filename below:

In [19]:
nb_fname = "HA2-Part2_EllaGuiladi_EmmaRydholm.ipynb"

### Fill in group number and member names (use NAME2 and GROUP only for HA1, HA2 and HA3):

In [20]:
NAME1 = "Ella Guiladi" 
NAME2 = "Emma Rydholm"
GROUP = "7"

### Check Python version

In [21]:
from platform import python_version_tuple
assert python_version_tuple()[:2] == ('3','7'), "You are not running Python 3.7. Make sure to run Python through the course Conda environment."

### Check that notebook server has access to all required resources, and that notebook has not moved

In [22]:
import os
nb_dirname = os.path.abspath('')
assignment_name = os.path.basename(nb_dirname)
assert assignment_name in ['IHA1', 'IHA2', 'HA1', 'HA2', 'HA3'], \
    '[ERROR] The notebook appears to have been moved from its original directory'

### Verify correct nb_fname

In [23]:
from IPython.display import display, HTML
try:
    display(HTML(r'<script>if("{nb_fname}" != IPython.notebook.notebook_name) {{ alert("You have filled in nb_fname = \"{nb_fname}\", but this does not seem to match the notebook filename \"" + IPython.notebook.notebook_name + "\"."); }}</script>'.format(nb_fname=nb_fname)))
except NameError:
    assert False, 'Make sure to fill in the nb_fname variable above!'

### Verify that your notebook is up-to-date and not corrupted in any way

In [24]:
import sys
sys.path.append('..')
from ha_utils import check_notebook_uptodate_and_not_corrupted
check_notebook_uptodate_and_not_corrupted(nb_dirname, nb_fname)

Matching current notebook against the following URL:
http://raw.githubusercontent.com/JulianoLagana/deep-machine-learning/master/home-assignments/HA2/HA2-Part2.ipynb
[SUCCESS] No major notebook mismatch found when comparing to latest GitHub version. (There might be minor updates, but even that is the case, submitting your work based on this notebook version would be acceptable.)


# HA2:  Part 2 - Transformers and self-attention
$$
\renewcommand{\vec}[1]{#1}
\def\x{\vec{x}}
\def\y{\vec{y}}
\def\dim{d}
\def\w{W}
\def\wu{Z}
\def\R{\mathbb{R}}
\def\linMap{W}
% Query, key and val
\def\q{\vec{q}}
\def\k{\vec{k}}
\def\v{\vec{v}}
\def\Wq{\linMap_Q}
\def\Wk{\linMap_K}
\def\Wv{\linMap_V}
$$
*You should have completed part 1 before starting with this one*

In this part we will take a closer look at the transformer architecture and the self-attention operation.
We will start with basic self-attention and gradually construct an actual self-attention module.
Finally we will construct a complete transformer and test it on an actual problem.

The focus is on a conceptual understanding of the transformer but you will have to implement a few key elements of a transformer. Along the way we will try to give some best practices for constructing a more complex network architecture.

Let's start with importing the module's we are going to need:

In [25]:
import torch
import torch.nn as torch_nn
import torch.nn.functional as F

# 1. Basic self-attention

The key-stone of the transformer architecture, self-attention is a sequence-sequence operation which transforms a sequence of input vectors $\x_1, \dots \x_t$ to output vectors $\y_1, \dots \y_t$.
Remember that all vectors have the same dimension $\dim$, i.e. $\x_i, \y_i \in \R^{\dim}, \forall i = 1, \dots t$.

## Weighted average
The actual transformation is a simple weighted average
$$
\y_i = \sum_{j} \x_j \w_{ji}.
$$

In an actual transformer, weighted averages are computed often and for long sequences. Therefore, the implementation must be fast in order for training to be even possible.
With high-level frameworks such as `pytorch`, the key to fast code is often to reduce loops and instead express computations as matrix operations.

**(2 POE)** Complete the function snippet below to implement simple weight sharing.

To pass this part of the assignment your implementation only has to be correct, not efficient, but to get the first POE, you must implement it with just a single for loop. For the second POE, do it without any loops at all.

*Hint*: Take a look at how `torch.bmm` is used later in the implementation

In [26]:
def weighted_avg(x, weights):
    """Weighted average
    Calculates a weighted average of a batch of sequences of vectors.
    
    Args:
        x (torch.Tensor): Shape (batch_size, dim, seq_len)
        weights (torch.Tensor): Shape (batch_size, seq_len, seq_len)
    
    Returns:
        y (torch.Tensor): Shape (batch_size, dim, seq_len)
        
    """
    #weights = F.softmax(weights, dim=2)
    y = torch.bmm(x,weights)
    
    return y

Make sure to test your implementation with the unit tests below.
The tests cover:

1. Dimensionality
2. Uniform weights $\w_{ji} = \frac{1}{t}$ should produce $y_i: y_i = \frac{1}{\dim} \sum_{j} x_j,\, \forall i = 1, \dots t$
 (i.e., every $y_i$ is an average of the input sequence).
3. A specific numerical example with batch size = 2, $t = 2,\, \dim=1$.

In [27]:
def test_weighted_avg(function):
    """
    Args:
        function: Implementation to test
    """
    # Testing dimension of averaged tensor.
    batch_size, dim, seq_len = 5, 2, 3
    x = torch.rand(batch_size, dim, seq_len)
    weights = torch.rand(batch_size, seq_len, seq_len)
    y = function(x, weights)
    assert y.shape == (batch_size, dim, seq_len), "Dimension error: expected y to have shape {}, got {}.".format(
        (batch_size, seq_len, dim), tuple(y.shape))
    
    # Testing uniform weights preserve x.
    batch_size, dim, seq_len = 5, 2, 3
    x = torch.rand(batch_size, dim, seq_len)
    weights = torch.ones((batch_size, seq_len, seq_len)).float() / seq_len
    y = function(x, weights)
    assert all(torch.allclose(y_b.mean(1), y_b[:, 0]) for y_b in y),\
        "Numerical error: With uniform weights, expected y_i = y_j forall i, j (within each batch)."
    assert all(torch.allclose(y_b.mean(1), x_b.mean(1)) for (x_b, y_b) in zip(x, y)),\
        "Numerical error: With uniform weights, expected y_i = sum_j x_j, for all i"
    
    # Actual numerical example.
    x = torch.tensor([4, 1]).reshape((1, 1, 2)).float()
    unnorm_weights = torch.arange(1, 5).reshape((1, 2, 2)).float()
    scale = unnorm_weights.sum(1).reshape((1, 1, 2))
    weights = unnorm_weights / scale

    y = function(x, weights)
    y_true = torch.tensor([7/4, 2]).reshape(1, 1, 2).float()
    assert torch.allclose(y, y_true), "Numerical error, expected: {}, got {}".format(y_true, y)
    
    print("Test passed.")

test_weighted_avg(function=weighted_avg)

Test passed.


## Defining weights through the dot product
A simple way to define $\w_{ji}$ is with the dot product

$$
\wu_{ji} = \x_j^T \x_i.
$$
which maps the pair of input vectors to a non-negative scalar, $\R^{\dim \times \dim} \to [0, \infty)$.
We then use a softmax to obtain normalised $\w_{ji} \in (0, 1]$:

$$
\w_{ji} = \frac{ e^{\wu_{ji}} }{ \sum_j e^{\wu_{ji}} }.
$$

**(1 POE)**
What is the difference between these weights and the weights in ordinary networks, e.g. a CNN?

**Your answer:** 

These weigths are not learnable, but instead calculated only using the input vector x. Also, these weights sums to 1, because we apply the softmax function, which is not the case for weights in a ordinary network. Lastly, these weights depend on each other in order to give words with similar word embeddings larger weights, which weights in ordinary networks does not.

**(2 POE)** 
The dot product is essential for calculating the weights. As we progress, we will make slight modifications to the inputs but we will still base it around a function which calculates a softmax-normalized dot product.
Therefore, you need to complete the implementation below:

Again, this function will be evaluated often and for long sequences in the transformer block. For POE's, implement it without using for loops.

In [28]:
def normalized_dot_product(v_1, v_2):
    """Normalized dot products between all pairs of vectors in a sequence
    Takes two batches of sequences of vectors as input.
    Sequences in the batch are processed independently.
    The normalization is done with a softmax function along the columns of the weight matrices.
    
    Args:
        v_1 (torch.Tensor): Shape (batch_size, dim, seq_len)
        v_2 (torch.Tensor): Shape (batch_size, dim, seq_len)

    Returns:
        norm_dot_prod (torch.Tensor): Shape (batch_size, seq_len, seq_len)
    """
    
    #v_1_transpose (torch.Tensor): Shape (batch_size, seq_len, dim)
    v_1_transpose = torch.transpose(v_1, 1, 2)
    
    dot_prod = torch.bmm(v_1_transpose, v_2)
    
    
    norm_dot_prod = F.softmax(dot_prod, dim=1)
    
    
    return norm_dot_prod

Make sure to test your implementation with the unit tests below.
The tests cover:

1. Dimensionality
2. Normalized in the correct dimension
3. A specific numerical example

In [29]:
import numpy as np

def test_normalized_dot_product(function):
    """
    Args:
        function: Implementation to test
    """
    
    batch_size, dim, seq_len = 5, 2, 3
    v_1 = torch.rand(batch_size, dim, seq_len)
    v_2 = torch.rand(batch_size, dim, seq_len)
    weights = function(v_1, v_2)
    
    # Testing dimension of weights.
    assert weights.shape == (batch_size, seq_len, seq_len),\
    "Dimension error: expected weights to have shape {}, got {}.".format(
        (batch_size, seq_len, seq_len), tuple(weights.shape))
    
    # Testing weights non-negative
    # (Boolean tensor's can be reduced to a single boolean)
    assert not (weights < 0.0).any() ,\
    "Value error: expected weights to be non-negative."
    
    # Testing weights smaller than one
    assert (weights < 1.0).all() ,\
    "Value error: expected weights to be non-negative."
    
    assert torch.allclose(weights.sum(1), torch.ones((batch_size, seq_len))),\
        "ValueError: expected columns (dim 1) to sum to 1.0"
    
    # Actual numerical example
    v_1 = torch.tensor([[1, 2], [-1, 1]]).float().reshape((1, 2, 2))
    v_2 = torch.tensor([[1, 0], [1, -1]]).float().reshape((1, 2, 2))
    e = np.exp(1)
    true_weights = torch.tensor([
        [1 / (e**3 + 1), e**2 / (e**2 + 1)],
        [e**3 / (e**3 + 1), 1 / (e**2 + 1)]
    ]).reshape((1, 2, 2))
    weights =  function(v_1, v_2)
    assert torch.allclose(true_weights, weights),\
    "Numerical error: expected {}, got {}.".format(true_weights, weights)
    
    print("Test passed.")   
    
test_normalized_dot_product(function=normalized_dot_product)

Test passed.


That's it, we have now the building blocks needed for basic self-attention:

In [30]:
def basic_self_attention(x):
    """Basic self-attention
    Transforms a batch of sequences of vectors.
    
    Args:
        x (torch.Tensor): Shape (batch_size, dim, seq_len)
    
    Returns:
        y (torch.Tensor): Shape (batch_size, dim, seq_len)
    """
    weights = normalized_dot_product(x, x)
    return weighted_avg(x, weights)

# 2. A self-attention module
Like you saw in the video lectures, self-attention is rarely used in the basic form we have created above.
Let's do the modifications needed to construct an actual transformer.

We will wrap it in a proper `torch.nn` module to create a building block that we can use in a network.
Creating your own module is actually not that common, frameworks like `pytorch` are built to be *modular* and we can often create very specific networks by combining standard modules. That is a good thing, since it enables us to express interesting models in a high-level interface and as a bonus, we build a model from well-tested and efficient parts.
With that said, you might find yourself in a situation (perhaps already in the project) where no off-the-shelf module suits your need and you have to create one yourself. View this latter part as an example/inspiration of how to construct a non-trivial custom module.

## Queries, keys and values
The self-attention is extended with three linear mappings $\Wq, \Wk, \Wv \in \R^{\dim \times \dim}$ .
These give us learnable parameters and make self-attention more flexible.
The three matrices map the input $\x_i$ into a query, key and value respectively:

\begin{align}
    \q_i = \Wq \x_i \\
    \k_i = \Wk \x_i \\
    \v_i = \Wv \x_i
\end{align}

First, we modify the self-attention by redefining the unnormalized weights (while reusing the notation):

\begin{align}
    \wu_{ji} = \q_j^T \k_i \Big{/} \sqrt{\dim}
\end{align}
The normalized weights are still obtained by applying the softmax function.

**(1 POE)** Explain why we scale the dot product with the factor $1 / \sqrt{\dim}$.

**Your answer:** 

The self-attention is defined as inner product of Query and Key divided by square root of the dimension. The longer the sentence, the more words there are, resulting in a larger number of an inner product. Dividing with that factor, i.e. the square root of the dimension, will act as a variance balance that will lead to better learning since it will stabilize gradients during training. 

Finally, the weighted average modified and is now based on the values $\v_j$, instead of on $\x_j$ directly:
$$
\y_i = \sum_{j} \v_j \w_{ji}.
$$

We can reuse our dot product calculation by simple *wrapping* it in a function that takes queries and keys as the argument:

In [31]:
def query_key_weights(queries, keys):
    """Weights from query-key dot product.
    Softmax-normalised dot product weights
    Calculates weights for a batch of sequences of vectors.
    
    Args:
        queries (torch.Tensor): Shape (batch_size, dim, seq_len)
        keys (torch.Tensor): Shape (batch_size, dim, seq_len)
    
    Returns:
        weights (torch.Tensor): Shape (batch_size, seq_len, seq_len)
    """
    dim = queries.shape[2]
    queries = queries / (dim ** (1/4))
    keys    = keys / (dim ** (1/4))
    return normalized_dot_product(queries, keys)

## Multi-head self-attention

The model should be able to find different patterns in the input sequence, which is why we use multiple heads.

Now, we'll create the actual self-attention function, which includes multiple heads.
For implementation simplicity and efficiency we will do a version called *narrow* self-attention, where the input vector is split into parts and each attention head is applied to just one part of the vector.
Imagine that we have $\d = 64$ and four heads, then each head would operate on a vector with dimension $64 / 4 = 16$.

## Constructing the module
Below is an implementation of our self-attention module. We try to show you how a typical custom model looks like. Part of that is to do full vectorization (i.e. no loops). The result is a lot of manipulation of shapes and dimension order of intermediate tensors. It is not very readable and it is quite difficult to wrap your head around it but since you are likely to use and modify other peoples code (in the project or some later time), it is good that you get exposed to it now.

In [32]:
class SelfAttention(torch_nn.Module):
    def __init__(self, dim, heads):
        """(Narrow) Self-attention module

        Args:
            dim (int): The full embedding dimension of the input vectors
            heads (int): The number of heads in the multi-head attention.
        """
        super().__init__()
        if not dim % heads == 0:
            raise ValueError(
                "The embedding dim. must be divisible by the number of heads for the vectorization to work."
            )
        self.dim = dim
        self.heads = heads
        part_dim = dim // heads
        # Linear maps for q, k and v
        self.Wq = torch_nn.Linear(part_dim, part_dim, bias=False)
        self.Wk = torch_nn.Linear(part_dim, part_dim, bias=False)
        self.Wv = torch_nn.Linear(part_dim, part_dim, bias=False)
        # Linear mapping to return to the original 
        self.WO = torch_nn.Linear(heads * part_dim, dim)

    def forward(self, x):
        """Multi-headed self attention

        Each head operates on a part of the embedding, i.e. we have q, k and v with shape
        (batch_size, seq_length, heads, dim / heads)
        
        Args:
            x (Tensor): Input with shape (batch_size, seq_length, dim)
        """
        batch_size, seq_length, dim = x.shape
        part_dim = dim // self.heads
        x = x.reshape(batch_size, seq_length, self.heads, part_dim)
        
        keys = self.Wk(x)
        queries = self.Wq(x)
        values = self.Wv(x)
        
        keys = self._restructure_tensor(keys, batch_size, seq_length, part_dim)
        queries = self._restructure_tensor(queries, batch_size, seq_length, part_dim)
        values = self._restructure_tensor(values, batch_size, seq_length, part_dim)

        weights = query_key_weights(queries, keys)

        y_tilde = weighted_avg(values, weights)
        y_tilde = (
            y_tilde.transpose(2, 1)
            .reshape(batch_size, self.heads, seq_length, part_dim)
            .transpose(1, 2)
            .contiguous()
            .reshape(batch_size, seq_length, part_dim * self.heads)
        )
        return self.WO(y_tilde)

    def _restructure_tensor(self, x, batch_size, seq_length, part_dim):
        """ Reshaping q, k and v tensors

        For efficient vectorisation we stack the different heads in the batch_size dimension.
        Think of it as temporarily expanding the batch_size with every head.
        """
        return (
            x.transpose(1, 2)
            .contiguous()
            .reshape(batch_size * self.heads, seq_length, part_dim)
            .transpose(2, 1)
        )

# The transformer block

The majority of the implementation complexity is actually in the `SelfAttention` module. The transformer block is rather straight forward, it is just like the one described in the video lectures:

In [33]:
class TransformerBlock(torch_nn.Module):
    """Transformer block"""

    def __init__(self, dim, heads):
        super().__init__()

        self.self_attention = SelfAttention(dim, heads)

        self.normalization_1 = torch_nn.LayerNorm(dim)
        self.normalization_2 = torch_nn.LayerNorm(dim)
        
        # The size of the hidden layer is a hyper-parameter,
        # but the consensus is that it should at least be larger than the input/output size
        self.feed_forward = torch_nn.Sequential(
            torch_nn.Linear(dim, 4 * dim),
            torch_nn.ReLU(),
            torch_nn.Linear(4 * dim, dim),
        )

    def forward(self, x):
        y = self.self_attention(x)
        # Note how the residual (skip) connections are implemented as simple addition.
        x = self.normalization_1(x + y)
        fed_forward = self.feed_forward(x)
        return self.normalization_2(fed_forward + x)

Now, we are done with the general module. To create an actual transformer yet we must choose an actual problem so that we can specify input, embedding and output.
Let's do that.

# 3. IMDB Classification

Transformers are often very complex models. Whenever you see impressive transformer results they are likely produced with a transformer with many millions, if not billions, of parameters. We don't really have that computational budget for a part of a home assignment. Instead, we will show a classifying task that is reasonable but still not a toy example: classification of IMDB reviews. Even this small example takes a considerable time to train.

The purpose is to build on the computer labs and to give you some inspiration for how to solve a general problem with `pytorch`. It will show you how to install additional python libraries (useful for the project) and some advice on how to construct a training/validation loop. We do not expect you to modify the code, **you don't even have to run it** if you feel that your cloud credits are starting to run low. However, you should read and understand the code, it will help you answer the questions at the end.

## The data

The IMDB data is provided by an external python module called `torchtext`.
You can add it to the dml conda environment with:
```
conda install -c pytorch torchtext
```
Make sure that you have activated the dml environment before your run it.

Processing text data can be tedious and error prone. For prototyping it is nice to use some third-party library which has done most of the work for you. You do not really need to focus on the data processing here, since it will be different for every task.


In [34]:
from torchtext import data, datasets

TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = data.Field(sequential=False)

def get_loaders(vocabulary_len, batch_size, device, split_ratio=0.8):
    """Load the IMDB data"""
    tdata, _ = datasets.IMDB.splits(TEXT, LABEL)
    train, test = tdata.split(split_ratio)

    TEXT.build_vocab(
        # We have to leave space for two special tokens.
        train, max_size=vocabulary_len - 2
    )
    LABEL.build_vocab(train)

    train_loader, test_loader = data.BucketIterator.splits(
        (train, test), batch_size=batch_size, device=device
    )
    return train_loader, test_loader

def view_example_text(index):
    """Helper function to look at a sample.
    
    The dataset is quite slow to load. 
    """
    train, test = datasets.IMDB.splits(TEXT, LABEL)
    sample_text = train[index].text
    sample_label = train[index].label
    # Simply print the list of words, separated by space.
    print(" ".join(sample_text))
    print("\nLabel: ", sample_label)

view_example_text(118)

here is one the entire family will enjoy... even those who consider themselves too old for fairy tales. shelley duvall outdid herself with this unique, imaginative take on nearly all of the popular fairy tales of childhood. the scripts offer new twists on the age-old fables we grew up on and they feature a handful of stars in each episode. "cinderella" is no exception to duvall's standard and in my opinion it's one of the top five of the series, highlighted by jennifer beals (remember her from "flashdance"--and she's still in hollywood today making a movie here and there) in the title role, jean stapleton as the fairy godmother with a southern accent and eve arden as the embodiment of wicked stepmotherhood. edie mcclurg ("ferris bueller's day off") and jane alden make for a hilarious duo as the stepsisters. matthew broderick is an affable prince henry. you'll all keep coming back for this one!

Label:  pos


## The transformer
We will create a simple transformer that takes as input text in the form of a python list of words and which outputs a  probability vector over the two classes "pos" and "neg" (technically, the output will be the input to a log-softmax).
We make the simplest (and less memory efficient) version of position embedding as described in the video lectures.

In [35]:
class Transformer(torch_nn.Module):
    def __init__(self, dim, heads, depth, seq_length, num_tokens, num_classes, device):
        super().__init__()
        self.device = device

        self.num_tokens = num_tokens
        
        self.pos_emb = torch_nn.Embedding(seq_length, dim)
        self.token_emb = torch_nn.Embedding(num_tokens, dim)

        transformer_blocks = []
        for _ in range(depth):
            transformer_blocks.append(TransformerBlock(dim=dim, heads=heads))

        # The Sequential wrapper is convenient when you want to repeat similar blocks.
        # A down-side is that it is harder access intermediate values for debugging.
        self.transformer_blocks = torch_nn.Sequential(*transformer_blocks)

        # The last part is problem specific. Here we want to map our transformer embeddings
        # to a probability distribution.
        # We will use a linear layer to produce log logits (the input to a  log-softmax function).
        self.output_map = torch_nn.Linear(dim, num_classes)

    def forward(self, x):
        """Transformer forward method

        Args:
            x Tensor(batch_size, seq_length): Word indices representing sequence of words.
        Returns:
            Tensor(batch_size, num_classes): Log logits
        """
        tokens = self.token_emb(x)
        batch_size, seq_length, dim = tokens.size()

        # Note that we create a completely new tensor which must be moved to the proper device.
        # This is why we must store the device in self.device.
        pos = torch.arange(seq_length, device=self.device)
        pos = self.pos_emb(pos)[None, :, :].expand(batch_size, seq_length, dim)

        x = tokens + pos
        x = self.transformer_blocks(x)

        x = self.output_map(x.mean(dim=1))
        return F.log_softmax(x, dim=1)

Now, for the train./val loop. This can be written in many ways but based on common misstakes in HA1, hints might be in order:

- Separate your code into smaller pieces, i.e. functions. It makes it easier to find bugs and easier to reuse code.
- Use separate functions to calculate metrics. If you want to calculate, say accuracy, during both training and validation, don't copy the code. Write one function and make sure that it works, then reuse it.
- Adding measurements to a running metrics can be tricky. Below is a solution that is a bit overkill but that is okay, since it is hard to use it incorrectly.

Note 1: the code below can be modified so that you can play around with it.

Note 2: timing this on Azure, a single epoch took ~5 min. Feel free to reduce the number of epochs or just study the code.

In [36]:
from time import time

def train_epoch(model, train_loader, optimizer, scheduler, max_seq_len):
    """Train epoch"""
    train_loss = AccumulatingMetric()
    train_acc = AccumulatingMetric()
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        input_, label = batch.text[0], batch.label - 1

        input_ = _truncate_input(input_, max_seq_len)
        pred = model(input_)
        loss = F.nll_loss(pred, label)
        loss.backward()
        train_loss.add(loss.item())

        train_acc.add(accuracy(pred, label))

        # Gradient clipping is a way to ensure
        # torch_nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
    return train_loss.avg(), train_acc.avg()


def validate_epoch(model, val_loader, max_seq_len):
    val_loss = AccumulatingMetric()
    val_acc = AccumulatingMetric()
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            input_, label = batch.text[0], batch.label - 1

            input_ = _truncate_input(input_, max_seq_len)
            pred = model(input_)
            val_loss.add(F.nll_loss(pred, label).item())

            val_acc.add(accuracy(pred, label))

    return val_loss.avg(), val_acc.avg()  # TODO: loss


def accuracy(pred, label):
    hard_pred = pred.argmax(1)
    return (hard_pred == label).float().mean().item()


def _truncate_input(input_, max_seq_len):
    if input_.size(1) > max_seq_len:
        input_ = input_[:, :max_seq_len]
    return input_


class AccumulatingMetric:
    """Accumulate samples of a metric and automatically keep track of the number of samples."""
    def __init__(self):
        self.metric = 0.0
        self.counter = 0

    def add(self, value):
        self.metric += value
        self.counter += 1
        
    def avg(self):
        return self.metric / self.counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_tokens = 50_000
max_length = 512
embedding_size = 128
num_heads = 8
num_classes = 2
depth = 6

model = Transformer(
    dim=embedding_size,
    heads=num_heads,
    depth=depth,
    seq_length=max_length,
    num_tokens=num_tokens,
    num_classes=num_classes,
    device=device)

model.to(device)

lr = 1e-4
lr_warmup = 1e4
num_epochs = 5
batch_size = 6

train_loader, test_loader = get_loaders(num_tokens, batch_size, device)

optimizer = torch.optim.Adam(lr=lr, params=model.parameters())
# A scheduler is a principled way of controlling (often decreasing) the learning rate as time progresses.
# Read more: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lambda i: min(i / (lr_warmup / batch_size), 1.0)
)

print("Starting training")
for epoch in range(1, num_epochs + 1):
    start = time()
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, max_length)
    val_loss, val_acc = validate_epoch(model, test_loader, max_length)
    end = time()
    print(
        "Epoch: {}/{}: time: {:.1f}, train loss: {:.3f}, train acc: {:.3f}, val. loss {:.3f}, val. acc: {:.3f}".format(
            epoch, num_epochs, end - start, train_loss, train_acc, val_loss, val_acc
        )
    )
print("You have now trained a transformer!")

# I'm adding a ''# YOUR CODE HERE' tag so that the code above is not hidden when the assignment is generated.
# YOUR CODE HERE

Starting training
Epoch: 1/5: time: 305.8, train loss: 0.698, train acc: 0.514, val. loss 0.681, val. acc: 0.513
Epoch: 2/5: time: 305.5, train loss: 0.629, train acc: 0.639, val. loss 0.594, val. acc: 0.725
Epoch: 3/5: time: 305.7, train loss: 0.536, train acc: 0.735, val. loss 0.501, val. acc: 0.762
Epoch: 4/5: time: 306.2, train loss: 0.472, train acc: 0.777, val. loss 0.484, val. acc: 0.794
Epoch: 5/5: time: 305.1, train loss: 0.419, train acc: 0.812, val. loss 0.474, val. acc: 0.804
You have now trained a transformer!


## Transformers and RNN

Now, that you have gotten a practical feel for the transformer it is time to reflect on some of its important properties:

**(2 POE)** Why are the significant differences between a transformer and an RNN?
In particular, how do the differences make it easier to train a transformer, compared to an RNN?

**Your answer:** 

- Both RNNs and Transformers are designed to handle sequential data. One difference between them is however that Transformers do not require that the sequential data be processed in order, in difference to RNN. This indicated that if the input data for example is a sentence, the Transformer does not need to process the beginning of it before the end since Transformers are non-sequential (process the whole sentence rather than word by word). This makes Transformers able to perform parallelization, which decreases the training time.


- Transformers also has the attention mechanism, which the RNNs lack. The attention mechanism doesn't suffer from short term memory, i.e. it can access words generated earlier in the sequence. RNN’s have a shorter window to reference from, so when the sentences/input text gets longer, they can not access words generated earlier in the sequence. LSTMs and GRUs has a bigger capcity to capture long term dependencies, however they still fail when the input sequence is to long.


- A transformer also has access to all of the hidden states of the entire encoding part, in contrast to the RNNs.


- The main reason for why its easier to train a transformer compared to a RNN is due to transformers avoiding recursion by enabling paralellization in computations as well as increasing performance due to long term dependencies. The underlying differences are the that transformers are non-sequential, transformers has self-attention and that both multi-head attention as well as positional embeddings provides information about the relationships between words. All of these characteristicts together, makes the training of a transformer compared to a RNN easier. 

**(2 POE)** Self-attention maps sets to sets. It is an important part of what makes transformers so useful and general. Explain what this property means?

Ironically, this property is actually a bit of an issue when we want to process text (or any NLP problem).
Why, and how do we try to fix it?

**Your answer:** 

This property means that the transformers do not require the sequential data to be processed in order and therefor does not take the order of the inputs into account. The transformers doesn't have recurrence such as for example the RNNs. 

In a NLP problem, this can become an issue when the input is a sentence with several words, since the order of the words will not be taken into account in the transformer. This problem is however taken care of by positional encoding, where one include information about the positions of the words into the input embeddings. 

# Wrapping up

The transformer architecture has become incredibly popular and has produced truly amazing results.
You should now have a good insight for how they can be implemented in `pytorch`. If you are interested, here are some more resources:

- https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
- https://github.com/huggingface/transformers