In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns

torch.manual_seed(42)

<torch._C.Generator at 0x1f7b8d9b910>

In [2]:
VOCAB_SIZE = 3  # hyperparameter depending on the task
HEADS = 8
EMBED_DIM = 512
HEAD_DIM = EMBED_DIM // HEADS
D_PROB = 0.1

In [3]:
input_emb = nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=EMBED_DIM)
output_emb = nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=EMBED_DIM)

In [4]:
input_emb

Embedding(3, 512)

In [5]:
test = torch.LongTensor([[0, 1, 2], [2, 1, 0]])
# input_emb(test)

In [6]:
# output_emb(test)

In [7]:
do = input_emb(test)

## Head

In [8]:
key = nn.Linear(in_features=EMBED_DIM, out_features=HEAD_DIM, bias=False)
query = nn.Linear(in_features=EMBED_DIM, out_features=HEAD_DIM, bias=False)
value= nn.Linear(in_features=EMBED_DIM, out_features=HEAD_DIM, bias=False)

# TODO: figure out register buffer
# TODO: prototype tril
register_buffer = ('mask', torch.tril(torch.ones(HEAD_DIM, HEAD_DIM), diagonal=1))

dropout = nn.Dropout(D_PROB)

# t for tokens, so I think this is words in this case
_, t, d_k = do.size() 
k = key(do)
q = query(do)
v = value(do)
qk = (q@k.transpose(-1, -2)) / d_k ** 0.5

mask = torch.tril(torch.ones(HEAD_DIM, HEAD_DIM))

# if mask:
qk = qk.masked_fill(mask[:t, :t] == 0, float('-inf'))

# 0 is among batches so ith, jth inputs in each batch add to 1, we don't want this
# 1 is through the columns, which is the word so maybe
# 2 is the word vector entirely
qk = F.softmax(qk, dim=-1)
attn = qk @ v

# dropout occurs at the end of the sublayer before adding residual connections and layernorming
attn
attn = dropout(attn)

In [9]:
from model import Head

In [10]:
head = Head(HEAD_DIM, EMBED_DIM, True)

## MultiHeadAttention

In [11]:
from model import MultiHeadAttention

In [12]:
mha = MultiHeadAttention(HEADS, HEAD_DIM, EMBED_DIM, mask=True)
multi = mha(do, do, do)

In [13]:
torch.all(multi == do).item(), multi.size() == do.size() # we want this

(False, True)

## FeedForward

In [14]:
from model import FeedForward

In [15]:
ff = FeedForward(EMBED_DIM)
ff(multi).size()

torch.Size([2, 3, 512])

In [16]:
multi.size()

torch.Size([2, 3, 512])

## Encoder Block

In [17]:
from model import EncoderBlock

In [18]:
eb = EncoderBlock(HEADS, EMBED_DIM)
enc = eb(do)

In [19]:
enc.size()

torch.Size([2, 3, 512])

## Decoder Block

In [20]:
from model import DecoderBlock

In [22]:
db = DecoderBlock(HEADS, EMBED_DIM)
db(do, enc).size()

torch.Size([2, 3, 512])

## Positional Encodings