In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# x is a tensor of shape (b, t, e)
def self_attention(x: torch.Tensor) -> torch.Tensor:
    # Create the w_prime matrix
    raw_weights = torch.bmm(x, x.transpose(1, 2))

    # Normalize the weights
    weights = F.softmax(raw_weights, dim=2)

    return torch.vmm(weights, x)

In [None]:
def self_attention_2(x: torch.Tensor) -> torch.Tensor:
    #  qi = Wq * xi
    #  ki = Wk * xi
    #  vi = Wv * xi

    # e = ?

    # w' = qi^T * kj / sqrt(e)
    # w = softmax(w')
    # y = w * vi

    # 

    # Create the w_prime matrix
    raw_weights = torch.bmm(x, x.transpose(1, 2))

    # Normalize the weights
    weights = F.softmax(raw_weights, dim=2)

    return torch.bmm(weights, x)

# Multi-Headed Self-Attention with Query, Key, and Value

Assume we have a e-dimension vector of d-dimension vector embeddings named `x`

`x` :: [e, d]

We have `h` attention heads. Each attention head will:

Use a $W_q$, $W_k$, and $W_v$ weights matrix to calculate Queries, Keys, and Weights.

The $W_q$, $W_k$, and $W_v$ weights matrix are of size [d, d], so we get $3d^2$  operations

Using `h` attention heads will slow down computation by a factor of `h`, unless we map each vector to a smaller one. We can use a projection from $d \rightarrow \frac{d}{4}$ dimensions for the $W_q$, $W_k$, and $W_v$ vectors. We now get q, e, and v vectors with dimensions [d/h].

Now, we can run the normal process.

$w' = \frac{q_i^T \cdot k_j}{\sqrt{d/h}}$
w=softmax(w')
y=w*v

At the end, we use a e*e matrix to unify the heads

In [4]:
## Now a class to implement all this math
##
## `e`: the number of embeddings dimensions
## `heads`: the number of heads to use
##
class SelfAttention(nn.Module):
    def __init__(self, e, heads=4):
        super().__init__()

        # e must be divisible by the number of heads
        assert e % heads == 0

        self.e, self.heads = e, heads

        # Create the weight matrices
        self.tokeys = nn.Linear(e, e, bias=False)
        self.toqueries = nn.Linear(e, e, bias=False)
        self.tovalues = nn.Linear(e, e, bias=False)

        self.unifyheads = nn.Linear(e, e, bias=False)

    def forward(self, x):
        # x is a tensor of shape (b, t, e)
        # b = batch size
        # t = sequence length
        # e = embedding size
        b, t, e = x.size()
        h = self.heads

        queries = self.toqueries(x)
        keys    = self.tokeys(x)
        values  = self.tovalues(x)

        # split the keys, queries, and values into h groups

        s = e // h

        keys = keys.view(b, t, h, s)
        queries = queries.view(b, t, h, s)
        values = values.view(b, t, h, s)

        # merge the heads into the batch dimension

        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        # calculate the raw weights
        # (b*h, t, t)
        dot = torch.bmm(queries, keys.transpose(1, 2))

        # scale the weights
        dot = dot / (e ** (1 / 2))

        # normalize the weights
        dot = F.softmax(dot, dim=2)

        # apply the weights to the values
        out = torch.bmm(dot, values).view(b, h, t, s)

        # tranpose the heads back out of the batch dimension
        out = out.transpose(1, 2).contiguous().view(b, t, e)

        # unify the heads
        return self.unifyheads(out)



# Transformers

## Definition

Any architecture designed to process a connected set of units—such as the tokens in a sequence or the pixels in an image—where the only interaction between units is through self-attention.

## Approach

Wrap the self attention into a repeatable block

- Transformer Block
  - Self Attention
    - Residual connections go around
  - Layer Norm
  - Feed Forward Layer
    - n* MLP
    - Residual connections go around
  - Layer Norm




In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, e, heads):
        super().__init__()

        self.attention = SelfAttention(k, heads=heads)

        # Two normalizing layers over the embedding dimension
        self.norm1 = nn.LayerNorm(e)
        self.norm2 = nn.LayerNorm(e)

        # Scale the input to the feed forward layer by a factor of 4
        self.ff = nn.Sequential(
            nn.Linear(e, 4 * e),
            nn.ReLU(),
            nn.Linear(4 * e, e)
        )

    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(attended + x)

        fedforward = self.ff(x)
        return self.norm2(fedforward + x)