## Code based on [Andrews Karpathy nanoGPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY)

In [16]:
import torch
import torch.nn.functional as F

## Read data

In [2]:
data_file = open("data/tinyshakespeare.txt", "r", encoding="utf-8")
data_str = data_file.read()

In [3]:
print(data_str[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [4]:
print(len(data_str))

1115394


In [14]:
vocab = sorted(list(set(data_str)))
vocab_len = len(vocab)
print(vocab_len, vocab)

65 ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [6]:
c2i = {c:i for i,c in enumerate(vocab)}
i2c = {i:c for i,c in enumerate(vocab)}

encode = lambda text: [c2i[c] for c in text]
decode = lambda ints: "".join([i2c[i] for i in ints])

In [7]:
encode("hello")

[46, 43, 50, 50, 53]

In [8]:
decode(encode("hello"))

'hello'

# Pytorch tran-val tensor data

In [9]:
data = torch.tensor(encode(data_str), dtype=torch.uint8)

In [12]:
n = int(len(data)*0.9)
train_data = data[:n]
valid_data = data[n:]

print(len(train_data), len(valid_data))

1003854 111540


In [None]:
CONTEXT_LEN = 9

# Model 1: Baseline 1: Only look at current character
We use Embedding(vocab_len, vocab_len) like **looking the row char probabilites for each char**

In [18]:
class Baseline(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.table = torch.nn.Embedding(vocab_len, vocab_len)
        
    def forward(self, in_idxs, target_idxs=None):
        # in_idxs & target_idxs are both [BS, SEQ_LEN] 
            
        # Forward
        out_probs = self.table(in_idxs) # out_probs (aka logits) is [BS, SEQ_LEN, VOC_LEN] 
        
        # Compute loss based on expected idxs?
        if target_idxs is None:
            return out_probs
        else:
            # Rechape (needed for Pytorch crosssEntr)
            BS,SEQ_LEN,VOC_LEN = out_probs.shape
            out_probs   = out_probs.view(BS*SEQ_LEN, VOC_LEN)
            target_idxs = target_idxs.view(BS*SEQ_LEN)

            # Loss (crosssEntr)
            ce_loss = F.cross_entropy(out_probs, target_idxs)

            return out_probs, ce_loss
        
m = Baseline()

In [19]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

# Model 2: Baseline 2: Mean of embeddings of current + past characters


# Model 3: Transformer
