In [1]:
import jax
import optax
import flax
from flax.training.train_state import TrainState

from functools import partial
from tqdm.auto import tqdm
from datasets import TextDataset
import models
import training

In [2]:
rng_key = jax.random.key(0)
context_len = 8
batch_size = 32

dataset = TextDataset(data_path="shakespeare.txt")
model = models.TransormerLM(
    vocab_size=len(dataset.tokenizer.vocab),
    max_context_len=context_len,
    embedding_dim=64,
    head_size = 16,
    n_heads = 4,
    n_layers = 2,
)

In [3]:
optimization_step = jax.jit(partial(training.optimization_step, loss_fn=training.logit_prediction_loss))
get_batch = jax.jit(partial(dataset.get_batch, batch_size=batch_size, context_len=context_len))
generate_token = jax.jit(partial(model.apply, method=model.generate_token))
def generate_text(params, prompt: str, length=500, rng_key=jax.random.key(0)):
    context = dataset.tokenizer.encode(prompt)
    print("\033[94m", dataset.tokenizer.decode(context), "\033[0m", end="")
    for sub_rng in jax.random.split(rng_key, length):
        next_token, context = generate_token(params, context, sub_rng)
        print(dataset.tokenizer.decode(next_token[None]), end="")

losses = []
train_state = TrainState.create(
    apply_fn=model.apply,
    params=model.init(rng_key, dataset.sample(context_len, rng_key)),
    tx=optax.adam(3e-4),
)
for epoch_rng_key in tqdm(jax.random.split(rng_key, 10)):
    for batch_rng_key in tqdm(jax.random.split(epoch_rng_key, 10000), leave=False):
        x, y = get_batch(rng_key=batch_rng_key)
        train_state, loss_value = optimization_step(train_state, x, y)
        losses.append(loss_value)
    print(f"Loss: {sum(losses) / len(losses)}\nGeneration test: ")
    generate_text(train_state.params, prompt="To be or", rng_key=epoch_rng_key)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.873110055923462
Generation test: 
[94m To be or [0mcucte:

I ous, wy tomor praus oawe, au, Baouho shed n sas! utenNg erosl os h! ncoshk tir touhy k sovisor o se fid s; wecaeowarolcl it anbyoe
Nd cangome wche, se Id bans I:
Y idonoe Vghit ongh ile p
RLIE
:
Ns'y:
Ty,
Tas? I d bsredtovel we o me ifleubls t Igiury ly inopid Itonoteuls, t tis.

MLIOjachoh deheamano minpeiver jhororosiu tin o tor is.


Wluur ot thhon fongy thoita.
Fls Tetat,
I lone, iumo bond,
Teorowram
Kmesio I rilt'Kon t tiyiof k, S donereorocmu ytacyos
Nufhhe soungor Hgthee Fhee nU

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.728705883026123
Generation test: 
[94m To be or [0m rhive wiu! Mi?


M:
E.

RO: fhdto; owan ed se yisovo
TE heeios;
Wathis am honey nafrme gek!,
CAd IEUSS:
This nf, us hou tent kou jheace; nel:
If oises.

NEd shemy chenive foy hiowi by,
I tapasu gore?

POr
WA Koffchecir, sher ththid procolume thiuvo nevhe gasin o Tariy mory, ta.
TUStrst at terpruur:
Holas mer bohit?

F ATI:
Le-TAve at:
Doret, wiworisu yand mighe osvon an sone
Thit.

ITI: thosll de doit tar:
Thiwd.
:
Phomols dhan:
I de-d:
Thosh th saro bpterto
NYdbt; Psotonks desve.

Ag rld cthek

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.6633145809173584
Generation test: 
[94m To be or [0msd a cosice, yit noh tins'f howin o frint fhind ankatiud, ne?

Ddamcit pesomt megiro bet micoh that.

IQUO Thachlal amas ate Idrin thin etal iwy'g lel thif Ccans thet.

SANEl Se ELWAmd d cteun thas skrlar berr't.

P, IRIUR BIOCOSwf Ieoed lunl't hhor; boig riRds thiim amgisat sal.

Rpsessk ourlas dorees gvirr Tor seired cleleu b,
Rd II hohel irin Igor Had Igivilit bheet son iny thhain:
Nir eshap hhad lol, mld isan p if:
Tof ukdtoue Napcek'y ohe bt metre pmsche gowe wessn bwand'b atee
Tas. Gte br!

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.626544713973999
Generation test: 
[94m To be or [0m't beuk thour athe:
Ekend fam houry te iread bof, olan I ndorrdrr lhy I uitim oung pyhe tinrils.

A Sareve anert nirs
O E

Votord,
MAn, ghor inasny dy, nakt saf norann lyyare wit;
Vel, nrdeny araskr firev.

Or; anar peeor hik ias:
wom yineviy land ghowdr ad whimdarse.

ALOBRHHOFUCLEAW
Ther maf har,
Fd son,
Ddaret me anir Ine.

OBwe Lo ighit cyip morg Hid net ivedfeitw habr' seturl inspse ye patig br.

Toulat Kheet ns henend itunt
rcas the;
Asrd cas se; Cites tranou rind:
Af it haf af sarestasug,

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.6018950939178467
Generation test: 
[94m To be or [0m,
Gexteas, Whild sho aly ta fum hasxco?

MAREEOTAT:
QETIUSIDUKAS:
uunat mit amilt sor livee yanetr awae tils den choly.

K it oriron Ide anvot we asr blas wout dak sessl fuup, u'de tinm the I The psus gole forblr ters ireicas p horekou-mothechs mib rannch il;
TINaunr Inewar ne hane necpuyeerr we rhes is riricy e iwen Wot un; olar mdatt phonve,
whou sorlr:
Shical min, o orimo se athirned wid, Id netherdis, ugan Gnart found thilt saf tht,
Molo ave lopeif jony or! flbdowr:
Heow fodt iwy, lonis erg 

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.5832927227020264
Generation test: 
[94m To be or [0mr tomd the Same whe serr dh, POC:
Whad heted bondd delll wole Marl
Yow'r lan:
Tand woltetrar, Btuot ometu:
Frn haeand, cltu the natatwe le th blethot, morcosl kausr!
BASIO
HTTA
Pyen oner liethras, arng ilthit, f dhegsiters akd he, seohe o w igilel--wetle; s vorsyres mha sh iw ROBA
MA chohe thim inond;
rin wey horploar do thetsrerd theld shar:
Yerald theur t athy he der, bis,
Get fotu simdagor'g asr,
Hor sin s linl bth b couvey datung igt; lag hom thisend arar loucanntrin marree Lou fopksar ro sh

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.5682687759399414
Generation test: 
[94m To be or [0m.

:
Lin has ceolidd tage
E ees
Yimemy tho pe nidr haged I
LUNORCITI:
Shol as thig hot lo brth, yrrpeviu eladee, toasat eee ikee thorcise-ndeghe lortism he chive bif.

Mrloos cwe ivot olo
Ton, Chdirsla'eun moe mafte de suveltu ytafas, voprime yo no rigbk onsis hos asss su quhho isy ith the du ne thoromicree fundtheof, sil boun theet ore leren pare g ealfee no whe, hin im sat, mabr thet we
YUIUNDUCUONINUORR:
OROWHO
SALIDOI:
Eeonangise oul, I! ingn lte mat been DISW:
Whera Lht sesus ilas pdid bil:

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.555424690246582
Generation test: 
[94m To be or [0mlmove thy toulamit, E'ld athes ite she the weaut?

BIUSAME:
A-
Fou phounsou mor sthond eous,
Chd tut. A?

By Cosreftlel'y
Mald.

Aromim's we ydo
Anll tus wapd to the pepef wolerthist, adcon his.

IRZTAR:
Hoveremitof purdaco hicer rost horn wou te:
Delred, se saaven vebande hehe.


KEThinlmos moon
Whae?

:
Yas thamre' the ous ond yelk'the this tof holaton the neum'y re llo horf aes wis lour Esar Ism sars af whol iw the yer g aror hitu faor we'sarcue lycee'd the thapred akey tho the hio whijagve p

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.54449725151062
Generation test: 
[94m To be or [0mr selet lhito
DIOnd,
I ro poeng; he gharomy, yas at turpveleet
Whived som fesal otesle,
Pokt:
Dotd hepee ptu, then ne yoy youiny
wow ofr o Bur, Rrith ethet:
Wheg wow pey shlin: That hee lild ane:
Whet mon hoh der thimlar'd the brar spe ate micos the
The thoweese theawslitr
wen thoven at izon far as row au sef qnhe herto tur ade.
Themethef mich;
GAt ley ane llres thed fhohr it filt apm blieste E'R thesecatit the
E a theste.

Wiflt me niss bd kysuund, vis
Ditaslcw.

YLENIDIN ILINHLOD:
Th goune deo

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 2.5349130630493164
Generation test: 
[94m To be or [0me:
Jorsas do dy has Yot tetpash anlte sure wegaf, fersislotes the nemerad rave was lol Um asthe thalrile the benes sran qucrhome, on the for lonssh
Ining aret ae thhoo hot met on O KINONTA:
wit Ovorele, ve mifkat that ow awror tow pwt.

Mcr hiilthithe wor yout ott ofkk to he? Lal thofmt ye hoth;
Rergoung thargof,
Kof tisoly elrs tu moveas kes, thin'd tas, tums the thopilolo apvan wor sid bur.

Wher liprad toelo?
Thoy silas is!


Thalgord os the con mo mow nen whins owgot lo sonsnn and od euld al