In [1]:
whole_text = open('../data/tinyshakespeare.txt', 'r').read()
lines = whole_text.splitlines()
len(lines)


40000

In [2]:
vocab = sorted(set(''.join(whole_text)))
vocab_size = len(vocab)
vocab_size, ''.join(vocab)

(65, "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

In [3]:
# tokenizer
stoi = {c:i for i,c in enumerate(vocab)}
itos = {v:k for k,v in stoi.items()}
encode = lambda str: [stoi[c] for c in str]
decode = lambda ints: ''.join([itos[i] for i in ints])


In [4]:
import torch

In [5]:
data = torch.tensor(encode(whole_text))
data.shape, data.dtype

(torch.Size([1115394]), torch.int64)

In [6]:
n = int(data.shape[0] * .9)
train_data = data[:n]
val_data = data[n:]
train_data.shape, val_data.shape

(torch.Size([1003854]), torch.Size([111540]))

In [7]:
block_size = 8
train_data[:block_size+1]


tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [8]:
# time (T) apparently (or T as in Token?)
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(len(x),):
  context = x[:t+1]
  print(f'{context} -> {y[t]}')

# this is apparently called T (time) dimension? or Token? 
# i think from the BTC acronym we'll see more of later

tensor([18]) -> 47
tensor([18, 47]) -> 56
tensor([18, 47, 56]) -> 57
tensor([18, 47, 56, 57]) -> 58
tensor([18, 47, 56, 57, 58]) -> 1
tensor([18, 47, 56, 57, 58,  1]) -> 15
tensor([18, 47, 56, 57, 58,  1, 15]) -> 47
tensor([18, 47, 56, 57, 58,  1, 15, 47]) -> 58


In [30]:
torch.manual_seed(1337)
# batch (B)
block_size = 8
batch_size = 4

def get_batch(split):
  data = train_data if split == 'train' else val_data
  # get offset starts for all batches
  ix = torch.randint(0, data.shape[0]-block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x,y

xb,yb = get_batch('train')
xb,xb.shape, yb, yb.shape

(tensor([[24, 43, 58,  5, 57,  1, 46, 43],
         [44, 53, 56,  1, 58, 46, 39, 58],
         [52, 58,  1, 58, 46, 39, 58,  1],
         [25, 17, 27, 10,  0, 21,  1, 54]]),
 torch.Size([4, 8]),
 tensor([[43, 58,  5, 57,  1, 46, 43, 39],
         [53, 56,  1, 58, 46, 39, 58,  1],
         [58,  1, 58, 46, 39, 58,  1, 46],
         [17, 27, 10,  0, 21,  1, 54, 39]]),
 torch.Size([4, 8]))

In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

class Bigram(nn.Module):
  def __init__(self, vocab_size):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
  
  def forward(self, idx, targets=None):
    logits = self.token_embedding_table(idx)
    B,T,C = logits.shape
    logits = logits.view(B*T, C)
    if targets == None:
      return logits, None
    targets = targets.view(B*T)
    loss = F.cross_entropy(logits, targets)
    return logits, loss
  
  def generate(self, idx, max_tokens=10):
    for _ in range(max_tokens):
      logits, loss = self(idx)
      print(f'{logits.shape=}')
      probs = F.softmax(logits[-1:], 0)
      print(f'{probs.shape=}')
      hit = torch.multinomial(probs, num_samples=1)
      print(f'{hit.shape=}')
      idx = torch.cat((idx, hit))
    return idx


m = Bigram(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss.shape)

infe = m.generate(torch.zeros(1,1, dtype=torch.long), max_tokens=70)
decode([i.item() for i in infe])

torch.Size([32, 65])
torch.Size([])
logits.shape=torch.Size([1, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([2, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([3, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([4, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([5, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([6, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([7, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([8, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([9, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size([10, 65])
probs.shape=torch.Size([1, 65])
hit.shape=torch.Size([1, 1])
logits.shape=torch.Size

'\nZbOJwTTWcELMoSavZ& C?nGnkE33:FgqkWKY&q;:JP!gFiwk\njgMTzbEHux3bLjLweX?DO'