In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import datasets
from jaxtyping import Float
from torch import Tensor
from einops import einsum, rearrange, reduce

In [24]:
seq_len = 512
def tokenize(raw_text):
  raw_text = raw_text['text'][0]
  token = [ord(x) for x in raw_text]
  current_token = []
  next_token = []
  for idx in range(len(token) // seq_len):
    t = token[idx:idx + seq_len + 1]
    current_token.append(t[:-1])
    next_token.append(t[1:])
  return {'current': current_token,
          'next': next_token}

In [25]:
raw_text_data = datasets.load_dataset('karpathy/tiny_shakespeare', split='train')
char_data = raw_text_data.map(tokenize, batched=True, remove_columns=['text']).with_format('torch')

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [20]:
class Embedding(nn.Module):
  def __init__(self, d_model: int, seq_len: int):
    super().__init__()
    self.embedding_matrix = nn.Parameter(torch.zeros(256, d_model))
    nn.init.xavier_normal_(self.embedding_matrix)

    self.positional_encoding = nn.Parameter(torch.zeros(seq_len, d_model))
    nn.init.xavier_normal_(self.positional_encoding)

  def forward(self, data: Float[Tensor, "*batch seq_len"]) -> Float[Tensor, "*batch seq_len d_model"]:
    return self.embedding_matrix[data] + self.positional_encoding

In [None]:
class Attention(nn.Module):
  def __init__(self, n_head: int, d_model: int, d_head: int, seq_len: int):
    super().__init__()
    self.seq_len = seq_len
    self.d_head = d_head

    self.query_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.query_matrix)

    self.key_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.key_matrix)

    self.value_matrix = nn.Parameter(torch.zeros(n_head, d_head, d_model))
    nn.init.xavier_normal_(self.value_matrix)

    self.output_matrix = nn.Parameter(torch.zeros(n_head, d_model, d_head))
    nn.init.xavier_normal_(self.output_matrix)

  def forward(self, data: Float[Tensor, "*batch seq_len d_model"]) -> Float[Tensor, "*batch seq_len d_model"]:
    query = einsum(data, self.query_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")
    key = einsum(data, self.key_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")
    value = einsum(data, self.value_matrix, "batch seq_len d_model, n_head d_head d_model -> batch n_head seq_len d_head")

    attn_pre = einsum(query, key, "batch n_head query_len d_head, batch n_head key_len d_head -> batch n_head query_len key_len")
    mask_idx = torch.triu_indices(self.seq_len, self.seq_len, offset=1)
    attn_pre[..., mask_idx[0], mask_idx[1]] = float('-inf')
    attn_pre /= torch.sqrt(self.d_head)
    F.softmax(attn_pre, dim=-1)

    output_pre = einsum(attn_pre, value, "batch n_head query_len key_len, batch n_head key_len d_head -> batch n_head key_len d_head")
    output = einsum(self.output_matrix, output_pre, "batch n_head d_model d_head, batch n_head seq_len d_head -> batch seq_len d_model")
    return output

In [38]:
b = torch.randn((3, 5, 5))
b[..., torch.triu_indices(5, 5, offset=1)[0], torch.triu_indices(5, 5, offset=1)[1],] = float('-inf')
b = F.softmax(b, dim=-1)
x = torch.randn((1, 5, 3))
k= einsum(b, x, "batch query_len key_len, batch key_len d_head -> batch key_len d_head")


tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8301, 0.1699, 0.0000, 0.0000, 0.0000],
         [0.2476, 0.0859, 0.6665, 0.0000, 0.0000],
         [0.5885, 0.2264, 0.0661, 0.1191, 0.0000],
         [0.0641, 0.0822, 0.0706, 0.4376, 0.3455]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3539, 0.6461, 0.0000, 0.0000, 0.0000],
         [0.3997, 0.3641, 0.2363, 0.0000, 0.0000],
         [0.0996, 0.8157, 0.0093, 0.0754, 0.0000],
         [0.1485, 0.0319, 0.7411, 0.0426, 0.0360]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.9354, 0.0646, 0.0000, 0.0000, 0.0000],
         [0.1461, 0.2672, 0.5867, 0.0000, 0.0000],
         [0.2009, 0.1178, 0.3215, 0.3598, 0.0000],
         [0.3561, 0.3709, 0.0353, 0.2032, 0.0345]]])