### init

In [6]:
import numpy as np
import torch
import torch.nn.functional as F
import seaborn as sns
from utils.data_util import HackerNewsBigrams
from torch.utils.data import DataLoader

sns.set_theme()
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

g = torch.Generator(device=device).manual_seed(0)

Using mps device


### import training data

In [2]:
batch_size = 1024
training_data = HackerNewsBigrams(train=True)
test_data = HackerNewsBigrams(train=False)
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

W = torch.randn((28, 28), generator=g, requires_grad=True, device=device)

### gradient descent

In [3]:
for epoch in range(10):
    epoch_losses = []
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)

        # forward pass
        logits = X @ W # (16, 28) @ (28, 28) = (16, 28)
        counts = logits.exp() # (16, 28)
        probs = counts / counts.sum(dim=1, keepdims=True) # (16, 28)
        batch_losses = -probs[torch.arange(probs.shape[0]), y].log() # (16, 1)
        batch_loss = batch_losses.mean()       
        epoch_losses.append(batch_loss.item())
        
        # backward pass
        W.grad = None
        batch_loss.backward()

        # update weights
        W.data -= 1e-1 * W.grad

    avg_epoch_loss = np.mean(epoch_losses)
    print(f'epoch: {epoch+1}, average training loss: {avg_epoch_loss.item():.4f}')

epoch: 1, average training loss: 2.5857
epoch: 2, average training loss: 2.4433
epoch: 3, average training loss: 2.4295
epoch: 4, average training loss: 2.4242
epoch: 5, average training loss: 2.4214
epoch: 6, average training loss: 2.4198
epoch: 7, average training loss: 2.4187
epoch: 8, average training loss: 2.4179
epoch: 9, average training loss: 2.4174
epoch: 10, average training loss: 2.4170


### sample from model

In [9]:
text = ''
n = 0
nll = 0.0

ix = training_data.ctoi['<>']

while True:
    pix = ix
    X = F.one_hot(torch.tensor(pix, device=device), num_classes=28).float()
    logits = X @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=0, keepdims=True)
    ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()

    if ix==training_data.ctoi['<>']:
        break
    
    text += training_data.itoc[ix]
    n += 1

    # calculate loss
    prob = probs[ix]
    logprob = torch.log(prob)
    nll -= logprob

print(text)
print(f'average negative log-likelihood: {nll/n}')

yofoustuldystofeatroielasn oy ithesic actve isn redit anksther al thaso ay a fug erecons lydintoremyeabl find it ffommomifou ss ca t n ude arg gly t wf ple ttorff utatusskits bu l ago wnt fon ang hitht tenciseas thof thikerts tediliot owottha diontho pret  thtpouthhtot ovea at ttsesid owo m tansutomel cthe a ane thtonttofit is ns d aill soul d ituthicurayonsikingono  t te avie a bemis porndmisi se th g l g ty insseioweclohigld ingzbary ang cticonunve t tif tcry ast igseabo splus cofanondes itelonithat t tretoppapenio mou lleany ie dittisnss angret opre omab h tes if lintmpringornowabendstningreritwesqurs chili pliskhrewebissouthrobyepok bemounieatouathadyea hethis saiti lle i withi havemaf ar ds wheve t thaledins f ld andingl hait athionw lldoredel ywimisen asillye d arics andhassthe surimbesme bme s re amex sont mppar s sct congurinatr conkecare f osiemprearongolcla fo oully soefin siogor o wit core d iatecunk wintidofins mey prchelepra iondnsis tthprd ie isusthe s is cle ururearemeva