# Self-Attention &b Multi-Head Attention

### Loading Libraries

In [1]:
%cd ../..

/Users/joaquinromero/Desktop


In [None]:
# Numerical Computing
import numpy as np

# Data Manipulation
import pandas as pd

# OS 
import os
import shutil
import joblib

# Data Visualization
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go

# Path
from pathlib import Path
from tqdm.autonotebook import tqdm

# IPython & Itertools
from itertools import cycle
from IPython.display import display, HTML

In [None]:
# Custom Libraries
from src.utils import plotting_utils
from src.utils import ts_utils_updated
from src.forecasting.ml_forecasting import calculate_metrics

from src.forecasting.ml_forecasting import (
    MissingValueConfig,
    calculate_metrics,
)

In [None]:
# %load_ext autoreload

# %autoreload 2

In [None]:
tqdm.pandas()

np.random.seed(42)

pio.templates.default = "plotly_white"

In [None]:
os.makedirs("imgs/chapter_14", exist_ok=True)

preprocessed = Path.home() / "Desktop" / "data" / "london_smart_meters" / "preprocessed"

output = Path.home() / "Desktop" / "data" / "london_smart_meters" / "output"

In [None]:
def format_plot(fig, legends=None, xlabel="Time", ylabel="Value", title="", font_size=15):
    if legends:
        names = cycle(legends)
        fig.for_each_trace(lambda t: t.update(name=next(names)))
    fig.update_layout(
        autosize=False,
        width=900,
        height=500,
        title_text=title,
        title={"x": 0.5, "xanchor": "center", "yanchor": "top"},
        titlefont={"size": 20},
        legend_title=None,
        legend=dict(
            font=dict(size=font_size),
            orientation="h",
            yanchor="bottom",
            y=1.0,
            xanchor="right",
            x=1,
        ),
        yaxis=dict(
            title_text=ylabel,
            titlefont=dict(size=font_size),
            tickfont=dict(size=font_size),
        ),
        xaxis=dict(
            title_text=xlabel,
            titlefont=dict(size=font_size),
            tickfont=dict(size=font_size),
        )
    )
    return fig

In [None]:
embedding_dim = 512

attn_dim = 64

seq_len = 10

### Embedding Representation of a Sentence

In [None]:
sentence = torch.randn(seq_len, embedding_dim)
with tsensor.explain(fontsize=20, dimfontsize=12):
    sentence

## Self Attention

#### `Step 1`: Define three learnable matrices, one each for query, key, and value

In [None]:
w_q = torch.randn(embedding_dim, attn_dim)
w_k = torch.randn(embedding_dim, attn_dim)
w_v = torch.randn(embedding_dim, attn_dim)

with tsensor.explain(fontsize=20, dimfontsize=12):
    w_q
    w_k
    w_v

#### `Step 2`: Convert input embeding into attention dimensions using `W_q, W_k, W_v`

In [None]:
with tsensor.explain(fontsize=20, dimfontsize=12):
    q = sentence @ w_q
    k = sentence @ w_k
    v = sentence @ w_v

#### `Step 3`: Calculate the attention weights between all the query and value pairs

In [None]:
scaling = 1/math.sqrt(attn_dim)

with tsensor.explain(fontsize=20, dimfontsize=12):
    attn_scores = q @ v.t()
attn_weights = torch.softmax(attn_scores/scaling, dim=-1)

#### `Step 4`: Use the attention weights to combine the values

In [None]:
with tsensor.explain(fontsize=20, dimfontsize=12):
    attn_output = attn_weights @ v

#### Putting it All Together

In [None]:
with tsensor.explain(fontsize=20, dimfontsize=12):
    sentence
    w_q
    w_k
    w_v

In [None]:
tsensor.astviz("attn_output = ((sentence @ w_q) @ (sentence @ w_k).t()) @ (sentence @ w_v)")

## Multi-Headed Self Attention

In [None]:
n_heads = 8

#### `Step 1`: Define three learnable matrices, one each for query, key and value

In [None]:
w_q = torch.randn(embedding_dim, attn_dim)
w_k = torch.randn(embedding_dim, attn_dim)
w_v = torch.randn(embedding_dim, attn_dim)

with tsensor.explain(fontsize=20, dimfontsize=12):
    w_q
    w_k
    w_v

#### `Step 2`: Convert input embeding into attention dimensions using W_q, W_k, W_v

In [None]:
with tsensor.explain(fontsize=20, dimfontsize=12):
    q = sentence @ w_q
    k = sentence @ w_k
    v = sentence @ w_v

#### `Step 3`: Reshape q, k and v into sub_dim for each head in the multi-headed attention

In [None]:
sub_dim = attn_dim//n_heads
with tsensor.explain(fontsize=20, dimfontsize=12):
    q = q.reshape(seq_len, n_heads, sub_dim).permute(1,0,2)
shape_str = " x ".join([str(i) for i in q.size()])
display(md(f"""
-----
### q, k, v dimensions: n_heads, seq_len, sub_dim : {shape_str}
"""))
#Similar transformation to k and v
k = k.reshape(seq_len, n_heads, sub_dim).permute(1,0,2)
v = v.reshape(seq_len, n_heads, sub_dim).permute(1,0,2)

##### `q, k, v` dimensions: n_heads, seq_len, sub_dim : `8 x 10 x 8`

#### `Step 4`: Calculate the attention weights between all the query and value pairs for each head