In [1]:
!pip3 install torch
!pip3 install matplotlib



In [673]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [674]:
names_list = open('names.txt','r').read().split('\n')

In [675]:
print(names_list[:10])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']


In [676]:
char_to_int = {}
int_to_char = {}

char_to_int['.'] = 0
int_to_char[0] = '.'

for char_num in range(ord('a'),ord('z')+1):
    integer_representation = char_num-ord('a')+1
    char_to_int[chr(char_num)] = integer_representation
    int_to_char[integer_representation] = chr(char_num)


In [677]:
import random
random.seed(42)
random.shuffle(names_list)

In [678]:
#hyperparameters

embed_size = 4

state_size = 10

hidden_size = 25

vocab_size = 28

In [679]:
import math
char_vector_size = 2
input_contexts , input_labels = [[] for _ in range(5)], [[] for _ in range(5)]

for name in names_list:
    log_padded_length = math.ceil(math.log(len(name))/math.log(2))
    padded_length = int(2**log_padded_length)
    context = [char_to_int[char] for char in name]  + [0] + [vocab_size-1]*(padded_length-len(name)) 
    input_contexts[log_padded_length].append(context)
#   instead of using labels, we can compare the result to the character vectors to find a likely match

In [680]:
word = input_contexts[2][1]

In [681]:
# parameters
all_char_embeds = torch.randn((vocab_size,embed_size))

all_char_embeds[vocab_size-1] = 10**10

w_eh = torch.randn((embed_size, hidden_size)) * 0.5

w_sh = torch.randn((state_size, hidden_size)) * 0.5
w_so = torch.randn((state_size, embed_size)) *0.5

w_hs = torch.randn((hidden_size, state_size)) *0.5

b_h = torch.zeros((hidden_size,))
b_s = torch.zeros((state_size,))
b_o = torch.zeros((embed_size,))

init_state = torch.randn((state_size,)) * 0.5

network_params = [init_state,w_eh,w_sh,w_so,w_hs,b_h,b_s,b_o,all_char_embeds]

for p in network_params:
    p.requires_grad = True

In [682]:
def process_state_and_embed(states , embeds, hiddens, outputs, idx):
    hiddens[idx] = 2*torch.sigmoid((states[idx] @ w_sh) + (embeds[idx] @ w_eh) + b_h) -1
    states[idx+1] = hiddens[idx] @ w_hs + b_s
    outputs[idx] = states[idx] @ w_so + b_o

In [683]:
batch_size = 32

for i in range(len(input_contexts)):
    input_contexts[i] = torch.tensor(input_contexts[i])

In [707]:
loss_tot = 0

loops = 50000
check_every = 500

# training loop
for i in range(loops):
    log_idx = random.randint(1,4)
#     log_idx = 4
    batch_indices = torch.randint(0,input_contexts[log_idx].shape[0],(batch_size,))

    batch = input_contexts[log_idx][batch_indices]

    word_length = len(batch[0])

    embeds = torch.transpose(all_char_embeds[batch],0,1)

    states = [init_state] + [torch.zeros((batch_size,state_size,)) for _ in range(word_length-1)]

    hiddens = [torch.zeros((batch_size,embed_size,)) for _ in range(word_length-1)]

    outputs = torch.zeros((word_length,batch_size,embed_size))

    for j in range(word_length-1):
        process_state_and_embed(states,embeds ,hiddens,outputs,j)

    outputs[-1] = states[-1] @ w_so + b_o
    
    distances = (torch.cdist(outputs,all_char_embeds[:-1]) ** 2).sum(2,keepdim=True)#is actually distances squared
        
    filtered = torch.where(embeds != 1e10, ((outputs-embeds)**2 / distances), 0)
    
    loss = filtered.sum()
    
    loss.backward()
    
    for p in network_params:
        p.data -= 2e-4*p.grad
        p.grad = None
        
    loss_tot+= loss/ (batch_size*word_length)
    
    if i % check_every==0 and i>0:
        print(f"{i} of {loops}: " , loss_tot / (check_every))
        loss_tot = 0
#     print('o',outputs)
#     print('e',embeds)
#     print('wl',word_length)

500 of 50000:  tensor(0.0100, grad_fn=<DivBackward0>)
1000 of 50000:  tensor(0.0102, grad_fn=<DivBackward0>)
1500 of 50000:  tensor(0.0099, grad_fn=<DivBackward0>)
2000 of 50000:  tensor(0.0102, grad_fn=<DivBackward0>)
2500 of 50000:  tensor(0.0098, grad_fn=<DivBackward0>)
3000 of 50000:  tensor(0.0099, grad_fn=<DivBackward0>)
3500 of 50000:  tensor(0.0099, grad_fn=<DivBackward0>)
4000 of 50000:  tensor(0.0097, grad_fn=<DivBackward0>)
4500 of 50000:  tensor(0.0097, grad_fn=<DivBackward0>)
5000 of 50000:  tensor(0.0096, grad_fn=<DivBackward0>)
5500 of 50000:  tensor(0.0098, grad_fn=<DivBackward0>)
6000 of 50000:  tensor(0.0099, grad_fn=<DivBackward0>)
6500 of 50000:  tensor(0.0098, grad_fn=<DivBackward0>)
7000 of 50000:  tensor(0.0095, grad_fn=<DivBackward0>)
7500 of 50000:  tensor(0.0092, grad_fn=<DivBackward0>)
8000 of 50000:  tensor(0.0095, grad_fn=<DivBackward0>)
8500 of 50000:  tensor(0.0093, grad_fn=<DivBackward0>)
9000 of 50000:  tensor(0.0096, grad_fn=<DivBackward0>)
9500 of 500

In [692]:
def predict(word):
    word_length = len(word)

    embeds = all_char_embeds[word]

    states = [init_state]+[torch.zeros((state_size,)) for _ in range(word_length-1)]

    hiddens = [torch.zeros((embed_size,)) for _ in range(word_length-1)]

    outputs = torch.zeros((word_length,embed_size))

    for i in range(word_length-1):
        process_state_and_embed(states,embeds ,hiddens,outputs,i)

    return states[-1] @ w_so + b_o


def get_next(text):
    word = [char_to_int[c] for c in text] +[0]

    pred = predict(word)

    pred = pred.view(1,embed_size)
    dists = torch.cdist(pred,all_char_embeds[:-1])
    dists = dists.view(-1)
    dists = torch.exp(-5*dists)
    dists/=dists.sum()
    sampled = (torch.multinomial(dists,1)).item()
    return int_to_char[sampled]

In [708]:
start_st = ''

for _ in range(100):
    st = start_st
    curr = 'a'
    while curr!='.':
        curr = get_next(st)
        st+=curr
    print(st)


vne.
en.
mny.
teyyaelynaser.
deterny.
mrl.
ras.
mienrltroll.
ie.
sd.
errynil.
milntta.
iasllnt.
moienadtlnlaana.
cni.
mylltoinralalisllnlh.
sodl.
eeleeaeray.
naeayel.
ieeeairte.
sa.
yaloeeeeieareers.
aaini.
ta.
o.
te.
lnryrilealesny.
mahtttul.
a.
.
tenelloiyiieoi.
cennnnienynerirt.
sa.
hoteaalaaeairterierirlteilyylteraai.
tidniha.
jrailh.
tier.
laena.
raiea.
mnoea.
tooetaeilialiimto.
mn.
vilaenlirieenoetaaaeeevet.
aol.
oedhvelrnle.
sooleueielln.
dlteoamasteaooi.
mle.
menasityieaionaeltiry.
sn.
mio.
teiarnelimslneleosayto.
sorinil.
dteaelryio.
toy.
neliyeoereelaoni.
u.
a.
naal.
tiaonahnarulnranoomarta.
serlomniaalsoemdla.
teovvti.
mmersd.
oairlaonnaehiaaoiten.
seet.
tryleen.
aaieoiaayoaelau.
menteai.
coye.
teoeiaa.
rilenrtod.
rs.
tymr.
cneronoaaaede.
lroonayrneove.
a.
tolitaaatataaealse.
lidliedaynr.
neanaa.
vot.
l.
cynerl.
yoasalnninel.
tni.
mi.
tninlea.
deivtusl.
a.
roar.
ale.
teoi.
aymn.
eon.
etyealeaanm.
yeneiaeahian.
nmelo.
.
mynalas.
oeieoiryailylaaeearea.
mrrnieylnoliorreoolmeael

In [713]:
((all_char_embeds[char_to_int['m']]-all_char_embeds[char_to_int['j']])**2).sum()

tensor(2.6450, grad_fn=<SumBackward0>)

In [650]:
all_char_embeds.data [:-1] -= torch.mean(all_char_embeds[:-1],0)
all_char_embeds.data [:-1] /= torch.std(all_char_embeds[:-1],0)


In [649]:
torch.mean(all_char_embeds[:-1],0)

tensor([-0.8800,  0.0131,  1.5467, -0.6799], grad_fn=<MeanBackward1>)

In [714]:
all_char_embeds

tensor([[ 4.5012e-01, -2.4450e-02,  3.9393e-01, -1.2016e-01],
        [ 4.5529e-01, -1.7121e-03,  3.3084e-01, -1.9254e-01],
        [ 9.5880e-01,  1.1683e+00,  4.5491e-01, -2.1356e-01],
        [-2.2185e-01,  2.5530e-01, -2.6817e-01, -7.5573e-01],
        [ 8.6706e-01,  1.7468e-01,  4.4619e-01, -1.9205e-02],
        [ 4.6325e-01, -2.6519e-02,  3.8783e-01, -1.5155e-01],
        [-1.5893e+00,  5.1317e-01, -2.1757e+00,  8.5860e-02],
        [-9.4863e-01, -5.7473e-01, -1.8380e+00, -1.5821e+00],
        [ 4.0778e-01, -2.8972e-01,  7.4186e-01,  4.2629e-02],
        [ 4.4010e-01, -2.5068e-02,  4.3573e-01, -1.1926e-01],
        [ 7.2413e-01,  1.6931e-01, -1.0461e+00, -1.2130e+00],
        [-5.6255e-01, -1.2972e-01, -7.5269e-01, -7.6305e-01],
        [ 4.5623e-01,  1.9689e-02,  3.4889e-01, -7.5090e-02],
        [ 1.5951e-01,  1.3316e-01,  2.0722e-02, -1.2357e-01],
        [ 4.3361e-01, -4.6162e-02,  3.9439e-01, -9.1754e-02],
        [ 5.9844e-01, -1.4042e-01,  4.0913e-01, -1.2199e-01],
        