In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from typing import Literal, Tuple

## Data Loading

In [3]:
raw_text = open("./marquis-170.txt").read()
# remove all japanese characters from text

to_remove = ['\xa0', '\u3000','い','う','え','が','こ','さ','し','す','そ','た','て','で','と','ど','な','の','は','や','よ','ら','り','れ','を','ア','ィ','ェ','エ','オ','カ','グ','シ','ジ','ス','ッ','ト','ド','ナ','ニ','ネ','ピ','ペ','マ','ム','メ','ュ','リ','ル','レ','ン','ヴ','ー','一','似','分','女','弱','彼','指','服','真','神','祟','私','肥','覚','許','部',]

# print(raw_text)

filtered_text = ''.join([c for c in raw_text if c not in to_remove])
with open("filtered_text.txt", "w") as f:
    f.write(filtered_text)


In [4]:
chars = sorted(list(set(filtered_text)))
char_to_idx = {ch: idx for idx, ch in enumerate(chars)}
idx_to_char = {idx: ch for idx, ch in enumerate(chars)}
encode = lambda chars: [char_to_idx[ch] for ch in chars]
decode = lambda idxs: "".join([idx_to_char[idx] for idx in idxs]) if type(idxs) is list else idx_to_char[idxs]
decode(encode("cunny"))

'cunny'

In [4]:
encoded_text = torch.tensor(encode(filtered_text), dtype=torch.long)

# not practical to duplicate the data. batches created on the fly with sampled indices
train_cutoff = int(0.8 * len(encoded_text)); val_cutoff = int(0.9 * len(encoded_text))
train = encoded_text[:train_cutoff]; val = encoded_text[train_cutoff:val_cutoff]; test = encoded_text[val_cutoff:]

In [5]:
torch.manual_seed(14)
batch_size = 4; block_size = 8

def get_batch(split: Literal["train", "val", "test"]):
  datasets = {"train": train, "val": val, "test": test}
  dataset = datasets[split]
  batch_idxs = torch.randint(dataset.shape[0] - (block_size + 1), (batch_size,))

  X_batch = torch.stack([dataset[batch_idx: batch_idx + block_size] for batch_idx in batch_idxs])
  y_batch = torch.stack([dataset[batch_idx + 1: batch_idx + block_size + 1] for batch_idx in batch_idxs])

  return X_batch, y_batch

X_batch, y_batch = get_batch("train")
X_batch.shape, y_batch.shape

(torch.Size([4, 8]), torch.Size([4, 8]))

## model

In [21]:
batch_size = 64
block_size = 256
vocab_size = len(chars)
max_iters = 5000
eval_interval = 100; eval_iters = 200
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_dim = 384
n_head = 6
qk_dim = embedding_dim // 2
v_dim = embedding_dim // n_head
n_layer = 6
dropout_p = 0.2

In [22]:
class Head(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    # self.qk_dim = qk_dim; self.v_dim = v_dim
    self.query_function = nn.Linear(embedding_dim, qk_dim, bias=False) # bias false -> just matmul linear function
    self.key_function = nn.Linear(embedding_dim, qk_dim, bias=False)
    self.value_function = nn.Linear(embedding_dim, v_dim, bias=False)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, X: Tensor):
    # B: batch dim, T: time dim, E: token/pos embed dim, QK: query key dim, V: value dim
    queries = self.query_function(X) # (B, T, E) @ (E, QK) -> (B, T, QK)
    keys = self.key_function(X)      # (B, T, E) @ (E, QK) -> (B, T, QK)
    values = self.value_function(X)  # (B, T, E) @ (E, V) -> (B, T, V)

    qk_dot_products = queries @ keys.transpose(1, 2) # (B, T, H) @ (B, H, T) -> (B, T, T) - given a query, calculates all qk dp's
    qk_dp_scaled = qk_dot_products * (qk_dim ** -0.5) # each increase in qk_dim is adding another random variable to qk_dp value
    causal_mask = ~torch.tril(torch.ones((X.shape[1], X.shape[1]), dtype=bool))
    qk_dps_masked = torch.masked_fill(qk_dp_scaled, causal_mask, -torch.inf) # mask is broadcasted
    qk_dpsm_softmaxed = qk_dps_masked.softmax(dim=2) # (B, T, T)
    qk_dpsms_dropout = self.dropout(qk_dpsm_softmaxed)

    aggregated_information = qk_dpsms_dropout @ values # (B, T, T) @ (B, T, V) -> (B, T, V)
    # ---- sloppy mistake made. got it right first time with values instead of x. but during re-implementation, disconnected from intuition and paid too much attention to dimension values. keep the interpretation as first priority. 
    # new_embeddings = X + aggregated_information # (B, T, E) + (B, T, E)
    # return new_embeddings

    # to work with multi-headed attention, can't aggregate information to embedding vector immediately
    return aggregated_information
  
class MultiHeadAttention(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.heads = nn.ModuleList([Head() for _ in range(n_head)]) # register params w. optimizer
    self.proj = nn.Linear(embedding_dim, embedding_dim)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, X: Tensor):
    sa_results = torch.cat([head(X) for head in self.heads], dim=2) # lowercase fucked me
    results = self.dropout(self.proj(sa_results)) # this seems very unecessary. the ffwd already should do this well enough (?)
    return results


In [23]:
class FeedForward(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.net = nn.Sequential(nn.Linear(embedding_dim, embedding_dim * 4), # this takes on the form of the 3b1b video now. 
                             nn.ReLU(),
                             nn.Linear(embedding_dim * 4, embedding_dim),
                             nn.Dropout(dropout_p))

  def forward(self, X: Tensor):
    return self.net(X)
  
class Block(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.sa = MultiHeadAttention()
    self.ffwd = FeedForward()
    self.ln1 = nn.LayerNorm(embedding_dim)
    self.ln2 = nn.LayerNorm(embedding_dim)

  def forward(self, X: Tensor):
    sa_results = X + self.sa(self.ln1(X))
    ffwd_results = X + self.ffwd(self.ln2(sa_results))
    return ffwd_results

## model

In [24]:
class BigramLanguageModel(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
    self.position_embedding_table = nn.Embedding(block_size, embedding_dim)
    self.blocks = nn.Sequential(*[Block() for _ in range(n_layer)])
    self.ln = nn.LayerNorm(embedding_dim)
    self.lm_head = nn.Linear(embedding_dim, vocab_size)

  def forward(self, X: Tensor, y: Tensor = None) -> Tuple[Tensor, Tensor]:
    token_embeddings: Tensor = self.token_embedding_table(X) # (batch, block, embed)
    positional_embeddings: Tensor = self.position_embedding_table(torch.arange(X.shape[1])) # (block, embed)
    embeddings = token_embeddings + positional_embeddings
    blocks_results = self.blocks(embeddings)
    ln_results = self.ln(blocks_results)
    logits: Tensor = self.lm_head(ln_results) # (batch, block, embed) @ (embed, vocab) -> (batch, block, vocab)

    if y is None:
      return logits, None

    logits_stable = logits - logits.max(dim=2, keepdim=True)[0]
    counts = logits_stable.exp()
    softmaxed = counts / counts.sum(dim=2, keepdim=True)
    # normally you wouldn't store intermediate probabilites, for training memory efficiency

    B, T, C = softmaxed.shape
    probs_target = softmaxed[torch.arange(B).repeat_interleave(T), 
                             torch.arange(T).repeat(B), 
                             y.view(-1)]

    anll = -probs_target.log().mean()
    
    return logits, anll
  
  def generate(self, context: Tensor, max_tokens: int):
    for iter in range(max_tokens):
      logits, _ = self(context[:, -block_size:])
      probs = F.softmax(logits[:, -1, :], dim=-1) # recalculates increasingly many logits that are uneeded. i don't like this. 
      sampled_idx = torch.multinomial(probs.view(-1), num_samples=1).reshape((1, 1))
      context = torch.cat((context, sampled_idx), dim=1)

    return context
  
bigram = BigramLanguageModel()
logits, loss = bigram(X_batch, y_batch)
logits.shape, loss

decode(bigram.generate(torch.zeros((1, 1), dtype=torch.long), max_tokens=100)[0].tolist())

'\n½b0-AGD⅔39dJnn9Rkr※G※UP;-8RéS-1pTk%U%N’C“B※ū!%r/♡D―*‘=e/U-’agép3DW.d"“egw>ūQHTW\nx5？y*–*？;（i>tQ6、’q～é'

In [25]:
optimizer = torch.optim.AdamW(bigram.parameters(), lr=1e-3)

In [26]:
for iter in range(1):
  Xb, Yb = get_batch("train")
  logits, loss = bigram(Xb, Yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

print(loss)

tensor(4.8314, grad_fn=<NegBackward0>)


In [29]:
print(decode(bigram.generate(torch.zeros((1, 1), dtype=torch.long), max_tokens=40)[0].tolist()))


  = ~ G  te b  
oiz  』   *    e  』  ro#e
