# Train Middle Earth Text Gen

## Clone the Repo

In [None]:
!git clone https://github.com/brotSchimmelt/MiddleEarthTextGen.git

import os
os.chdir("MiddleEarthTextGen")

## Load the Text

In [None]:
from google.colab import files

upload = files.upload()
file_name = list(upload.keys())[0]

In [None]:
with open(file_name, 'r', encoding='utf-8') as f:
  text = f.read()

vocab = sorted(list(set(text)))
vocab_size = len(vocab)

print(f'{file_name} is {len(text):,} characters long with a vocabulary size of {vocab_size}.\n')
print(f'Vocabulary: {repr("".join(vocab))}\n')
print(f'First 500 character sequence:\n{text[:500]}')

## Tokenize the Vocabulary

In [None]:
# create a simple character tokenizer
stoi = { ch:i for i,ch in enumerate(vocab) }
itos = { i:ch for i,ch in enumerate(vocab) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [None]:
test_string = "hello there"
assert decode(encode(test_string)) == test_string

In [None]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)

## Train / Validation Split

In [None]:
# 0.9 / 0.1 split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

## Define Context Length and Batch Size

In [None]:
batch_size = 4
context_length = 8

In [None]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])
    return x, y

xb, yb = get_batch('train')

## Bigram Model

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        return logits

bigram_model = BigramLanguageModel(vocab_size)
out = bigram_model(xb, yb)
print(out.shape)