In [86]:
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

# Transformer Model
https://arxiv.org/pdf/1706.03762.pdf

In [68]:
# batch of 5, vocab of 30
vocab = 30
embs = 12
batch_size = 5
window = 16
hidden_nodes = 48

In [63]:
prompts = torch.randint(vocab, (batch_size, window))
emb_prompts = nn.Embedding(vocab, embs)(prompts)
prompts.shape, emb_prompts.shape

(torch.Size([5, 16]), torch.Size([5, 16, 12]))

In [64]:
## Thank you! https://jalammar.github.io/illustrated-transformer/
# for evens:  sin(pos/10000**(2i/embs))
# for odds:   cos(pos/10000**(2i/embs))
pos_idxs = torch.arange(0, window).view((-1, 1))
emb_idxs = torch.arange(0, embs).view((1, -1))

angles = pos_idxs / (10000**(2*emb_idxs / embs))

pos_enc = torch.zeros_like(angles)
pos_enc[:, 0::2] += torch.sin(angles[:, 0::2])
pos_enc[:, 1::2] += torch.cos(angles[:, 1::2])

## Plot the positional embeddings
# plt.pcolormesh(pos_enc, cmap='viridis')
# plt.xlabel('Embedding Dimensions')
# plt.xlim((0, embs))
# plt.ylim((window,0))
# plt.ylabel('Char Position')
# plt.colorbar()
# plt.show()

pos_enc.shape

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, window):
        super(PositionalEncoding, self)._init__()
        
        pos_idxs = torch.arange(0, window).view(-1, 1)
        emb_idxs = torch.arange(0, d_model).view(1, -1)
        
        
        angles = pos_idxs / (10000**(2*emb_idxs / embs))

        pos_enc = torch.zeros_like(angles)
        pos_enc[:, 0::2] += torch.sin(angles[:, 0::2])
        pos_enc[:, 1::2] += torch.cos(angles[:, 1::2])

torch.Size([16, 12])

In [65]:
pos_enc = pos_enc.unsqueeze(0).repeat(batch_size, 1, 1)
embedded_prompt = torch.cat((emb_prompts, pos_enc2), dim=2)
emb_prompts.shape, pos_enc.shape, embedded_prompt.shape

(torch.Size([5, 16, 12]), torch.Size([5, 16, 12]), torch.Size([5, 16, 24]))

In [91]:
# Scaled Dot-Product Attention

print("prompt: " + str(embedded_prompt.shape))

Q = nn.Linear(24, hidden_nodes)(embedded_prompt)
V = nn.Linear(24, hidden_nodes)(embedded_prompt)
K = nn.Linear(24, hidden_nodes)(embedded_prompt)

d_k = emb_prompts.shape[2]

print("Q: " + str(Q.shape))
print("V: " + str(V.shape))
print("K: " + str(K.shape))
print("K.transpose(-2, -1): " + str(K.transpose(-2, -1).shape))

score = torch.matmul(Q, K.transpose(-2, -1))

print("score: " + str(score.shape))

gain = score / math.sqrt(K.shape[-1])

print("gain: " + str(gain.shape))

softmax = nn.Softmax(dim=1)(gain) 

print("softmax: " + str(softmax.shape))

out = softmax @ K

print("out: " + str(out.shape))

prompt: torch.Size([5, 16, 24])
Q: torch.Size([5, 16, 48])
V: torch.Size([5, 16, 48])
K: torch.Size([5, 16, 48])
K.transpose(-2, -1): torch.Size([5, 48, 16])
score: torch.Size([5, 16, 16])
gain: torch.Size([5, 16, 16])
softmax: torch.Size([5, 16, 16])
out: torch.Size([5, 16, 48])


In [None]:
# Multi-Head Attention
class MultiHeadedAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        self.hidden = hidden_nodes

        Wq = nn.Linear(self.input, self.hidden)
        Wv = nn.Linear(self.input, self.hidden)
        Wk = nn.Linear(self.input, self.hidden)
        
    def forward(self, q, v, k):
        
        
        q = Wq(x)
        v = Wv(x)
        k = Wk(x)

        d_k = emb_prompts.shape[2]