## Loading Data

In [2]:
# ! wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

## Training the Bigram Model

In [1]:
import torch

split_size = 0.9
batch_size = 32
block_size = 8
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
eval_iters = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
from data import CharData

data = CharData(text, params={
    'split_size': split_size,
    'batch_size': batch_size,
    'block_size': block_size,
    'device': device,
})

data loader successfully initiated.


In [4]:
from bigram import BigramLanguageModel

model = BigramLanguageModel(data.get_vocab_size())
model.to(device)

BigramLanguageModel(
  (token_embedding_table): Embedding(65, 65)
)

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [6]:
from utils import estimate_loss

for iter in range(max_iters):

    # check the loss once every eval_iters intervals pass
    if not iter % eval_iters:
        losses = estimate_loss(model, data, eval_iters)
        print(f"iter {iter} - train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    # get a batch of data
    X, Y = data.get_batch('train')

    # evaluate the loss
    logits, loss = model(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

iter 0 - train loss 4.6271, val loss 4.6232
iter 200 - train loss 3.1157, val loss 3.1271
iter 400 - train loss 2.6663, val loss 2.6797
iter 600 - train loss 2.5422, val loss 2.5654
iter 800 - train loss 2.5045, val loss 2.5340
iter 1000 - train loss 2.4906, val loss 2.5161
iter 1200 - train loss 2.4873, val loss 2.4990
iter 1400 - train loss 2.4794, val loss 2.5013
iter 1600 - train loss 2.4607, val loss 2.4841
iter 1800 - train loss 2.4661, val loss 2.4905
iter 2000 - train loss 2.4726, val loss 2.4883
iter 2200 - train loss 2.4717, val loss 2.5022
iter 2400 - train loss 2.4553, val loss 2.4918
iter 2600 - train loss 2.4718, val loss 2.4948
iter 2800 - train loss 2.4657, val loss 2.4966


In [8]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(data.decode(model.generate(context, max_new_tokens=500)[0].tolist()))


M:
O:
My mys cemar:
Wer peer cobearoflifos s.
TIZA RYghe s
HEDYg he
Lounrd
Ma hounnonghallveparisotrm o to Coswind ade histhert.
AThecas, o nt ave;
AMabarne k hof spetour malam owowieand illee ousoe cr it h. s pts.
MANRKE:
gmy my ste plke aimiagho wesisit wrer Canses howh cce, tuas s llltetod my be listhasw co I ave: f t CENCl medurmerur mayan.
St thy au o warme in.zenoupom.

QUCHomeeteate te Hectr.

Ton pprtit y ngs she we an w kend'shoonondio pr ber:
Wack fe h h,
's ait le,
TEDERYO: tifo donof
