In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
import torch
import torch.nn as nn
import torch.nn.functional as F

with open('input.txt', 'r', encoding='utf-8') as file:
    text = file.read()

context_length = 256
block_size = context_length
num_heads = 6
batch_size = 64
drop_prob = 0.2
embedding_dim = 384
lr = 3e-4
head_size = 384
num_blocks = 6
iters = 5000
vocab = sorted(list(set(text)))
vocab_size = len(vocab)
char_to_int = {}
int_to_char = {}
for i, c in enumerate(vocab):
  char_to_int[c] = i
  int_to_char[i] = c

def encode_string(data_str):
  tokens = []
  for c in data_str:
    tokens.append(char_to_int[c])
  return tokens

def decode_tokens(data_ints):
  char_seq = []
  for i in data_ints:
    char_seq.append(int_to_char[i])
  return ''.join(char_seq)

dataset = torch.tensor(encode_string(text), dtype = torch.long)
n = int(0.9 * len(dataset))
train_set = dataset[:n]
val_set = dataset[n:]

def batch_loader(dataset):
  starting_indices = torch.randint(len(dataset) - context_length, (batch_size,))
  x_vals = torch.zeros((batch_size, context_length))
  y_vals = torch.zeros((batch_size, context_length))
  for i in range(batch_size):
    x = dataset[starting_indices[i]:starting_indices[i] + context_length]
    y = dataset[starting_indices[i] + 1:starting_indices[i] + 1 + context_length]
    x_vals[i] = x
    y_vals[i] = y
  x_vals = x_vals.to('cuda').long()
  y_vals = y_vals.to('cuda').long()
  return x_vals, y_vals

class Head(nn.Module):

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(embedding_dim, head_size, bias = False)
    self.query = nn.Linear(embedding_dim, head_size, bias = False)
    self.value = nn.Linear(embedding_dim, head_size, bias = False)
    self.scale_factor = head_size ** 0.5
    self.dropout = nn.Dropout(drop_prob)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x): #x is B,T,C
    k = self.key(x) # Key matrix is HS,C, so HS,C * C,T ----> B,T,HS
    q = self.query(x)
    v = self.value(x)
    weights = k @ q.transpose(1, 2) # B,T,HS * B,HS,T ---> B,T,T
    weights /= self.scale_factor
    B, T, C = x.shape
    weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    weights = F.softmax(weights, dim = -1)
    weights = self.dropout(weights)
    return weights @ v #B,T,T * B,T,HS -> B,T,HS

class MultiheadedAttention(nn.Module):

  def __init__(self, n_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList()
    for _ in range(n_heads):
      self.heads.append(Head(head_size // n_heads))
    self.proj = nn.Linear(embedding_dim, embedding_dim)
    self.dropout = nn.Dropout(drop_prob)

  def forward(self, x):
    out = self.heads[0](x)
    for i in range(1, num_heads):
      out = torch.cat((out, self.heads[i](x)), dim = -1)
    out = self.proj(out)
    return self.dropout(out)

class Feedforward(nn.Module):

  def __init__(self):
    super().__init__()
    self.network = nn.Sequential(nn.Linear(embedding_dim, 4 * embedding_dim), nn.ReLU())
    self.proj = nn.Linear(4 * embedding_dim, embedding_dim)
    self.dropout = nn.Dropout(drop_prob)

  def forward(self, x):
    x = self.network(x)
    x = self.proj(x)
    return self.dropout(x)

class TransformerBlock(nn.Module):
  #communication, then computation

  def __init__(self):
    super().__init__()
    self.comm = MultiheadedAttention(num_heads, head_size)
    self.comp = Feedforward()
    self.first_ln = nn.LayerNorm(embedding_dim)
    self.second_ln = nn.LayerNorm(embedding_dim)

  def forward(self, x):
    x = x + self.comm(self.first_ln(x))
    x = x + self.comp(self.second_ln(x))
    return x

class BigramModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
    self.pos_embeddings = nn.Embedding(context_length, embedding_dim)
    self.transformer_blocks = nn.Sequential(TransformerBlock())
    for _ in range(num_blocks - 1):
      self.transformer_blocks.append(TransformerBlock())
    self.transformer_blocks.append(nn.LayerNorm(embedding_dim))
    self.to_logits = nn.Linear(head_size, vocab_size)

  def forward(self, batch, targets = None):
    B, T = batch.shape
    token_embeddings = self.embedding_table(batch)
    pos_embeddings = self.pos_embeddings(torch.arange(T, device = 'cuda'))
    x = token_embeddings + pos_embeddings
    almost_logits = self.transformer_blocks(x)
    logits = self.to_logits(almost_logits)
    loss = None
    if targets is not None:
      logits = logits.view(batch_size * context_length, vocab_size)
      targets = targets.view(batch_size * context_length)
      loss = F.cross_entropy(logits, targets)
    return logits, loss




  def generate(self, max_seq_len, context):
    for t in range(max_seq_len):
      context = context[:, -context_length:]
      logits, loss = self.forward(context)
      logits = logits[:,-1,:]
      probs = F.softmax(logits, dim = 1)
      next = torch.multinomial(probs, num_samples = 1)
      context = torch.cat((context, next), dim = 1)
    return context

decoder = BigramModel()
model = decoder.to('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [None]:
for i in range(iters):
  x, y = batch_loader(train_set)

  logits, loss = model(x, y)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

In [None]:
starting_token = torch.zeros((1, 1), dtype=torch.long, device = 'cuda')
print(decode_tokens(model.generate(5000, starting_token)[0].tolist()))