In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# books = ["tiny_shakespeare.txt", 
#          "dracula.txt",
#          "blake.txt",
#          "pickwick.txt", "twist.txt", "hard times.txt", "dorrit.txt",
#          "decline1.txt",
#          "vanity.txt",
#          "folly.txt",
#          "white company.txt",
#          "heights.txt",
#          "secret agent.txt",
#          "nonsense.txt",
#          "Middlemarch.txt", "brother jacob.txt", "mill on the floss.txt", "the lifted veil.txt",
#          "alice.txt", "hunting of the snark.txt", "Through the looking glass.txt", "a tangled.txt", "bruno.txt",
#          "jude.txt", "mayor of castle.txt", "return of the native.txt", "Tess of the.txt", "mayor of castle.txt", "adam bede.txt",
#          "Northanger Abbey.txt", "mansfield.txt", "emma.txt", "sense and.txt",
#          "treasure island.txt", "kidnapped.txt"]

books = ["tiny_shakespeare.txt"]

In [3]:
lines = []
for book in books:
    with open(book, 'r', encoding='utf-8') as f:
        lines += f.readlines()

len(lines)

40000

In [4]:
%run tokenizer.py

In [24]:
merges = load_merges("v2_600.model")

tokenizer = Tokenizer(merges)

vocab_size = len(tokenizer.vocab)
vocab_size

600

In [25]:
data = "".join(lines)
print(len(data))
data = tokenizer.encode(data[:700000])

data_length = len(data)

1115394


In [26]:
import random

content_length = 256
chunk_size = content_length * 10

test, dev, train = [], [], []

i = 0
offset = 0

print(offset, data_length, offset < (data_length - chunk_size))

while offset < data_length - chunk_size:
    offset = i * chunk_size
    i += 1
    chunk = data[offset:offset+chunk_size]
    match random.randint(0, 10):
        case 0:
            test += chunk
        case 1:
            dev += chunk
        case _:
            train += chunk

data_length, len(train), len(dev), len(test)

0 419580 True


(419580, 345340, 30720, 43520)

In [11]:
def get_batch(data, batch_length=5, batch_size=5):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - batch_length, (batch_size,))
    b = torch.stack([torch.tensor(data[i:i+batch_length]) for i in ix])
    return b

In [27]:
get_batch(train, 15, 5)

tensor([[103, 583, 258, 437, 359, 536, 332, 104, 285, 262, 114,  46,  10,  10,
          72],
        [ 97, 314, 280, 100,  46,  10,  10,  81,  85,  69,  69,  78,  32,  69,
          76],
        [441, 119, 409, 115, 271, 118, 279, 371, 437,  97,  32, 501, 101, 316,
         103],
        [ 73,  78,  67,  69,  58,  10,  83, 267, 108,  32, 117, 382, 261, 256,
         109],
        [273, 428, 259, 110, 410, 286, 258, 115,  99, 282, 309,  39, 258, 261,
         256]])

In [13]:
%run attention.py

import torch.optim as optim

In [28]:
epochs = 22
training_runs = 800
batch_size = 96
context_length = 36
learning_rate = .1
embedding_dimensions = 32
num_heads = 4
head_size = embedding_dimensions // num_heads

print(head_size)
# our embedding_dimensions are still 'small' so we mutliply the size our our feed forward network to make up
multiplier = 4
model = FFMultiHeadAttention(vocab_size, embedding_dimensions, context_length, num_heads, head_size, multiplier)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

print(sum(p.numel() for p in model.parameters()), ' parameters')

lmbda = lambda epoch: 0.98

m_scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)

for ep in range(epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

8
65240  parameters
ep 0 tensor(4.2011, grad_fn=<DivBackward0>) [0.098]
ep 2 tensor(3.6958, grad_fn=<DivBackward0>) [0.0941192]
ep 4 tensor(3.6136, grad_fn=<DivBackward0>) [0.09039207968]
ep 6 tensor(3.5663, grad_fn=<DivBackward0>) [0.086812553324672]
ep 8 tensor(3.5125, grad_fn=<DivBackward0>) [0.08337477621301498]
ep 10 tensor(3.4788, grad_fn=<DivBackward0>) [0.08007313507497958]
ep 12 tensor(3.4524, grad_fn=<DivBackward0>) [0.07690223892601039]
ep 14 tensor(3.4323, grad_fn=<DivBackward0>) [0.07385691026454037]
ep 16 tensor(3.4190, grad_fn=<DivBackward0>) [0.07093217661806457]
ep 18 tensor(3.4042, grad_fn=<DivBackward0>) [0.06812326242398921]
ep 20 tensor(3.3979, grad_fn=<DivBackward0>) [0.06542558123199924]


In [29]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(tokenizer.decode(o))

 time ity liays .,

Li and ble theme wilouburseverve powerriest acref alhe he wiech
Rce:
Stear, 'nam
And sakenatself be pid you beyour dus peals:
For so be suLigly littpothe tion; and futhou cefen my ey. You so I youegedes'?

And mersir, to-will yet is than 
RLorbrove the to n ccke agecesou conroushath for theceerter
Thouenomy shstak alto ast bid which oprias ly, love him

WARARLE:
Ste redvancedence gie, I'll him , landdy time fore there hea, Gthul,
Well herte solover this do to f? Be reciou sderves the hierfof ge
 t other hinde, kpal, then?
Tell to the very fou imy s; and what merer make aplay e
What se the sonlace ge; why:
Usome r gdra sent y fioubace?

Hay,susts:
Noe'ence and et nto more, are thy supon fsty bothy tachue ae thpation
Sereth me nd, ruets foopere.
Wellow iseged how and arcknes's in CH!
And hip thoundererving thought era. mydius.

Rhaxer's Oflew have de umUEsoms art.
S; the than mond ste topares thto his hers,
Where o me shall sorhear, for ternquhe soice:
No e y acould ,

In [30]:
e_epochs = 20

In [31]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(3.3946, grad_fn=<DivBackward0>) [0.06283472821521206]
ep 2 tensor(3.3869, grad_fn=<DivBackward0>) [0.06034647297788966]
ep 4 tensor(3.3814, grad_fn=<DivBackward0>) [0.05795675264796523]
ep 6 tensor(3.3699, grad_fn=<DivBackward0>) [0.055661665243105805]
ep 8 tensor(3.3635, grad_fn=<DivBackward0>) [0.053457463299478813]
ep 10 tensor(3.3573, grad_fn=<DivBackward0>) [0.05134054775281945]
ep 12 tensor(3.3498, grad_fn=<DivBackward0>) [0.0493074620618078]
ep 14 tensor(3.3434, grad_fn=<DivBackward0>) [0.04735488656416021]
ep 16 tensor(3.3335, grad_fn=<DivBackward0>) [0.04547963305621946]
ep 18 tensor(3.3243, grad_fn=<DivBackward0>) [0.04367863958719317]


In [32]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(tokenizer.decode(o))

 e aurg, mere not nobe own llut kennam-laue th
Ps hear have fearppufush;
Anted at ce, acneay, erhat of my lm caom babe,
Our it for not none pitpcalight.

S in.
Shhear I what be ar on if harinede word have for come mext:
Thyour greaprivovero thee thee hay,
He snot to twrotencardmshs tr.
We Rome Biten; not not ubto th-a buse well oiussatect,
The sun bestuse opged stouenin-au's fathered thue in ch, throly sing:
Thadst chamrantame pard how he ube a them and wharft depn stings.

HEannot grazemanENR:

BUCLANTALARC:
Lored 
 ;row'tidus for rethat buonnce gthsen fell!
The glove, O Gin mof kns?

MELO:
El still RIRIINS:
Dmay t''s pitn,
Tas e whizus to hinot sendG ckn, thought sitged ons say fidper;
His p'd, those iann goldreptlemI lianth
Nearnot shath the king, Parvirhonews.

MMYAULEY ANDEman:
Haburs: andw the takgrellingteoceck

PERMERLAUENS:
And shartsd, frinst timerie, youly is ull wrown thy the pomive,

CUETHELInone and RIUS:

Haere and ifeeldarther, and to some got:
You'Lome car.




In [33]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(3.3177, grad_fn=<DivBackward0>) [0.04194896545954032]
ep 2 tensor(3.3139, grad_fn=<DivBackward0>) [0.04028778642734252]
ep 4 tensor(3.3047, grad_fn=<DivBackward0>) [0.038692390084819756]
ep 6 tensor(3.3016, grad_fn=<DivBackward0>) [0.03716017143746089]
ep 8 tensor(3.2923, grad_fn=<DivBackward0>) [0.03568862864853744]
ep 10 tensor(3.2866, grad_fn=<DivBackward0>) [0.034275358954055354]
ep 12 tensor(3.2801, grad_fn=<DivBackward0>) [0.03291805473947476]
ep 14 tensor(3.2762, grad_fn=<DivBackward0>) [0.03161449977179156]
ep 16 tensor(3.2733, grad_fn=<DivBackward0>) [0.030362565580828612]
ep 18 tensor(3.2663, grad_fn=<DivBackward0>) [0.029160207983827797]


In [34]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(tokenizer.decode(o))

 there comovose  one youndere
stlus damings.

LUTENE:
Haeathe top.
Youor reeth no for signgtefore
Haus to eeptsor
Thy hom futs suran were nisforth in we, mone-car?

ROCHA:

Sst giv;

COMIUCHARD IIII:
Andups very de call thentus Lon, in a beg for to of will sterfof theeer foogik' you's thear use;
What me, handd hopin hear o; and hims; and have an is binct;
And artchs to but did those statighten toet e clove.

Ser,:
Comise s:
I happould wece the aing y fest meoke sur
Yoont Ene: you happoor I fe, frienlif
 rehim thee down so new twmad. Rell'uke may menter;
To screnting from our ked, day: Srevie se,
Whse manother tle;lest in she to will vo's wort dea,,
To to fothe fleaes.

Bill:
Os. mork be'?
O fue

JeOse:
Bce to, banen pleave might acngs house withreave tho my hat and arthreadme?

POMERO: whinat:
O moe couson, which griacr plor;
I a to hiers ofoie thoith s, the lie--
Is what I liort I will our conurr, frearord you stear s, might to hand.

Old-MEL:Ends:
O, thepiring. What do to with to the

In [35]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(3.2657, grad_fn=<DivBackward0>) [0.028005463747668217]
ep 2 tensor(3.2579, grad_fn=<DivBackward0>) [0.026896447383260556]
ep 4 tensor(3.2573, grad_fn=<DivBackward0>) [0.025831348066883437]
ep 6 tensor(3.2555, grad_fn=<DivBackward0>) [0.024808426683434852]
ep 8 tensor(3.2480, grad_fn=<DivBackward0>) [0.023826012986770832]
ep 10 tensor(3.2476, grad_fn=<DivBackward0>) [0.022882502872494704]
ep 12 tensor(3.2418, grad_fn=<DivBackward0>) [0.02197635575874391]
ep 14 tensor(3.2375, grad_fn=<DivBackward0>) [0.021106092070697653]
ep 16 tensor(3.2376, grad_fn=<DivBackward0>) [0.020270290824698025]
ep 18 tensor(3.2370, grad_fn=<DivBackward0>) [0.019467587308039984]


In [36]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(tokenizer.decode(o))

 hirow, thing rings!

HUChiull:
As bappy bup the may his isk hom, will hop,
As ct many hiy;berameiouor im diantiome I w
To your silind King ofan O lacriing grow?

KIR KIORGICH:
Ma most is not es ap'cause pnobte! wak. in an purner, hiluem
swEthy it ,ou , Of lord? and led sust, that fidount it bloolo.

FFarESTER:
But so of dry:
Ridounds my.
Your pring, eysounerer huights te
If Pold in have this, leve no look the Sitced my of four banguenle hoopnk in mars.
Sit bltloudt, are reparew
 soned may be alluterfeff.

MUCUS:

LAumUCentwid:
Bt have speackcelly
Kther y brie
That spos which das thouts?
Nsh, slaith make you'ings, fid and your cant deak:
O literit ves and inds better thy cractous Ll s, see.
3racark Ritts;
Evsonought them noter, 'st we, which neen sath prefochard.

MasIE EDY ading ELole if chither?
Whale dies and denselnbe yemft my the me pted the lidasty belie,
Agarewor,
ShU, ieds habless, good by like lequdenemems
And grifte, and at, we was nell.

TIU


In [37]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(3.2392, grad_fn=<DivBackward0>) [0.0186966708506416]
ep 2 tensor(3.2309, grad_fn=<DivBackward0>) [0.017956282684956193]
ep 4 tensor(3.2278, grad_fn=<DivBackward0>) [0.017245213890631928]
ep 6 tensor(3.2265, grad_fn=<DivBackward0>) [0.016562303420562904]
ep 8 tensor(3.2259, grad_fn=<DivBackward0>) [0.01590643620510861]
ep 10 tensor(3.2224, grad_fn=<DivBackward0>) [0.015276541331386308]
ep 12 tensor(3.2244, grad_fn=<DivBackward0>) [0.01467159029466341]
ep 14 tensor(3.2245, grad_fn=<DivBackward0>) [0.014090595318994738]
ep 16 tensor(3.2223, grad_fn=<DivBackward0>) [0.013532607744362546]
ep 18 tensor(3.2174, grad_fn=<DivBackward0>) [0.012996716477685789]


In [38]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(tokenizer.decode(o))

 faway than ked, my him.
The to JunwiTE:
And alshywion fand sion'
I wit dasgemeg! Ae fusesving to heap s.

ARINCER::
DETHAUT'se our taturNIRD IVM:
Pthou brace sour sish ught hat?


MENENUMBORCLADYNII:
Nay, for thouer.

GLOU:
When he ftoomalve.
Be in the bl;shall corace.

NGLOU:
O some s, my of geavud; and and done fumpree, ms,
The good ed from most mother knopttcest all fread
sour ge yearet hie, cou!

BLOLIF
COMOMEOEERIIZAn Y RNIIIIII:
We whay, lie
 redesfriends as to and thy are disus
then call much supsear-he with
Corgetince parckint'dose men fleb'dow?

Tayate VINCIK:
vom gooctyell is gdue in Warwim's I gioked th to and the swerfartlill.
So ves these sted deight soing, he whs where ourtumy! thy you, of his hear redgights, and all ere He word
TnUo hontearteneps,

Rorcraiel:
Partregreauty, a msthers?Thy childing in ter
Your , an he enever I withuth's on bsatons:
Floughe.
One b most frught it would ine blooctoot.
Your ain Rodet, lot hiakould Prion fruppelle other here hath is chis to 


In [39]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(3.2142, grad_fn=<DivBackward0>) [0.01248204650516943]
ep 2 tensor(3.2137, grad_fn=<DivBackward0>) [0.011987757463564721]
ep 4 tensor(3.2178, grad_fn=<DivBackward0>) [0.011513042268007558]
ep 6 tensor(3.2155, grad_fn=<DivBackward0>) [0.01105712579419446]
ep 8 tensor(3.2115, grad_fn=<DivBackward0>) [0.010619263612744357]
ep 10 tensor(3.2118, grad_fn=<DivBackward0>) [0.01019874077367968]
ep 12 tensor(3.2130, grad_fn=<DivBackward0>) [0.009794870639041964]
ep 14 tensor(3.2093, grad_fn=<DivBackward0>) [0.009406993761735902]
ep 16 tensor(3.2085, grad_fn=<DivBackward0>) [0.00903447680877116]
ep 18 tensor(3.2077, grad_fn=<DivBackward0>) [0.008676711527143824]


In [40]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(tokenizer.decode(o))

 whe's give that post sworger.
Then lothm hed the hol,
Hon, quight agaow, buiel a 
God to Belb, Meim had Seceds, shomius.

VAPESENR:
O,e conspingllling to a viself
Co hybloodom dis'eratery hears, mloth
For grele heady kings he wkeet cupserke comuppyselvereather;
Then Deone'd thous devor's with ock Hno my thatd?

ROTUE:
As cainfha dis, nurdon sune;
Thouyself my stus, and head the doish scese brintatle let.

MENENRRY:
Perun ere me'd me, but there foose.
Repty kingd, 
 and wathers of hast a not the fort a heart.

MEO:
Rewellsanst weart gritlmer sck evo and se,
And benchy have he bandge SMurst;
Will be, and ck, and resuoopartow it,
Macretwickeopfff or dos, deadd?
 as sired m would speath his welled, in feratered on viragghti?


GLONRTENULARE:
Ad byuupon of a me ally stroceseavever woocit:
Tha gpon you fooard, but you her us to blords and blood,
What to foroos: m. I bloon.our more yoit? Could e cure faniz,
Ken allanct ' inme ga; and and they so ents that.
The have tentive b,


In [41]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(3.2101, grad_fn=<DivBackward0>) [0.008333113750668928]
ep 2 tensor(3.2088, grad_fn=<DivBackward0>) [0.008003122446142439]
ep 4 tensor(3.2076, grad_fn=<DivBackward0>) [0.007686198797275197]
ep 6 tensor(3.2067, grad_fn=<DivBackward0>) [0.007381825324903099]
ep 8 tensor(3.2060, grad_fn=<DivBackward0>) [0.007089505042036937]
ep 10 tensor(3.2056, grad_fn=<DivBackward0>) [0.006808760642372274]
ep 12 tensor(3.2055, grad_fn=<DivBackward0>) [0.006539133720934331]
ep 14 tensor(3.2042, grad_fn=<DivBackward0>) [0.006280184025585331]
ep 16 tensor(3.2044, grad_fn=<DivBackward0>) [0.006031488738172152]
ep 18 tensor(3.2010, grad_fn=<DivBackward0>) [0.005792641784140534]


In [42]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(tokenizer.decode(o))

 high'd wigife;
I dame lorseige frue did you lorden but sore sensver.
the am time that to we my ersturye, and my soe's obpice.

BENET:
Citje sting:
'd, kingill-hurp a weiewes' cielther
Whourk. And in Mon fagriclanded notupueser.
What wriks lo's'd lamns, wruefeadsels mea,
Bd's frient: bit; sound'd of reds ick,
My the up the from wip beces: whom wartino,
as sion, if ow horny were at but you wfine
To comshaiuben lou ka harmy.

WLRD ICE:
Oun Rwiland giveffore thy 
 all was by worsse.

HARD:
Mardo! With I ve Sreogust am boris'erving.

NurUND:
Henencesthead yous thinhe
And felled Rust, landerboth sous much of cely therep chworie'd unroledight
Eth new auke upon heanace, do soos and he mind sc, a your e soiot unes,
What prattceste,breath wis any ino ans you womother partryless, it sing thoud co the to and to palt trace king uk,
I'll ints, The Romgh s! walend here to dow,
evy lequstart's use you  brink a mom I begene
ance; flavow noble and we elmet our sed year.
dve that the out yourlt youould 

In [43]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())

ep 0 tensor(3.2046, grad_fn=<DivBackward0>) [0.005563253169488569]
ep 2 tensor(3.1992, grad_fn=<DivBackward0>) [0.005342948343976821]
ep 4 tensor(3.2011, grad_fn=<DivBackward0>) [0.005131367589555339]
ep 6 tensor(3.2013, grad_fn=<DivBackward0>) [0.004928165433008947]
ep 8 tensor(3.1974, grad_fn=<DivBackward0>) [0.004733010081861793]
ep 10 tensor(3.2013, grad_fn=<DivBackward0>) [0.004545582882620065]
ep 12 tensor(3.2004, grad_fn=<DivBackward0>) [0.00436557780046831]
ep 14 tensor(3.2046, grad_fn=<DivBackward0>) [0.004192700919569765]
ep 16 tensor(3.1973, grad_fn=<DivBackward0>) [0.004026669963154802]
ep 18 tensor(3.2029, grad_fn=<DivBackward0>) [0.003867213832613871]


In [44]:
idx = torch.zeros((1, 1), dtype=torch.int)
for i in range(2):
    o = model.generate(idx, 300).data[0].tolist()
    print(tokenizer.decode(o))

 me the but by the caro sh tus uze
When avise was thy drivues. Stto s, andfsfor ped and frientter them thir
Thave penged thinatch,
No mall wks sevpaight.

MENERREY:
If r do timhat sine a.
my cre, now and to and boureous though the dUns. Thunon:
Eordit, and it, the mains, hear hajem himly thisouke.

SU:
M many me ract?
What githt pry, Es disjovery bat the hop
And dmy lo pasemtly evle red wor, ff bid man exive rike of ur ay, Hill Ring,
Nock wayt?Wlssis,ell yournds entice thy are mon foro hlor thee.

Kst CLshinp 
 conchise: and steard farciaiews, th, wart!
Mareet you h
And hi. Thourood op is fooli: what this,
And my lords greaes defe, sither appd, now no Fre thine tocitle not or all the o gheep!

COMOP inLE:
The kindour '
Ouhiveforice; some w we duned re;
The there Iw, tearer you br, I in goove marieled mor,
Which Tosm, mra he that knes, when bray!
engard.

HEDWINWKS Xerch:
Hatalilie to fren w; sself me could leve Eth much the withrad mroast me way Yroughts, to in kraver was;
Will.
It Eng

In [None]:
for ep in range(e_epochs):
    epoch_loss = 0
    for tr in range(training_runs):
        t_b = get_batch(train, context_length+1, batch_size)

        x = t_b[:, 0:context_length]
        y = t_b[:, 1:context_length+1]

        logits, loss = model(x, y)

        epoch_loss += loss
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    m_scheduler.step()
    
    if ep % 2 == 0:
        print("ep", ep, epoch_loss/training_runs, m_scheduler.get_last_lr())