In [4]:
import torch
import matplotlib.pyplot as plt
import numpy as np


**Introduction to QKV Self-Attention Mechanism in Encoders: The Detective Analogy**  

Imagine a detective solving a complex case with scattered clues. To understand the story, the detective systematically evaluates each clue's relevance and connection, similar to the Query-Key-Value (QKV) self-attention mechanism in a transformer encoder.

In this analogy:

* Queries (Q) are the detective's questions or focal points.
* Keys (K) represent the unique features of each clue.
* Values (V) are the actual pieces of evidence.


The self-attention mechanism works in three steps:
* Scoring Relevance: The detective compares queries to keys to score each clue's relevance.
* Normalizing Scores: These scores are normalized using a softmax function, akin to the detective assigning a probability to each clue.
* Aggregating Information: The detective gathers information from the most relevant clues, weighted by their scores, to piece together the story. 


This mechanism helps the encoder capture relationships within the input data, just as a detective uncovers connections in a case, making it essential for modern natural language processing.

# Attention Mechanisms
This exercise is based on Lukas Heinrich's lecture about Transformers in the course "Modern Deep Learning in Physics" @ TUM. 
In this exercise we will write a attention mechanism to solve the following problem

"Given a length 10 sequence of integers, are there more '2s' than '4s' ?"

This could of course be easily handled by a fully connected network, but we'll force the network
to learn this by learning to place attention on the right values. I.e. the strategy is

* embed the individual integer into a high-dimensional vector (using torch.nn.Embedding)
* once we have those embeddings compute how much attention to place on each vector by comparing "key" values computed from the embedded values 
* compute "answer" values to our query by weighting the individual responsed by their attention value


$$
v_{ik} = \mathrm{softmax}_\mathrm{keys}(\frac{q_{i}\cdot k_{j}}{\sqrt{d}}) v_{jk}
$$


## Preparation
Before we bring this all to live in a written class, we go through each step in order to understand it properly. 

We will use the QKV - Self Attention Encoding Part.
For doing so we will replicate the shown procedure:


<img src="Grafiken/text1.svg" alt="Attention">

### Generating Data
Write a data-generating function that produces a batch of N examples or length-10 sequences of random integers between 0 and 9 as well as a binary label indicating whether the sequence has more 2s than 4s. The output should return (X,y), where X has shape `(N,10)` and y has shape `(N,1)`

You can use torch.randint() for creating the batch.

```python
def make_batch(N):
    ...
```

### Embedding the Integers
Deep Learning works well in higher dimensions. So we'll embed the 10 possible integers into a vector space using `torch.nn.Embedding`


* Verify that using e.g. a module like `torch.nn.Embedding(10,embed_dim)` achieves this
* Take a random vector of integers of shape (N,M) and evaluate them through an embedding module
* Does the output dimension make sense?

**Alternatively:**


One can embed the integers inot a vector space such that one uses One-Hot Encoding and sending it through a `torch.nn.Linear(one_hot_dim, embed_dim)`



### Extracting Keys and Values 

Once data is embedded we can extract keys and values by a linear projection

* For 2 linear layers `torch.nn.Linear(embed_dim,att_dim)` we can extract keys and values for the output of the previous step
* verify that this works from a shape perspective

## Computing Attention
<img src="Grafiken/text1.svg" alt="Attention">
Implement the Attention-formula from above in a batched manner, such that for a input set of sequences `(N,10)`
you get an output set of attention-weighted values `(N,1)`

* It's easiest when using the function `torch.einsum` which uses the Einstein summation you may be familiar with from special relativity
* e.g. a "batched" dot product is performed using `einsum('bik,bjk->bij')` where `b` indicates the batch index, `i` and `j` are position indices and `k` are the coordinates of the vectors


--> **Some hints**
* Keep in mind that the dimension $\sqrt{d}$ in the softmax function is your `att_dim`
* Initiate your query randomly in the size `1,att_dim`
* query and keys: einsum --> `'ik,bjk->bij'`
* for the softmax use `dim=-1`

# Integrate into a Module

Complete the following torch Module:

To use the `self.nn` make sure to have an input shaped like `torch.Size([batch, att_dim])`


* For the `forward(x)` function have a look at the Graph in 1.2 again and follow along
```python
class AttentionModel(torch.nn.Module):
    def __init__(self):
        super(AttentionModel,self).__init__()
        self.embed_dim = 5
        self.att_dim = 5
        self.embed = torch.nn.Embedding(10,self.embed_dim)
        
        #one query
        self.query  = torch.nn.Parameter(torch.randn(1,self.att_dim))
        
        #used to compute keys
        self.WK = torch.nn.Linear(self.embed_dim,self.att_dim)
        
        #used to compute values
        self.WV = torch.nn.Linear(self.embed_dim,1)
        
        #final decision based on attention-weighted value
        self.nn = torch.nn.Sequential(
            torch.nn.Linear(1,200),
            torch.nn.ReLU(),
            torch.nn.Linear(200,1),
            torch.nn.Sigmoid(),
        )

    def attention(self,x):
        # compute attention
        ...
    
    def values(self,x):
        # compute values
        ...
                
    def forward(self,x):
        # compute final classification using attention, values and final NN
      
```

## Predefine Plot Function
Use this given function later on to visualize your training. 
Just execute the following cell.

In [None]:
def plot(model,N,traj):
    x,y = make_batch(N)
    f,axarr = plt.subplots(1,3)
    f.set_size_inches(10,2)
    ax = axarr[0]
    at = model.attention(model.embed(x))[:,0,:].detach().numpy()
    ax.imshow(at)
    ax = axarr[1]
    
    
    vals = model.values(model.embed(x))[:,:,0].detach().numpy()
    nan = np.ones_like(vals)*np.nan
    nan = np.where(at > 0.1, vals, nan)
    ax.imshow(nan,vmin = -1, vmax = 1)
    for i,xx in enumerate(x):
        for j,xxx in enumerate(xx):
            ax = axarr[0]
            ax.text(j,i,xxx.numpy(), c = 'r' if (xxx in [2,4]) else 'w')    
            ax = axarr[1]
            ax.text(j,i,xxx.numpy(), c = 'r' if (xxx in [2,4]) else 'w')    
    ax = axarr[2]
    ax.plot(traj)
    f.set_tight_layout(True)


## Train the Model

In [None]:
def train():
    model = AttentionModel()
    opt = torch.optim.Adam(model.parameters(),lr = 1e-4)

    traj = []
    for i in range(5001):
        x,y = make_batch(100)
        p = model.forward(x)
        loss = torch.nn.functional.binary_cross_entropy(p,y)
        loss.backward()
        traj.append(float(loss))
        if i % 500 == 0:
            plot(model,5,traj)
            # plt.savefig('attention_{}.png'.format(str(i).zfill(6)))
            plt.show()
            print(i,loss)
        opt.step()
        opt.zero_grad()
    return traj


In [None]:
training = train()