In [1]:
!pip3 install torch
!pip3 install matplotlib



In [591]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [592]:
names_list = open('names.txt','r').read().split('\n')

In [593]:
print(names_list[:10])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']


In [594]:
char_to_int = {}
int_to_char = {}

char_to_int['.'] = 0
int_to_char[0] = '.'

for char_num in range(ord('a'),ord('z')+1):
    integer_representation = char_num-ord('a')+1
    char_to_int[chr(char_num)] = integer_representation
    int_to_char[integer_representation] = chr(char_num)


In [595]:
import random
random.seed(42)
random.shuffle(names_list)

In [596]:
#hyperparameters

embed_size = 4

state_size = 10

hidden_size = 25

vocab_size = 28

In [597]:
import math
char_vector_size = 2
input_contexts , input_labels = [[] for _ in range(5)], [[] for _ in range(5)]

for name in names_list:
    log_padded_length = math.ceil(math.log(len(name))/math.log(2))
    padded_length = int(2**log_padded_length)
    context = [char_to_int[char] for char in name]  + [0] + [vocab_size-1]*(padded_length-len(name)) 
    input_contexts[log_padded_length].append(context)
#   instead of using labels, we can compare the result to the character vectors to find a likely match

In [598]:
word = input_contexts[2][1]

In [599]:
# parameters
all_char_embeds = torch.randn((vocab_size,embed_size))

all_char_embeds[vocab_size-1] = 10**10

w_eh = torch.randn((embed_size, hidden_size)) * 0.5

w_sh = torch.randn((state_size, hidden_size)) * 0.5
w_so = torch.randn((state_size, embed_size)) *0.5

w_hs = torch.randn((hidden_size, state_size)) *0.5

b_h = torch.zeros((hidden_size,))
b_s = torch.zeros((state_size,))
b_o = torch.zeros((embed_size,))

init_state = torch.randn((state_size,)) * 0.5

network_params = [init_state,w_eh,w_sh,w_so,w_hs,b_h,b_s,b_o,all_char_embeds]

for p in network_params:
    p.requires_grad = True

In [600]:
def process_state_and_embed(states , embeds, hiddens, outputs, idx):
    hiddens[idx] = 2*torch.sigmoid((states[idx] @ w_sh) + (embeds[idx] @ w_eh) + b_h) -1
    states[idx+1] = hiddens[idx] @ w_hs + b_s
    outputs[idx] = states[idx] @ w_so + b_o

In [601]:
batch_size = 32

for i in range(len(input_contexts)):
    input_contexts[i] = torch.tensor(input_contexts[i])

In [617]:
loss_tot = 0

loops = 50000
check_every = 500

# training loop
for i in range(loops):
    log_idx = random.randint(1,4)
#     log_idx = 4
    batch_indices = torch.randint(0,input_contexts[log_idx].shape[0],(batch_size,))

    batch = input_contexts[log_idx][batch_indices]

    word_length = len(batch[0])

    embeds = torch.transpose(all_char_embeds[batch],0,1)

    states = [init_state] + [torch.zeros((batch_size,state_size,)) for _ in range(word_length-1)]

    hiddens = [torch.zeros((batch_size,embed_size,)) for _ in range(word_length-1)]

    outputs = torch.zeros((word_length,batch_size,embed_size))

    for j in range(word_length-1):
        process_state_and_embed(states,embeds ,hiddens,outputs,j)

    outputs[-1] = states[-1] @ w_so + b_o
    
    distances = torch.cdist(outputs,all_char_embeds[:-1]).sum(2,keepdim=True)
        
    filtered = torch.where(embeds != 1e10, ((outputs-embeds)**2 / distances), 0)
    
    loss = filtered.sum()
    
    loss.backward()
    
    for p in network_params:
        p.data -= 1e-4*p.grad
        p.grad = None
        
    loss_tot+= loss/ (batch_size*word_length)
    
    if i % check_every==0 and i>0:
        print(f"{i} of {loops}: " , loss_tot / (check_every))
        loss_tot = 0
#     print('o',outputs)
#     print('e',embeds)
#     print('wl',word_length)

500 of 50000:  tensor(0.0020, grad_fn=<DivBackward0>)
1000 of 50000:  tensor(0.0018, grad_fn=<DivBackward0>)
1500 of 50000:  tensor(0.0020, grad_fn=<DivBackward0>)
2000 of 50000:  tensor(0.0020, grad_fn=<DivBackward0>)
2500 of 50000:  tensor(0.0020, grad_fn=<DivBackward0>)
3000 of 50000:  tensor(0.0018, grad_fn=<DivBackward0>)
3500 of 50000:  tensor(0.0020, grad_fn=<DivBackward0>)
4000 of 50000:  tensor(0.0019, grad_fn=<DivBackward0>)
4500 of 50000:  tensor(0.0018, grad_fn=<DivBackward0>)
5000 of 50000:  tensor(0.0019, grad_fn=<DivBackward0>)
5500 of 50000:  tensor(0.0018, grad_fn=<DivBackward0>)
6000 of 50000:  tensor(0.0018, grad_fn=<DivBackward0>)
6500 of 50000:  tensor(0.0017, grad_fn=<DivBackward0>)
7000 of 50000:  tensor(0.0018, grad_fn=<DivBackward0>)
7500 of 50000:  tensor(0.0018, grad_fn=<DivBackward0>)
8000 of 50000:  tensor(0.0016, grad_fn=<DivBackward0>)
8500 of 50000:  tensor(0.0017, grad_fn=<DivBackward0>)
9000 of 50000:  tensor(0.0018, grad_fn=<DivBackward0>)
9500 of 500

In [622]:
def predict(word):
    word_length = len(word)

    embeds = all_char_embeds[word]

    states = [init_state]+[torch.zeros((state_size,)) for _ in range(word_length-1)]

    hiddens = [torch.zeros((embed_size,)) for _ in range(word_length-1)]

    outputs = torch.zeros((word_length,embed_size))

    for i in range(word_length-1):
        process_state_and_embed(states,embeds ,hiddens,outputs,i)

    return states[-1] @ w_so + b_o


def get_next(text):
    word = [char_to_int[c] for c in text] +[0]

    pred = predict(word)

    pred = pred.view(1,embed_size)
    dists = torch.cdist(pred,all_char_embeds[:-1])
    dists = dists.view(-1)
    dists = torch.exp(-1000*dists)
    dists/=dists.sum()
    sampled = (torch.multinomial(dists,1)).item()
    return int_to_char[sampled]

In [None]:
start_st = 'j'

for _ in range(100):
    st = start_st
    curr = 'a'
    while curr!='.':
        curr = get_next(st)
        st+=curr
    print(st)


jd.
jo.


In [582]:
((all_char_embeds[char_to_int['u']]-all_char_embeds[char_to_int['d']])**2).sum()

tensor(0.0244, grad_fn=<SumBackward0>)

In [608]:
all_char_embeds


tensor([[-2.9664e-01, -6.0726e-02,  3.1658e-01, -7.4424e-02],
        [-2.4544e-01, -7.9831e-02,  2.8781e-01, -7.6645e-02],
        [-1.7123e-02, -7.6159e-02,  5.4805e-01,  2.1402e-01],
        [ 8.9724e-02, -8.4789e-04,  4.8525e-02, -1.9236e-01],
        [-2.0644e-01, -3.7802e-02,  4.2947e-01,  9.9446e-02],
        [-2.7876e-01, -9.1170e-02,  2.9630e-01, -7.1531e-02],
        [ 1.0480e+00, -2.4380e-01, -1.5023e+00, -8.5290e-01],
        [-2.3502e-01,  3.8689e-01,  8.7651e-01, -2.7298e-01],
        [-4.1980e-01, -6.3887e-02,  2.0647e-01, -8.4796e-02],
        [-2.9296e-01, -1.0149e-01,  3.0774e-01, -6.4426e-02],
        [ 4.1933e-02, -2.1295e-01,  1.5177e-01,  2.2886e-02],
        [-1.4921e-01, -7.5653e-02,  2.3434e-01, -3.9551e-01],
        [-2.7345e-01, -1.1771e-01,  3.0814e-01, -8.4988e-02],
        [-2.2740e-01, -1.3531e-01,  1.3570e-01, -1.5739e-01],
        [-3.0604e-01, -8.3162e-02,  3.2461e-01, -7.4180e-02],
        [-2.3206e-01, -1.6435e-01,  3.6302e-01, -2.4898e-02],
        