In [88]:
import torch 

with open('mini_shakespeare.txt', 'r', encoding='utf-8') as file:
    file_content = file.read()

alphabet = sorted(list(set(file_content)))

stoi = {char: i for i,char in enumerate(alphabet)}
itos = {i:char for i, char in enumerate(alphabet)}

encode = lambda text: [stoi[char] for char in text]
decode = lambda encoded_text: [itos[char] for char in encoded_text]

encoded_data = torch.tensor(encode(file_content), dtype=torch.long)
cutoff = int(len(encoded_data)*0.9)
train = encoded_data[:cutoff]
val = encoded_data[cutoff:]

In [89]:
torch.manual_seed(1337)    
block_size = 8
batch_size = 4

def get_batch(split):
    data = train if split == "train" else val
    batch_indicies = torch.randint(len(data) - block_size, (batch_size,1))
    x = torch.stack([data[i:i+block_size] for i in batch_indicies])
    y = torch.stack([data[i+1:i+block_size+1] for i in batch_indicies]) 
    return x, y

x, y = get_batch('train')
x,y

(tensor([[24, 43, 58,  5, 57,  1, 46, 43],
         [44, 53, 56,  1, 58, 46, 39, 58],
         [52, 58,  1, 58, 46, 39, 58,  1],
         [25, 17, 27, 10,  0, 21,  1, 54]]),
 tensor([[43, 58,  5, 57,  1, 46, 43, 39],
         [53, 56,  1, 58, 46, 39, 58,  1],
         [58,  1, 58, 46, 39, 58,  1, 46],
         [17, 27, 10,  0, 21,  1, 54, 39]]))

In [95]:
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)    
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embeding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, x, y=None):
        ## gets the logits and loss
        ## inference
        logits = self.token_embeding_table(x) ## (batch, block, vocab_size)
        loss = None
        ## loss
        if y != None:
           B, T, C = logits.shape
           changed_view = logits.view(B*T, C) 
           y = y.view(B*T)
           loss = F.cross_entropy(changed_view, y)

        return logits, loss
    
    def generate(self, context, max_new_tokens =100):

        for i in range(max_new_tokens):
            logits, _ = self(context) ## (B, T, C)
            logits = logits[:, -1, :] ## (B, C) last token
            probs = F.softmax(logits, dim=1)
            next_token = torch.multinomial(probs, num_samples=1) # gets next token
            context = torch.concat((context, next_token), dim=1)

        return  context



    
model = BigramLanguageModel(len(alphabet))
logits, loss = model(x,y)
print("".join(decode(model.generate(torch.zeros((1,1), dtype=torch.long))[0].tolist())))


SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [101]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 32

for i in range(10000):
    x,y = get_batch(train)
    logits, loss = model(x,y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss)

tensor(2.3311, grad_fn=<NllLossBackward0>)


In [105]:
print("".join(decode(model.generate(torch.zeros((1,1), dtype=torch.long))[0].tolist())))


LL:
Fanese br,

's:



A:
cape lave IO:
Wif thourant
O:
MIfo cown ame?
RIfith'ios! co werendenke g a
