In [134]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

In [135]:
words = open('names.txt', 'r').read().splitlines()

In [136]:
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [137]:
print(f'min = {min(len(w) for w in words)}, max = {max(len(w) for w in words)}')

min = 2, max = 15


In [138]:
alphabet = sorted(list(set(''.join(words))))

stoi = {s: i for i, s in enumerate(('.', *alphabet))}
itos = {i: s for s, i in stoi.items()}
num_classes = len(stoi)

In [139]:
xs, ys = [], []

for w in words:
    chars = ['.'] + list(w) + ['.']
    for c_curr, c_next in zip(chars, chars[1:]):
        #print(c_curr, c_next)
        xs.append(stoi[c_curr])
        ys.append(stoi[c_next])

xs = tf.convert_to_tensor(xs)
ys = tf.convert_to_tensor(ys)


In [140]:
xenc = tf.one_hot(xs, num_classes)
yidx = tf.stack([tf.range(ys.shape[0]), ys], axis=1)

In [141]:
W = tf.Variable(tf.random.normal((num_classes, num_classes), seed=1), name='W', trainable=True)

In [142]:
learning_rate = 20
reg_strength = 0.1

for _ in range(100):

    with tf.GradientTape() as tape:
        logits = xenc @ W
        counts = tf.math.exp(logits)
        probs = counts / tf.math.reduce_sum(counts, axis=1, keepdims=True)
        loss = tf.math.reduce_mean(-tf.math.log(tf.gather_nd(probs, yidx))) + reg_strength * tf.math.reduce_mean(W**2)

    print(loss.numpy())

    dl_dw = tape.gradient(loss, W)
    W.assign_sub(learning_rate*dl_dw)


3.9211085
3.7454646
3.599487
3.4770389
3.3746512
3.2885842
3.2153761
3.1525435
3.0984082
3.0516932
3.0113091
2.976297
2.945819
2.9191573
2.8957086
2.8749702
2.8565285
2.8400426
2.8252325
2.8118653
2.7997484
2.78872
2.778644
2.7694056
2.7609065
2.7530627
2.7458024
2.7390623
2.7327886
2.7269332
2.7214553
2.7163174
2.7114887
2.7069404
2.7026472
2.6985881
2.6947432
2.6910946
2.6876273
2.684328
2.681184
2.678184
2.6753182
2.6725774
2.6699538
2.66744
2.6650286
2.662714
2.6604905
2.6583529
2.6562965
2.6543171
2.65241
2.6505725
2.6488001
2.6470907
2.64544
2.6438465
2.6423068
2.6408186
2.63938
2.6379886
2.6366422
2.6353393
2.6340775
2.6328564
2.6316726
2.6305258
2.6294146
2.628337
2.6272924
2.6262789
2.625296
2.6243422
2.6234164
2.622518
2.6216452
2.620798
2.6199749
2.6191754
2.6183987
2.6176436
2.61691
2.6161964
2.6155028
2.6148288
2.6141727
2.6135345
2.6129136
2.6123092
2.6117215
2.6111488
2.6105917
2.6100495
2.6095214
2.6090074
2.6085064
2.608018
2.6075425
2.6070793


In [143]:
dist = tfp.distributions.Multinomial(1, probs=[0.2, 0.5, 0.3])
dist.sample()

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 1., 0.], dtype=float32)>

In [144]:
# Generate in serial

n_gen = 10
for _ in range(n_gen):
    ienc = tf.one_hot([0], num_classes)
    out = []
    while True:
        logits = ienc @ W
        counts = tf.math.exp(logits)
        probs = counts / tf.math.reduce_sum(counts, axis=1, keepdims=True)

        ienc = tfp.distributions.Multinomial(1, probs=probs).sample()
        i = tf.math.argmax(tf.squeeze(ienc))

        out.append(itos[i.numpy()])

        if i == 0:
            break

    print(''.join(out))


cmimaya.
za.
el.
d.
rbreka.
mzadeja.
kejantyn.
jajoneixindailinda.
laulaesjgh.
dggowuumaqapeeianiyayn.


In [147]:
# Generate in parallel, has a bias towards generating shorter examples

n_parallel = 10
n_gen = 10
gen_count = 0

ienc = tf.one_hot([0]*n_parallel, num_classes)
out = [[] for _ in range(n_parallel)]

while gen_count < n_gen:
    logits = ienc @ W
    counts = tf.math.exp(logits)
    probs = counts / tf.math.reduce_sum(counts, axis=1, keepdims=True)

    ienc = tfp.distributions.Multinomial(1, probs=probs).sample()
    iz =  tf.argmax(ienc, axis=1)

    for o, i in zip(out, iz):
        if i == 0:
            print(''.join(o))
            o.clear()
            gen_count += 1
        else:
            o.append(itos[i.numpy()])


m
pn
e
man
lama
xili
jonnav
ie
gie
drte
trjetare
kessgonch
ben
ommir
elon
kanl
dozala
manr
caayguorealbrary
deq
cerikkwdetelecen
kanabrynilere
mamriazrykusai
malfx
ieleeri
dige
ensn
yn
ylvibolltizah
loman
dbeliyy
jryaritlmfon
s
arcaxxy
jbeleewopsar
seolor
zmararikcus
srslaamgburosobory
asianynni
kxgribeda
ze
brh
kerlan
os
aszxx
k
dwvdqle
we
dechistonn
kejacbolarame
ralesrbch
a
ekahamaheohiha
in
tailla
ke
losusgh
autxi
lezafwtoninnn
cariebetian
cayavariaamquacimy
a
friou
rorma
canikah
eleisiylia
mixthyai
vharai
ta
ziarn
tkenyadimindjayvielydenebera
fall
may
ma
r
kg
maidrize
llionkamausala
ana
lenuhay
zavaorblio
bxli
a
s
xgmasel
jionikila
mone
cssonil
dan
jasiktie
mi
imadaroxh
midj
sapiyobo
jbxm
gvkeniy
nen
nauderelann
jrfxelaxxuma
lainiligshliele
ina
