<a href="https://colab.research.google.com/github/furio1999/Computer-Vision/blob/main/Attention_model_question_and_answers_application.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This notebook will be about a techique that is everywhere in NLP: attention (made popular by [Bahdanau et al. 2015](https://arxiv.org/abs/1409.0473)). Attention is loosely inspired by how human perceive language (and everything else, really), by selective cropping a subset of the information that is available at any given time. To present attention we have selected a task where one needs to be paying attention to a whole paragraph to be successful -- question answering!

# Preliminaries
In the following cell we are going to


*   Download the SQUAD 2.0 corpus
*   Download the first 100k embedding vectors from `fasttext` Common Crawl
*   Import a few useful libraries (including Moses for tokenization)


In [None]:
# install lightning and sacremoses
!pip install sacremoses pytorch_lightning &> /dev/null
!mkdir -p data/

# Download SQUAD
!wget -nc -O data/squad-2.0.json \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json &> /dev/null

# Download the fastext embeddings
N_VECTORS=100_001
!wget -O - \
    https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz 2> /dev/null |\
    gunzip | head -n $N_VECTORS > data/embeddings.txt

import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict, Optional, Any
import random
import textwrap

from sacremoses import MosesTokenizer
from tqdm import tqdm
import torch
from torch.nn.utils import rnn
from torch.nn import functional as F
import pytorch_lightning as pl

In [None]:
def cache(method):
    """
    This decorator caches the return value of a method so that results are not recomputed
    """
    method_name = method.__name__
    def wrapper(self, *args, **kwargs):
        self._cache = getattr(self, '_cache', {})
        if method_name not in self._cache:
            self._cache[method_name] = method(self, *args, **kwargs)
        return self._cache[method_name]
    return wrapper

# Example Task and Baseline


## Vocabulary
Since we base off our vocabulary on pretrained vectors, we can build the vocab without even looking at the data. We just take the first field from the vector file, which is already ordered by frequency. Lets' take a peek...

In [None]:
!head data/embeddings.txt | cut -d " " -f1-4

2000000 300
, 0.1250 -0.1079 0.0245
the -0.0517 0.0740 -0.0131
. 0.0342 -0.0801 0.1162
and 0.0082 -0.0899 0.0265
to 0.0047 0.0281 -0.0296
of -0.0001 -0.1877 -0.0711
a 0.0876 -0.4959 -0.0499
</s> 0.0731 -0.2430 -0.0353
in -0.0140 -0.2522 0.0715


We can now build our mapping between indices and tokens. In addition to our special tokens for padding (`<PAD>`) and out-of-vocabulary (`<UNK>`) we will also need another one (`<SEP>`) that will function as a delimiter between context and question, which will be concatenated in the input e.g.:

* **Question**: `What was the color of Napoleon's horse?`
* **Context**: `Napoleon's horse was white`
* **Input**: `Napoleon's horse was white <SEP> What was the color of Napoleon's horse?`


In [None]:
PAD_TOKEN = '<PAD>'
UNK_TOKEN = '<UNK>'
SEP_TOKEN = '<SEP>'

# this will hold the str2int mapping
vocab = {}

# this will hold the vectors
vectors = []

for token in (PAD_TOKEN, UNK_TOKEN, SEP_TOKEN):
    vocab[token] = len(vocab)

# the vectors for the special tokens are initialized randomly
PAD_IDX = vocab[PAD_TOKEN]
UNK_IDX = vocab[UNK_TOKEN]
SEP_IDX = vocab[SEP_TOKEN]
vectors += [torch.randn(300) for _ in vocab]

with open('data/embeddings.txt') as f:
    next(f)  # skip header
    for i, line in tqdm(enumerate(f)):
        token, *vector = line.strip().split(' ')
        vocab[token] = len(vocab)
        vector = torch.tensor([float(c) for c in vector])
        vectors.append(vector)

# we stack the list of vectors into a matrix
vectors = torch.stack(vectors, 0)

def numericalize(tokens, vocab, max_length=None):
    ids = [vocab.get(t, UNK_IDX) for t in tokens]
    return torch.tensor(ids)

100000it [00:11, 9040.32it/s]


## Datasets (SQUAD 2.0)

Now let's deal with the datasets. Each instance is a complex object that comprises several things:

* the **question**, e.g. `how many corgis does the Queen of the UK own?`
* the background **context** providing the knowledge needed to answer, e.g. a paragraph 
* the **answer**, which is a string as well as a span in the context$^{a}$: if the answer to the question is `five`, the substring `five` will be contained in the context.$^{b}$

$^{a}$In the original dataset there can be multiple answers, or none at all. We only keep the first one, and discard instances with no answers altogether.

$^{b}$We have no idea how many corgis are owned by the Queen of England. 

In [None]:
tokenizer = MosesTokenizer()

@dataclass
class Sample:
    question: str
    context: str
    answer: Optional[str]

    @cache
    def question_tokens(self) -> List[str]:
        """Tokenize the question"""
        return tokenizer.tokenize(self.question.strip())

    @cache
    def answer_tokens(self) -> List[str]:
        """Tokenize the answer"""
        if self.answer is None:
            return [UNK_TOKEN]
        return tokenizer.tokenize(self.answer.strip())

    @cache
    def context_tokens(self) -> List[str]:
        """Tokenize the context"""
        return tokenizer.tokenize(self.context.strip())

    @cache
    def answer_span_in_context(self) -> Tuple[int, int]:
        """Find the anwer span in the tokenized context"""
        if self.answer is None:
            return (-1, -1)
        answer_tokens = self.answer_tokens()
        context_tokens = self.context_tokens()
        search = 0
        while 0 <= search < len(context_tokens):
            try:
                index = context_tokens.index(answer_tokens[0], search)
                if answer_tokens == context_tokens[index:index+len(answer_tokens)]:
                    return (index, index+len(answer_tokens))
                else:
                    search = index + 1
            except ValueError:
                return (-1, -1)
        return (-1, -1)

    def has_single_word_answer(self):
        """Self-explanatory"""
        span = self.answer_span_in_context()
        return span[1] - span[0] == 1

    @cache
    def process(self):
        """Numericalizes the sample"""
        input_tokens = self.question_tokens() + [SEP_TOKEN] + self.context_tokens()
        output_tokens = self.answer_tokens()
        span = self.answer_span_in_context()
        if span != (-1, -1):
            shift = len(self.question_tokens()) + 1
            span = (span[0] + shift, span[1] + shift)
        input_ids = numericalize(input_tokens, vocab)
        output_ids = numericalize(output_tokens, vocab)
        span = torch.tensor(span)
        return (input_ids, output_ids, span)

Now we can load the data from the official SQUAD 2.0 json. To simplify matter enormously we only keep instances that satisfy the following constraints: 
* have a single word answer
* the answer is in the `fasttext` vocabulary

In [None]:
def load(path: str, keep_single_word_answers=True, remove_oov=True) -> List[Sample]:
    samples = []
    data = json.loads(Path(path).read_text())
    paragraph_data = (p for d in data['data'] for p in d['paragraphs'])
    for paragraph_data in paragraph_data:
        context = paragraph_data['context']
        for question_data in paragraph_data['qas']:
            question = question_data['question']
            try:
                answer = question_data['answers'][0]['text']
            except IndexError:
                # there is no possible answer for the question
                continue
            sample = Sample(
                question=question, 
                answer=answer, 
                context=context)
            if keep_single_word_answers and not sample.has_single_word_answer():
                pass
            elif remove_oov and sample.answer_tokens()[0] not in vocab:
                pass
            else:
                samples.append(sample)
    return samples

samples = load('data/squad-2.0.json')

In [None]:
#@title Explore SQUAD { run: "auto" }
def wrap(text):
    return "\n".join(textwrap.wrap(text))

n_sample = 33 #@param {type:"slider", min:0, max: 10000, step:1}
s = samples[n_sample]

print('Question:')
print(wrap(s.question))

print('\nAnswer:')
print(wrap(s.answer))

print('\nContext:')
print(wrap(s.context))

Question:
How many weeks did their single "Independent Women Part I" stay on
top?

Answer:
eleven

Context:
The remaining band members recorded "Independent Women Part I", which
appeared on the soundtrack to the 2000 film, Charlie's Angels. It
became their best-charting single, topping the U.S. Billboard Hot 100
chart for eleven consecutive weeks. In early 2001, while Destiny's
Child was completing their third album, Beyoncé landed a major role in
the MTV made-for-television film, Carmen: A Hip Hopera, starring
alongside American actor Mekhi Phifer. Set in Philadelphia, the film
is a modern interpretation of the 19th century opera Carmen by French
composer Georges Bizet. When the third album Survivor was released in
May 2001, Luckett and Roberson filed a lawsuit claiming that the songs
were aimed at them. The album debuted at number one on the U.S.
Billboard 200, with first-week sales of 663,000 copies sold. The album
spawned other number-one hits, "Bootylicious" and the title track,
"

Finally we do the usual chore of implementing subclasses of `torch.utils.data.Dataset` and `pl.LightningDataModule`, which will handle the production of batches of samples to be used by the models that we are going to implement. This should all look familiar to you. If it doesn't, please (re)check the previous notebooks!  

In [None]:
class QADataset(torch.utils.data.Dataset):

    def __init__(self, samples):
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
        return self.samples[item]

def collate_fn(samples, device=None):
    batch_input_ids, batch_output_ids, batch_span_ids = \
        zip(*[s.process() for s in samples])
    batch = {}
    batch['input_ids'] = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=PAD_IDX)
    batch['output_ids'] = rnn.pad_sequence(batch_output_ids, batch_first=True, padding_value=PAD_IDX)
    batch['span_ids'] = rnn.pad_sequence(batch_span_ids, batch_first=True, padding_value=-1)
    batch['samples'] = samples
    if device is not None:
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
    return batch

In [None]:
class QADataModule(pl.LightningDataModule):

    def __init__(self, samples, vocab, batch_size=32):
        super().__init__()
        self.batch_size = 32
        self.samples = samples
        self.vocab = vocab

    def prepare_data(self):
        random.seed(1337)
        samples = self.samples[:]
        random.shuffle(samples)
        i = int(len(samples) * 0.8)
        j = int(len(samples) * 0.9)
        self.train_samples = samples[:i]
        self.valid_samples = samples[i+1:j]
        self.test_samples = samples[j+1:]
        super().prepare_data()

    def setup(self, stage=None):
        self.train_dataset = QADataset(self.train_samples)
        self.valid_dataset = QADataset(self.valid_samples)
        self.test_dataset = QADataset(self.test_samples)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            pin_memory=True,
            num_workers=0,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            pin_memory=True,
            num_workers=0,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            pin_memory=True,
            num_workers=0,
        )

data = QADataModule(samples, vocab, batch_size=40)

## Measure (Exact Match)

The performance in Question Answering is assessed with a rather uncomplicated measure: the Exact Match (EM). As the name suggests, it is simply the number of times a predicted answers matches the ground truth, over the number of samples. So, if we use $a$ and $\hat a$ to denote, respectively the ground truth and the prediction, and $\mathcal{D}$ is the dataset, the EM is:

\\[ \text{EM}_{\mathcal{D}} = \frac{1}{|\mathcal{D}|} \sum_{i}^{|\mathcal{D}|}\begin{cases} 0 & \text{if}\ \ a \neq \hat a \\ 1 & \text{if}\ \ a = \hat a \end{cases} \\]

We create a simple `ExactMatch` accumulator class that computes the metric incrementally, to be plugged in `pl.LightningModule`.

In [None]:
class ExactMatch:

    def __init__(self):
        self.reset()

    def reset(self):
        """Resets the accumulators (a the beginning of the epoch)"""
        self.hit = 0
        self.tot = 0

    def value(self):
        """Fetches the current value"""
        if self.tot > 0:
            return self.hit / self.tot
        return 0.

    def update(self, predictions, ground_truth):
        """Updates the accumulator on the basis of predictions and ground truths"""
        for p, g in zip(predictions, ground_truth):
            self.tot += 1
            if p == g:
                self.hit += 1
        return self.value()

## Baseline

To perform QA we use one of the simplest model we can think of, an LSTM that takes as input the concatenation of the question and the context. The output of this model will be a probability distribution over the vocabulary of words, with each word being one possible answer. So formally:

\\[x = q | \texttt{<SEP>}|c\\]
\\[E = \text{Embed}_{\Theta_0}(x)\\]
\\[H = \text{LSTM}_{\Theta_1}(E) \\]
\\[L = \text{Linear}_{\Theta_0^T}(H_{|x|}) \\]
\\[P(a|q,c) = \frac{\text{exp}(L_a)}{\sum_{v \in V}{\text{exp}(L_v)}}\\]



N.B. $\Theta$ are the parameters of the model. 
You may have noticed that the parameters of the $\text{Linear}$ layer are set to the transpose the parameters of the embedding. This is not a mistake, but an instance of the [well-known trick of weight tying](https://arxiv.org/abs/1608.05859), useful when the input and output vocabulary are the same.

**Question**: In this model the prediction $\hat a = \text{argmax}_{v\in V} P(v|q,c)$, i.e. the mode of the distribution. Do you think there is any obvious way to improve this prediction rule and increase perfomances, without retraining anything? Do you think all words in the vocabulary can be a valid answer?

**Answer**: Check the `predict` method below.

The loss function is the vanilla cross-entropy.

\\[ \mathcal{L} = -\text{log}\ P(a|q,c) \\]

Unfortunately `pl.LightningModule` (and OOP syntax in Python, really) does not allow us to split the code. So we will comment the code directly in the cell.

In [None]:
class QAModel(pl.LightningModule):

    def __init__(self, pretrained):
        pl.LightningModule.__init__(self)
        # This we will need to convert back to string
        self.inverse_vocab = {i: t for t, i in vocab.items()}
        vocab_size, hidden_size = pretrained.size()
        self.hidden_size = hidden_size
        self.num_layers = 2

        # The embedding layer is initialized from the vectors and kept fixed
        self.embedding = torch.nn.Embedding.from_pretrained(pretrained)

        # Two layers of LSTM
        # We hate non batch-first syntax
        self.lstm = torch.nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
            bidirectional=False)
        
        # Weight tying for the final layer
        self.lin = torch.nn.Linear(hidden_size, vocab_size, bias=False)
        self.lin.weight = self.embedding.weight
        
        # Loss is cross-entropy
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
        
        # Measures
        self.em_valid = ExactMatch()
        self.em_test = ExactMatch()

    def forward(self, batch):
        """
        Computes the forward pass, returning unnormalized log probabilities (the logits)
        """
        
        # x_0
        input_ids = batch['input_ids']
        
        # Compute lengths of sequences, which we will need
        lengths = (input_ids != PAD_IDX).long().sum(-1)
        lengths = lengths.to('cpu')
        
        # x_1
        x = self.embedding(input_ids)

        # The following lines are necessary to handle different lengths
        x = rnn.pack_padded_sequence(
            x,
            lengths=lengths,
            batch_first=True,
            enforce_sorted=False,)
        
        # x_2
        # We compute the LSTM states. torch makes it a bit 
        # cumbersome to get the final states of the RNN
        # as we need a bit of reshaping
        x, (h, c) = self.lstm(x)
        x = h.view(self.num_layers, -1, self.hidden_size)[-1]
        
        #x, _ = rnn.pad_packed_sequence(x, batch_first=True)
        # We take the final state of the sequence
        #x = torch.stack([x[i, l-1] for i, l in enumerate(lengths)], dim=0)
        
        # x_3
        # Compute the logits
        logits = self.lin(x)
        
        return logits

    def compute_loss(self, logits, batch):
        """ 
        Computes the cross-entropy loss from the logits
        """
        output_ids = batch['output_ids']
        loss = self.criterion(
            # These reshapes (flattening batch and time)
            # are required by torch
            logits.view(-1, logits.size(-1)),
            output_ids.reshape(-1),)
        return loss

    def predict(self, logits, batch, mask_out_impossible_words=True) -> List[str]:
        """
        Computes a batch of predictions (as list of strings) from logits
        """
        input_ids = batch['input_ids']

        # Remember that the answer MUST be a word in the context
        # Then we can mask out all words which are NOT in the context
        # We build a masking tensor with the same shapes of the logits
        # mask[i,j] = float('-inf')
        mask = torch.full_like(logits, float('-inf'))
        
        # Then we punch holes in this thick mask, each one corresponding
        # to a word we are not masking out
        for i, seq_ids in enumerate(input_ids):
            # there is a little bug here... can guess which?
            for idx in set(seq_ids.tolist()):
                mask[i, idx] = 0.
        logits = logits + mask

        # Taking the argmax (the relative ordering between scores doesn't change with softmax)
        predictions = logits.argmax(-1).tolist()

        # Index 2 string
        predictions = [self.inverse_vocab[i] for i in predictions]
        return predictions

    def ask(
        self, 
        question: str, context: str, 
        mask_out_impossible_words: bool = True) -> str:
        """
        Run the full pipeline on a single example, given as input as strings
        """
        batch = collate_fn(
            [Sample(question=question, context=context, answer=None)],
            device=next(self.parameters()).device
        )
        logits = self(batch)
        predictions = self.predict(logits, batch)
        return predictions[0]

    @torch.no_grad()
    def evaluation(
        self, 
        batch: Tuple[torch.Tensor],
        metric: ExactMatch,):
        """
        Evaluates performance on ground truth in terms of both loss (returned)
        and EM (updated)
        """
        logits = self(batch)
        loss = self.compute_loss(logits, batch).item()
        predictions = self.predict(logits, batch)
        gold = [s.answer.strip() for s in batch['samples']]
        metric.update(predictions, gold)
        return loss

    def training_step(
        self, 
        batch: Tuple[torch.Tensor], 
        batch_idx: int
    ) -> torch.Tensor:
        """
        [Required by lightning]
        Computes loss to be used for .backward()
        """
        logits = self(batch)
        loss = self.compute_loss(logits, batch)
        return loss

    def validation_step(
        self,
        batch: Tuple[torch.Tensor],
        batch_idx: int):
        """
        [Required by lightning]
        Evaluates on batch of validation samples
        """
        loss = self.evaluation(batch, self.em_valid)
        self.log('valid_loss', loss, on_step=True, prog_bar=True)

    def test_step(
        self,
        batch: Tuple[torch.Tensor],
        batch_idx: int):
        """
        [Required by lightning]
        Evaluates on batch of test samples
        """
        loss = self.evaluation(batch, self.em_test)
        self.log('test_loss', loss, on_step=True, prog_bar=True)

    def configure_optimizers(self):
        """
        [Required by lightning]
        Initializes the optimizer
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=5e-4)
        return optimizer

    def on_validation_epoch_end(self):
        """
        [lightning]
        Logging and EM reset (validation)
        """
        self.log('valid_em', self.em_valid.value(), on_epoch=True, prog_bar=True)
        print(f'Epoch {self.current_epoch} - valid: {self.em_valid.value() * 100:.1f}% EM')
        self.em_valid.reset()

    def on_test_epoch_end(self):
        """
        [lightning]
        Logging and EM reset (test)
        """
        self.log('test_em', self.em_test.value(), on_epoch=True, prog_bar=True)
        self.em_test.reset()


In [None]:
pl.seed_everything(42)

model = QAModel(vectors)
trainer = pl.Trainer(
    max_epochs=6,
    gpus=(1 if torch.cuda.is_available() else 0),
)
trainer.fit(model, data)
trainer.test(model, data.test_dataloader());

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | embedding | Embedding        | 30.0 M
1 | lstm      | LSTM             | 1.4 M 
2 | lin       | Linear           | 30.0 M
3 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
1.4 M     Trainable params
30.0 M    Non-trainable params
31.4 M    Total params
125.783   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Epoch 0 - valid: 0.0% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 0 - valid: 4.9% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 1 - valid: 5.5% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 2 - valid: 4.6% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 3 - valid: 4.8% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 4 - valid: 4.7% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 5 - valid: 4.8% EM


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_em': 0.05451295799821269,
 'test_loss': 9.97176742553711,
 'test_loss_epoch': 9.842949867248535}
--------------------------------------------------------------------------------


In [None]:
model.ask(
    "How many corgis does the Queen of England own?",
    "The queen of England owns several mices, seven one-eyed ponies, three duck-sized corgis and a cranky old raven", 
)

'seven'

# Enter Attention





## Intro

Attention is an ubiquituous technique in modern NLP. It takes many forms, but at its core it is a way to use a **query** (a vector---or multiple ones) to perform a weighted combination over another set of **vectors**.

So, say I have a vector representing my friend Joan, and a collection of other vectors represent possible hobbies (chess, hiking, programming). I want a vector that summarizes Joan's interests. One way to do it would be obtaining relevance scores, one for each possible interest, normalizing them into a probability distribution, and use them to do a weighted sum of each hobby vector:

\\[ Z = \sum_{\text{interests}} e^{\text{score}(\text{Joan},\text{interest})} \\]
\\[ r_{\text{chess}} = e^{\text{score}(\text{Joan}, \text{chess})} / Z \\]
\\[ r_{\text{hike}} = e^{\text{score}(\text{Joan}, \text{progr})} / Z \\]
\\[ r_{\text{progr}} = e^{\text{score}(\text{Joan}, \text{progr})} / Z \\]
\\[ \text{summary} = \text{chess} \cdot r_{\text{chess}} + \text{hike} \cdot r_{\text{hike}} + \text{progr} \cdot r_{\text{progr}}  \\]

In the formula above, $\text{interest}$ (one of $\text{chess}$, $\text{hike}$, $\text{progr}$) is a vector in $\mathbb{R}^{h}$, while $\text{score}: \langle\mathbb{R}^{h}, \mathbb{R}^{h}\rangle \rightarrow \mathbb{R}$ is a function that returns that takes as arguments two vectors and returns a scalar. 
The simplest option for the scoring function is the dot product, which will be higher if the vectors have the same sign for the corresponding component. Let's code it!



In [None]:
Joan = vectors[vocab['Joan']]
chess = vectors[vocab['chess']]
hike = vectors[vocab['hike']]
progr = vectors[vocab['programming']]

# We will be performing all the step at once thanks to the power of linear algebra
interests = torch.stack([chess, hike, progr])

# You have surely recognized the softmax function in the formula above :)
# This is equivalent to performing the dot product with each vector separately
r = (interests @ Joan).softmax(-1)

# And this is equivalent to weighted sum
summary = r @ interests

Abstracting away from the Joan example, and following the code of the previous cell, attention can be understood as a function taking as input queries (matrix $Q$), keys (matrix $K$) and values (matrix $V$), in our case corresponding to Joan, the interests and the interests again. The return value is a matrix with the same shape as $Q$, where each row is a 'summary' of $V$ w.r.t. with the scoring being relative to $Q$ and $K$. (Often times either $K = V$ or $Q = K = V$)

\\[ \text{attention}(Q, K, V) = \text{softmax}(QK^T)V \\]

**Warning**: Even if $K$ and $V$ are in practical cases almost always identical, it makes sense to keep them conceptually distinct. Suppose you want to retrieve a book from your bookshelf. Makes sense to use the title instead of the full text of the novel, isn't it?

Modern attention implementations also make use of linear projections applied on the matrices, which add flexibility, and allow to use matrices with otherwise non-matching shapes:

\\[ \text{attention}(Q, K, V) = \text{softmax}((QW_{(Q)})(W_{(K)}K^T))(VW_{(V)}) \\]

One of the most powerful intuitions about attention comes from **machine translation**! In the following image (taken from [Bahdanau et al. 2015](https://arxiv.org/abs/1409.0473)) you see a heatmap of the attention softmax in a French-to-English translation model.

**Spoiler alert:** Machine translation models output words one word at a time, at each timestep using the current word as query, and the sentence to translate as keys and values.

![testo del link](https://jalammar.github.io/images/attention_sentence.png)

In the plot, each English word (columns) attends over French words (rows), focusing on its counterpart. You can how attention deviates from the diagonal in the case of different word order (*zone économique européenne* / *European Economic Area*) and in the case of many-to-many correspondances (*a été signé* / *was signed*).

## Baseline + attention

Let's now turn on how to incorporate attention in the model we have been using for question answering.

In our case we want to use attention to get a "summary" of the whole input. So, we use as $K$ (keys) and $V$ (values) the LSTM output for all timestep. But what to use as $Q$ (queries)? We want to use a vector that is a function of the whole input. In recurrent networks, the only vector that satisfies this requirement is the output at the last timestep, so that is what we will use:

\\[ \text{softmax}((H_{n}W_{(Q)})(W_{(K)}H_{1:n}))(H_{1:n}W_{(V)}) \\]

In practice, this will be the new model:

\\[x = q | \texttt{<SEP>}|c\\]
\\[E = \text{Embed}_{\Theta_0}(x)\\]
\\[H = \text{LSTM}_{\Theta_1}(E) \\]
\\[Q = H_{|x|};\ \ \ \  K = V = H_{1:|x|}\\]
\\[ A=\text{softmax}((H_{n}W_{(Q)})(W_{(K)}H_{1:n}))(H_{1:n}W_{(V)}) \\]
\\[L = \text{Linear}_{\Theta_0^T}(A) \\]
\\[P(a|q,c) = \frac{\text{exp}(L_{a})}{\sum_{v \in V} \text{exp}(L_v)}\\]

What do we stand to gain from this? After all, the last vector of the LSTM is already a summary of the sentence. The recurrence makes it so that the repeated application of an LSTM cell has to remember both what was the question and the answer over possibly hundreds of step. The attention, instead, has no problem with long-range dependencies, as it gets an "overview" look at the whole sequence, so there is no possible forgetting.

In the next cell, we will show how to implement this in a batched computation using `torch.einsum`.

In [None]:
class QAModelWithAttention(QAModel):

    def __init__(self, pretrained):
        pl.LightningModule.__init__(self)
        self.inverse_vocab = {i: t for t, i in vocab.items()}
        vocab_size, hidden_size = pretrained.size()
        self.hidden_size = hidden_size
        self.num_layers = 2
        self.embedding = torch.nn.Embedding.from_pretrained(pretrained)
        self.lstm = torch.nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
            bidirectional=False,
        )
        
        # We initialize the matrices used to get linear tranformations of
        # attention "ingredients"
        self.q_transform = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.k_transform = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.v_transform = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        
        self.lin = torch.nn.Linear(hidden_size, vocab_size)
        self.lin.weight = self.embedding.weight
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
        self.em_valid = ExactMatch()
        self.em_test = ExactMatch()

    def forward(self, batch):
        input_ids = batch['input_ids']
        
        lengths = (input_ids != PAD_IDX).long().sum(-1)
        lengths = lengths.to('cpu')
        x = self.embedding(input_ids)
        x = rnn.pack_padded_sequence(
            x,
            lengths=lengths,
            batch_first=True,
            enforce_sorted=False,)
        
        sequence, (h, c) = self.lstm(x)
        x = h.view(self.num_layers, -1, self.hidden_size)[-1]
        sequence, _ = rnn.pad_packed_sequence(sequence, batch_first=True)
        
        # We apply the linear transformations
        qw = self.q_transform(x)
        kw = self.k_transform(sequence)
        vw = self.v_transform(sequence)

        # Let's break the following line down
        # - Any axis variable on the left of '->' that appears in BOTH tensors
        #   will be treated differently according to whether appears or not to
        #   the right of '->':
        #   - APPEARS: it will me treated as a batch dimension along which to
        #     to iterate, but not perform contraction (dot product)
        #   - NOT APPEARS: it will be treated as a contraction dimension (similarly)
        #     to the inner dimension of a matrix @ matrix product
        scores = torch.einsum('bh,bth->bt', qw, kw)

        # Equivalently:
        # scores = []
        # batch_size = qw.size(0) # <-- iteration dimension
        # for i in range(batch_size):
        #    scores.append(kw[i] @ qw[i]) # <-- contraction over h
        # scores = torch.stack(scores, dim=0)

        # The scores have the same size as the input_ids because there is one
        # score per token. Yet some of the inputs are padding and need to be
        # excluded from the softmax computation. How to do it? We compute the scores
        # and add float('-inf'), so that e ** score will be 0 and thus will be
        # uninfluent the softmax denominator.
        scores[input_ids == PAD_IDX] = float('-inf')
        scores = scores.softmax(-1)

        # Now we multiply the scores with the values, which is equivalent to
        # doing a weighted sum of vw along the time axis, with scores providing
        # the weights.
        attention_summary = torch.einsum('bt,bth->bh', scores, vw)

        logits = self.lin(attention_summary)
        return logits

In [None]:
pl.seed_everything(42)

model_att = QAModelWithAttention(vectors)
trainer = pl.Trainer(
    max_epochs=6,
    gpus=(1 if torch.cuda.is_available() else 0),
)
trainer.fit(model_att, data)
trainer.test(model_att, data.test_dataloader());

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params
-------------------------------------------------
0 | embedding   | Embedding        | 30.0 M
1 | lstm        | LSTM             | 1.4 M 
2 | q_transform | Linear           | 90.0 K
3 | k_transform | Linear           | 90.0 K
4 | v_transform | Linear           | 90.0 K
5 | lin         | Linear           | 30.1 M
6 | criterion   | CrossEntropyLoss | 0     
-------------------------------------------------
1.8 M     Trainable params
30.0 M    Non-trainable params
31.8 M    Total params
127.263   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Epoch 0 - valid: 0.0% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 0 - valid: 11.9% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 1 - valid: 16.1% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 2 - valid: 18.1% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 3 - valid: 19.6% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 4 - valid: 19.0% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 5 - valid: 18.8% EM


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_em': 0.19302949061662197,
 'test_loss': 7.971844673156738,
 'test_loss_epoch': 7.979929447174072}
--------------------------------------------------------------------------------


In [None]:
model_att.ask(
    "How many corgis does the Queen of England own?",
    "The queen of England owns several mices, seven one-eyed ponies, three duck-sized corgis and a cranky old raven")

'three'

# Attention as pointing

Remember the image about attention alignment in machine translation?

![testo del link](https://jalammar.github.io/images/attention_sentence.png)

The attention mechanism provides a way to do (continuous) pointing at stuff. What if we use *this* to predict the answer, instead of producing using it to build a summary vector, which has then to be used to predict a word? This use of attention is known, unsurprisingly as **pointer** mechanism.

In practice the objective of the model would shift to predicting an index representing the position in the input of the answer. (We already have all that we need in the batching code).

\\[x = q | \texttt{<SEP>}|c\\]
\\[E = \text{Embed}_{\Theta_0}(x)\\]
\\[H = \text{LSTM}_{\Theta_1}(E) \\]
\\[Q = H_{|x|}\ \ \ \  K = V = H_{1:|x|}\\]

\\[L = (QW_{(Q)})(W_{(K)}K^T) \\]

\\[P(i|q,c) = \frac{\text{exp}(L_{i})}{\sum_{j}^{|x|} \text{exp}(L_j)}\\]

We don't have to perform anything extra compared to the previous model. If anything, this model is less complicated, as we are not using the values, and we have no output projection matrix.

The cross-entropy loss function could be then computed using as labels the indices of the ground truth answer.

\\[ \mathcal{L} = -\text{log}\ P(i|q,c) \\]

Let's code it!

In [None]:
class QAModelWithPointer(QAModelWithAttention):

    def __init__(self, pretrained):
        super().__init__(pretrained)
        # each position in the sentence could possibly be a an answer, so we cannot
        # assign any int >= 0 to be the padding token
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)

    def forward(self, batch):
        input_ids = batch['input_ids']
        
        lengths = (input_ids != PAD_IDX).long().sum(-1)
        lengths = lengths.to('cpu')
        x = self.embedding(input_ids)
        x = rnn.pack_padded_sequence(
            x,
            lengths=lengths,
            batch_first=True,
            enforce_sorted=False)
        
        sequence, (h, c) = self.lstm(x)
        
        x = h.view(self.num_layers, -1, self.hidden_size)[-1]
        sequence, _ = rnn.pad_packed_sequence(sequence, batch_first=True)
        
        q = self.q_transform(x)
        k = self.k_transform(sequence)
        
        attention_logits = torch.einsum('bh,bth->bt', q, k)
        attention_logits[input_ids == PAD_IDX] = float('-inf')
        
        return attention_logits

    def compute_loss(self, logits, batch):
        # we had put the indices in the span_ids
        output_ids = batch['span_ids'][:, 0]
        loss = self.criterion(logits, output_ids)
        return loss

    def predict(self, logits, batch):
        input_tokens = [s.question_tokens() + [SEP_TOKEN] + s.context_tokens() for s in batch['samples']]
        predictions = logits.argmax(-1).tolist()
        # the answer is extracted by taking the ith token frome the original
        # sentence, where i is the predicted index
        predictions = [seq_tokens[token_n] for token_n, seq_tokens in zip(predictions, input_tokens)]
        return predictions

In [None]:
pl.seed_everything(42)

model_pointer = QAModelWithPointer(vectors)
trainer = pl.Trainer(
    max_epochs=6,
    gpus=(1 if torch.cuda.is_available() else 0),
)
trainer.fit(model_pointer, data)
trainer.test(model_pointer, data.test_dataloader());

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params
-------------------------------------------------
0 | embedding   | Embedding        | 30.0 M
1 | lstm        | LSTM             | 1.4 M 
2 | q_transform | Linear           | 90.0 K
3 | k_transform | Linear           | 90.0 K
4 | v_transform | Linear           | 90.0 K
5 | lin         | Linear           | 30.1 M
6 | criterion   | CrossEntropyLoss | 0     
-------------------------------------------------
1.8 M     Trainable params
30.0 M    Non-trainable params
31.8 M    Total params
127.263   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Epoch 0 - valid: 0.0% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 0 - valid: 18.6% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 1 - valid: 19.4% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 2 - valid: 20.1% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 3 - valid: 21.0% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 4 - valid: 20.5% EM


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch 5 - valid: 22.9% EM


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]





HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_em': 0.22832886505808758,
 'test_loss': 2.872730016708374,
 'test_loss_epoch': 2.9124464988708496}
--------------------------------------------------------------------------------


In [None]:
model_pointer.ask(
    "How many corgis does the Queen of England own?",
    "The queen of England owns several mices, seven one-eyed ponies, three duck-sized corgis and a cranky old raven")

'seven'

# Interact with the Models!

In [None]:
#@title Explore the baseline { run: "auto" }

question = "How many corgis does the Queen of England own?" # @param {type: "string"}
context = "The queen of England owns several mices, seven one-eyed ponies, three duck-sized corgis and a cranky old raven" # @param {type: "string"}
answer = model.ask(question=question, context=context)

print('Question:')
print(wrap(question))

print('\nAnswer:')
print(wrap(answer))

print('\nContext:')
print(wrap(context))

Question:
How many corgis does the Queen of England own?

Answer:
seven

Context:
The queen of England owns several mices, seven one-eyed ponies, three
duck-sized corgis and a cranky old raven


In [None]:
#@title Explore the attention-based model { run: "auto" }

question = "How many corgis does the Queen of England own?" # @param {type: "string"}
context = "The queen of England owns several mices, seven one-eyed ponies, three duck-sized corgis and a cranky old raven" # @param {type: "string"}
answer = model_att.ask(question=question, context=context)

print('Question:')
print(wrap(question))

print('\nAnswer:')
print(wrap(answer))

print('\nContext:')
print(wrap(context))

Question:
How many corgis does the Queen of England own?

Answer:
three

Context:
The queen of England owns several mices, seven one-eyed ponies, three
duck-sized corgis and a cranky old raven


In [None]:
#@title Explore the pointer-based model { run: "auto" }

question = "How many corgis does the Queen of England own?" # @param {type: "string"}
context = "The queen of England owns several mices, seven one-eyed ponies, three duck-sized corgis and a cranky old raven" # @param {type: "string"}
answer = model_pointer.ask(question=question, context=context)

print('Question:')
print(wrap(question))

print('\nAnswer:')
print(wrap(answer))

print('\nContext:')
print(wrap(context))

Question:
How many corgis does the Queen of England own?

Answer:
seven

Context:
The queen of England owns several mices, seven one-eyed ponies, three
duck-sized corgis and a cranky old raven
