In [1]:
import tensorflow as tf
import tensorflow_probability as tfp

In [2]:
words = open('names.txt', 'r').read().splitlines()

In [3]:
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [4]:
print(f'min = {min(len(w) for w in words)}, max = {max(len(w) for w in words)}')

min = 2, max = 15


In [5]:
alphabet = sorted(list(set(''.join(words))))

stoi = {s: i for i, s in enumerate(('.', *alphabet))}
itos = {i: s for s, i in stoi.items()}
num_classes = len(stoi)

In [6]:
xs, ys = [], []

for w in words:
    chars = ['.'] + list(w) + ['.']
    for c_curr, c_next in zip(chars, chars[1:]):
        #print(c_curr, c_next)
        xs.append(stoi[c_curr])
        ys.append(stoi[c_next])

xs = tf.convert_to_tensor(xs)
ys = tf.convert_to_tensor(ys)


In [7]:
xenc = tf.one_hot(xs, num_classes)
yidx = tf.stack([tf.range(ys.shape[0]), ys], axis=1)

In [8]:
W = tf.Variable(tf.random.normal((num_classes, num_classes), seed=1), name='W', trainable=True)

In [9]:
learning_rate = 20
reg_strength = 0.1

for _ in range(100):

    with tf.GradientTape() as tape:
        logits = xenc @ W
        counts = tf.math.exp(logits)
        probs = counts / tf.math.reduce_sum(counts, axis=1, keepdims=True)
        loss = tf.math.reduce_mean(-tf.math.log(tf.gather_nd(probs, yidx))) + reg_strength * tf.math.reduce_mean(W**2)

    print(loss.numpy())

    dl_dw = tape.gradient(loss, W)
    W.assign_sub(learning_rate*dl_dw)


3.930746
3.7550907
3.6061234
3.479786
3.3742385
3.287305
3.2157376
3.1558714
3.1046226
3.059876
3.0203128
2.9851034
2.9536793
2.9256084
2.9005327
2.8781376
2.8581357
2.8402605
2.8242633
2.8099148
2.7970026
2.7853389
2.7747555
2.765109
2.7562754
2.7481503
2.740647
2.7336912
2.7272213
2.7211847
2.7155385
2.7102437
2.7052684
2.700584
2.6961665
2.6919944
2.6880481
2.6843116
2.6807685
2.6774068
2.6742122
2.6711748
2.6682842
2.665531
2.6629066
2.6604028
2.658013
2.65573
2.6535475
2.6514597
2.6494613
2.6475472
2.6457129
2.6439536
2.6422656
2.6406448
2.6390882
2.6375916
2.636152
2.6347673
2.633434
2.6321497
2.630912
2.6297188
2.6285677
2.6274567
2.6263838
2.6253479
2.6243463
2.6233783
2.622442
2.621536
2.6206596
2.6198103
2.6189883
2.6181917
2.6174192
2.6166708
2.6159446
2.6152406
2.6145573
2.6138935
2.6132495
2.612624
2.6120164
2.6114256
2.6108518
2.610294
2.6097517
2.6092243
2.608711
2.6082115
2.6077259
2.6072526
2.6067927
2.6063442
2.6059077
2.6054823
2.6050684
2.6046648


In [10]:
dist = tfp.distributions.Multinomial(1, probs=[0.2, 0.5, 0.3])
dist.sample()

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 1., 0.], dtype=float32)>

In [11]:
# Generate in serial

n_gen = 10
for _ in range(n_gen):
    ienc = tf.one_hot([0], num_classes)
    out = []
    while True:
        logits = ienc @ W
        counts = tf.math.exp(logits)
        probs = counts / tf.math.reduce_sum(counts, axis=1, keepdims=True)

        ienc = tfp.distributions.Multinomial(1, probs=probs).sample()
        i = tf.math.argmax(tf.squeeze(ienc))

        out.append(itos[i.numpy()])

        if i == 0:
            break

    print(''.join(out))


monnta.
eyaya.
evkeur.
te.
ke.
re.
anrerofrimavy.
oned.
jbriacxbbqgidyubin.
kyramiele.


In [15]:
# Generate in parallel, has a bias towards generating shorter examples

n_parallel = 5
n_gen = 20
gen_count = 0

ienc = tf.one_hot([0]*n_parallel, num_classes)
out = [[] for _ in range(n_parallel)]

while gen_count < n_gen:
    logits = ienc @ W
    counts = tf.math.exp(logits)
    probs = counts / tf.math.reduce_sum(counts, axis=1, keepdims=True)

    ienc = tfp.distributions.Multinomial(1, probs=probs).sample()
    iz =  tf.argmax(ienc, axis=1)

    for o, i in zip(out, iz):
        o.append(itos[i.numpy()])

        if i == 0:
            print(''.join(o))
            o.clear()
            gen_count += 1
            


truan.
seame.
tzbzjhan.
zit.
a.
awbela.
kat.
i.
cqxz.
ceryphiwlinylleum.
s.
lieyin.
erama.
sheebekanclelsesadera.
kgwh.
ay.
w.
evbety.
uta.
kqxa.
