In [87]:
import torch as t
from torch import nn
import plotly.express as px
from IPython.display import display
import pandas as pd
import numpy as np
import copy
from fancy_einsum import einsum

from einops import rearrange, reduce, repeat

import utils
import cnn_modules as cm

In [2]:
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (batch, seq_len, embed_size)
    K: shape (batch, seq_len, embed_size)
    V: shape (batch, seq_len, embed_size)

    Return: shape (batch, seq_len, embed_size)
    '''
    scores = Q @ t.transpose(K, -2, -1)
    scores /= Q.shape[-1] ** 0.5
    scores = t.softmax(scores, dim=-1)
    Z = einsum('B Seq Score, B Seq Emb -> B Seq Emb', scores, V)
    return Z
    

In [None]:
def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
    '''
    Should return the results of masked self-attention.

    See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.

    Q: shape (batch, seq_len, embed_size)
    K: shape (batch, seq_len, embed_size)
    V: shape (batch, seq_len, embed_size)

    Return: shape (batch, seq_len, embed_size)
    '''
    batch_size, seq_len, embed_size = Q.shape
    scores = Q @ t.transpose(K, -2, -1)
    scores /= Q.shape[-1] ** 0.5
    
    # create lower-left triangle of ones, including the diagonal
    mask = t.tril(t.ones(seq_len, seq_len), diagonal=0)
    # fill with close-to-neg-inf values
    scores = scores.masked_fill(mask==0, -1e9)
    
    scores = t.softmax(scores, dim=-1)
    Z = einsum('B Seq Score, B Seq Emb -> B Seq Emb', scores, V)
    return Z
    

In [85]:
def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
    '''
    Implements multihead masked attention on the matrices Q, K and V.

    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)

    returns: shape (batch, seq, nheads*headsize)
    '''
    Q = rearrange(Q, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)
    K = rearrange(K, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)
    V = rearrange(V, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)

    batch_size, seq_len, nheads, headsize = Q.shape
    scores = einsum('B Qseq nheads headsize, B Kseq nheads headsize -> B nheads Qseq Kseq', Q, K)
    scores /= Q.shape[-1] ** 0.5

    # create lower-left triangle of ones, including the diagonal
    mask = t.tril(t.ones(seq_len, seq_len), diagonal=0)
    # fill with close-to-neg-inf values where mask==0
    scores = scores.masked_fill(mask==0, -1e9)

    scores = t.softmax(scores, dim=-1)
    Z = einsum('B nheads Qseq Kseq, B Qseq nheads headsize -> B Qseq nheads headsize', scores, V)
    Z = rearrange(Z, 'B Qseq nheads headsize -> B Qseq (nheads headsize)')
    return Z

In [80]:
T = t.randn(2, 10, 4)

In [84]:
multihead_masked_attention(T, T, T, 2)

tensor([[[ 1.8941, -0.4399,  0.2117,  2.8757],
         [ 0.6109, -0.6800, -0.3416, -2.5958],
         [-0.6437,  3.2413, -0.4565,  0.2686],
         [ 0.0507,  0.1856,  0.2923,  1.2261],
         [-0.6602, -0.6553,  0.2737, -1.5608],
         [-0.0852,  1.3946,  0.0421,  0.4518],
         [ 0.5860, -2.2522, -1.8987, -0.4523],
         [-0.4501,  0.5583,  0.0848,  1.6969],
         [-0.0303,  0.1286, -0.8724,  1.1841],
         [ 0.6496,  0.1996, -0.5088, -0.4013]],

        [[-0.0852,  0.0814,  0.2323,  1.2971],
         [ 0.5716,  0.2464,  0.3229,  0.2581],
         [ 0.8881,  1.0325, -1.0889, -2.5788],
         [ 1.8355, -0.1379, -0.8073,  0.6323],
         [ 0.0549, -0.2533, -1.2703,  0.9032],
         [ 0.4287, -0.3284, -1.0406,  0.1846],
         [ 0.9691, -0.3803,  0.2693,  0.8820],
         [-0.0184,  1.3782, -0.5900, -0.8977],
         [ 0.6414,  0.9656,  0.2465,  1.4997],
         [ 0.2221, -1.4293, -0.3854,  0.0404]]])

In [None]:
class MultiheadMaskedAttention(nn.Module):
    W_QKV: nn.Linear
    W_O: nn.Linear

    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.query_size = int(hidden_size / num_heads)
        self.lin = cm.Linear(hidden_size, hidden_size)
        self.qkv = nn.ModuleList([copy.deepcopy(self.lin) for _ in range(3)])
        self.ff = cm.Linear(hidden_size, hidden_size)

    def multihead_masked_attention(self, Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
        '''
        Implements multihead masked attention on the matrices Q, K and V.

        Q: shape (batch, seq, nheads*headsize)
        K: shape (batch, seq, nheads*headsize)
        V: shape (batch, seq, nheads*headsize)

        returns: shape (batch, seq, nheads*headsize)
        '''
        Q = rearrange(Q, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)
        K = rearrange(K, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)
        V = rearrange(V, 'B S (nheads headsize) -> B S nheads headsize', nheads = num_heads)

        batch_size, seq_len, nheads, headsize = Q.shape
        scores = einsum('B Qseq nheads headsize, B Kseq nheads headsize -> B nheads Qseq Kseq', Q, K)
        scores /= Q.shape[-1] ** 0.5

        # create lower-left triangle of ones, including the diagonal
        mask = t.tril(t.ones(seq_len, seq_len), diagonal=0)
        # fill with close-to-neg-inf values where mask==0
        scores = scores.masked_fill(mask==0, -1e9)

        scores = t.softmax(scores, dim=-1)
        Z = einsum('B nheads Qseq Kseq, B Qseq nheads headsize -> B Qseq nheads headsize', scores, V)
        Z = rearrange(Z, 'B Qseq nheads headsize -> B Qseq (nheads headsize)')
        return Z

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        x = t.cat([x, x, x], dim=-1) # x reshaped to [BATCH_SIZE x SEQ_LEN x 3 * EMB_DIM]

        # x projected to [BATCH_SIZE x SEQ_LEN x HEADS * QUERY_SIZE], then reshaped to 
        # [BATCH_SIZE x SEQ_LEN x HEADS x QUERY_SIZE], and finally permuted to 
        # [BATCH_SIZE x HEADS x SEQ_LEN x QUERY_SIZE] for all q, k, v
        # TODO:

        z = self.multihead_masked_attention(q, k, v, num_heads=self.num_heads)
        
        # z made contiguous in memory and transformed from [BATCH_SIZE x HEADS x SEQ_LEN x QUERY_SIZE]
        # to [BATCH_SIZE x SEQ_LEN x HEADS * QUERY_SIZE]
        # TODO:

        return self.ff(z)