In [1]:
from torch import nn
from torch.nn import functional as F
from pathlib import Path

import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1337)

<torch._C.Generator at 0x1047f1eb0>

In [2]:
with Path('../input.txt').open('r', encoding='utf-8') as f:
    text = f.read()

vocab = sorted(list(set(text)))
stoi = { ch: i for i, ch in enumerate(vocab) }
itos = { i: ch for i, ch in enumerate(vocab) }
encode = lambda x: [stoi[s] for s in x]
decode = lambda x: [itos[s] for s in x]

print("".join(decode(encode("hii there"))))

hii there


In [3]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[n:]
val_data = data[:n]

In [4]:
batch_size = 4
block_size = 8
n_embd = 12

eval_iters = 200
max_iters = 4000

vocab_size = len(vocab)

In [5]:
# torch.randint

In [6]:
def get_batch(split, batch_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [7]:
# B, T, C = 4, 8, 6
# x = torch.randn(B, T, C)
# tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros(T, T)
# wei = wei.masked_fill(tril == 0, float('-inf'))
# wei = F.softmax(wei, dim=-1)
# wei @ x

In [8]:
# class 

In [9]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        # self.position_embedding_table = nn.Embedding(torch.arange(T))

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        logits = self.lm_head(tok_emb) # (B, T, vocab_size)
        if targets is None:
            losses = None
        else:
            # targets.shape when cross_entropy turns them
            # to OHE --> (B, T, vocab_size).
            # cross_entropy needs shape (minibatch, C).
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            losses = F.cross_entropy(logits, targets)
            # F.softmax(logits, dim=-1)
        return logits, losses

    def generate(self, idx, max_new_tokens=1000):
        for _ in range(max_new_tokens):
            # idx_cond = idx[:-batch_size:]
            logits, _ = self(idx)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx

In [10]:
m = LanguageModel(vocab_size, n_embd)
m = m.to(device)

# xb, yb = get_batch('train', batch_size)
# logits, losses = m(xb, yb)
# logits, losses

In [11]:
_in = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(_in)
print("".join(decode(out[0].tolist())))


&JVyYSUXyll' IrpWnUkZnSJwTTWcELMoSbvE& X?GGnka33:PgqkWK.Yq;:JPP.FiwkHjTMFzbEHuxsbuwZweX?dOJGUGp
wnY'mOwVWpeMsnDaXEOscLfJt-:Bpf&kWZpNKseO3
bx&JwdYw-,x:zAd IRFbTZqQ
bZeuuZthlXNzdXkU&Q;Y:PFi.v&BZG
,&BAigqrT.c:Pq,kWfzetn3XVyX-YBfHkUTk&PvdTcSe.n'f,FJp?ARiOuQeUPXBxsPviq3GmHUr'rknJaWm&rIaIAnSL$Pq?Vgb
RZni3cLvFbdn-tLhGpg
Aa!GMdg  te.T?bTiW.XzRYiPlai3vs&
$K.jx$n$? QWOfPmRJgrIAcmFp;b&b$sLxm:u3q3XAWrZEv IjjQ-y&wUCMC'erIPlwUu3n rOchXEsIYd qusPp&UdLMCKsdXXB$XkFMR,$a$yAMjWxd EX
,EJC,$? VV$RUtZNC?KpFUWFjR$zl$dammFwpkTEX
L.JYKexPaT
kBAkxzEzS
iQ jjkrVGn$v&yQVu
&lKuzJK.JkWnqXilxaBaaPgJ;b$l3&uS.IWnMff!Gmmrhv3
cMvqtNEem3h'JYy
Hfy?ZICpzqL.r:CwagxLaLj-?Aer:I?uw&TxUTOBK.JAO-pFW'DhEs
l'mUgbExLGh,
oiIHw'-OWfyzI3MG JOwOBEpSgmy:kPqEd Nj3WePrFa?VykT3nuT:tKypJ.KfPFBUm3i$'FjF$c-rIekCuqp ATl$;
yAVrwmRpxwiXelf&boD
ytPvL!JD rvA&ANDFJhfFkT3Ra
xpP$;lXzZIZtpPtsPIEPlUX'pH f?MLpaxUfy$AOvo3'
:fBfRw&mqHQ,k.ttRFWFv'Ne:&xVHyFAqVAOq!Fv3EzteT'xhJnFMZYxwzGZqrpf&ZKnER,rITIHZ'WdXEJAX tO'EuIFYLg:iO$ mHPv
;LjdLkxXV$bCEgeg!'$JcUPaSBZ

In [None]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        loss_list = torch.zeros(eval_iters)
        for _ in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, losses = model(X, Y)
    model.train()
    return out

In [12]:
m = LanguageModel(vocab_size, n_embd)
m = m.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

for iter_ in range(max_iters):
    if iter_ % eval_iters == 0:
        out = estimate_loss(m, batch_size)
        print(f"train loss {out['train']} val loss {out['val']}")
    xb, yb = get_batch('train', batch_size)
    logits, losses = m(xb, yb)
    optimizer.zero_grad()
    losses.backward()
    optimizer.step()
print(losses.item())

2.5596256256103516


In [13]:
_in = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(_in)
print("".join(decode(out[0].tolist())))


RNutheuc.
PERUCCAOzPk'mas t l ses ostO:
Fe k hy itueriinsiniw, le ed t n, fre m
Pvese spe dilithe;:
Yh payo mourthezJ
olucy Lf q:

Rooua'ororhlly thowuan ferhe tes S
PTRUClir heticowhesldavage hy, osonaar whe, gant he a.
Sh iorerewhawhe?
Ne llUJMiniouthay, d, thend ut wesmal assthoupolest th can oworghe govlthy t ftT:
Weu t w n blios ha wadorie chensthio hint mese-wither iy t,

Qt thak
-pake wiare t, o we, inoSouniolraintadais ilo fe hierhoWhent m fithie san.
cyowet t kc


ASTO: tr hanouris.

ANIfou norulr wt cikul t ut fuiggce w oslusc, tou ar as, ft pt cisir more handanorrtoft cee, won IO:
Sat and,

Wes, kbely n thyawemy.!
GANAorer'sou asglLECSMhot athin tu,
Rje be t mane t ouporofu t tin athy Bme chlos piclla ifedrtow,
YPNIORTNThe be hareallerlo fonshonk.


TIHqAbenofrir ha
anthest othaveal bendhe nl w
Totakemy IO:
y githladegs
chend t unere cspen AOILulu an al minu mes st y ty ty,
HIOangg y but wo gas Axrineshong.
BlensUCBmane halanmerg ty.
Whene me cl bla, s y t maasTon tod itheo