In [1]:
!pip3 install torch
!pip3 install matplotlib

Collecting torch
  Downloading torch-2.0.1-cp311-none-macosx_11_0_arm64.whl (55.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.8/55.8 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting filelock
  Using cached filelock-3.12.2-py3-none-any.whl (10 kB)
Collecting typing-extensions
  Using cached typing_extensions-4.7.1-py3-none-any.whl (33 kB)
Collecting sympy
  Using cached sympy-1.12-py3-none-any.whl (5.7 MB)
Collecting networkx
  Using cached networkx-3.1-py3-none-any.whl (2.1 MB)
Collecting mpmath>=0.19
  Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)
Installing collected packages: mpmath, typing-extensions, sympy, networkx, filelock, torch
Successfully installed filelock-3.12.2 mpmath-1.3.0 networkx-3.1 sympy-1.12 torch-2.0.1 typing-extensions-4.7.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;3

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
names_list = open('names.txt','r').read().split('\n')

In [4]:
print(names_list[:10])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']


In [5]:
char_to_int = {}
int_to_char = {}

char_to_int['.'] = 0
int_to_char[0] = '.'

for char_num in range(ord('a'),ord('z')+1):
    integer_representation = char_num-ord('a')+1
    char_to_int[chr(char_num)] = integer_representation
    int_to_char[integer_representation] = chr(char_num)


In [6]:
import random
random.seed(42)
random.shuffle(names_list)

In [30]:
#hyperparameters

embed_size = 6

state_size = 15

hidden_size = 30

vocab_size = 28

In [83]:
import math
char_vector_size = 2
input_contexts , input_labels = [[] for _ in range(5)], [[] for _ in range(5)]

for name in names_list:
    log_padded_length = math.ceil(math.log(len(name))/math.log(2))
    padded_length = int(2**log_padded_length)
    context = [char_to_int[char] for char in name]  + [0] + [vocab_size-1]*(padded_length-len(name)) 
    input_contexts[log_padded_length].append(context)
#   instead of using labels, we can compare the result to the character vectors to find a likely match

In [84]:
freqs = [0 for i in range(vocab_size-1)]

for log_len in range(1,5):
    contexts = input_contexts[log_len]
    for c in contexts:
        for int_char in c:
            if int_char<vocab_size-1:
                freqs[int_char] += 1
                
freqs = torch.tensor(freqs+[0],dtype=torch.float64)
freqs/=freqs.sum()

freqs

tensor([0.1404, 0.1485, 0.0116, 0.0155, 0.0241, 0.0895, 0.0040, 0.0084, 0.0334,
        0.0776, 0.0127, 0.0221, 0.0612, 0.0291, 0.0803, 0.0348, 0.0045, 0.0012,
        0.0557, 0.0355, 0.0244, 0.0137, 0.0113, 0.0041, 0.0031, 0.0428, 0.0105,
        0.0000], dtype=torch.float64)

In [85]:
word = input_contexts[2][1]

In [86]:
# parameters
all_char_embeds = torch.randn((vocab_size,embed_size))

all_char_embeds[vocab_size-1] = 10**10

w_eh = torch.randn((embed_size, hidden_size)) * 0.5

w_sh = torch.randn((state_size, hidden_size)) * 0.5
w_so = torch.randn((state_size, embed_size)) *0.5

w_hs = torch.randn((hidden_size, state_size)) *0.5

b_h = torch.zeros((hidden_size,))
b_s = torch.zeros((state_size,))
b_o = torch.zeros((embed_size,))

init_state = torch.randn((state_size,)) * 0.5

network_params = [init_state,w_eh,w_sh,w_so,w_hs,b_h,b_s,b_o,all_char_embeds]

for p in network_params:
    p.requires_grad = True

In [87]:
def process_state_and_embed(states , embeds, hiddens, outputs, idx):
    hiddens[idx] = 2*torch.sigmoid((states[idx] @ w_sh) + (embeds[idx] @ w_eh) + b_h) -1
    states[idx+1] = hiddens[idx] @ w_hs + b_s
    outputs[idx] = states[idx] @ w_so + b_o

In [88]:
batch_size = 32

for i in range(len(input_contexts)):
    input_contexts[i] = torch.tensor(input_contexts[i])

In [97]:
loss_tot = 0

loops = 100000
check_every = 500

# training loop
for i in range(loops):
    log_idx = random.randint(1,4)
#     log_idx = 4
    batch_indices = torch.randint(0,input_contexts[log_idx].shape[0],(batch_size,))

    batch = input_contexts[log_idx][batch_indices]

    word_length = len(batch[0])

    embeds = torch.transpose(all_char_embeds[batch],0,1)
    
    factors = torch.where(freqs > 0,1/(freqs+1e-18),0)
    
    factors[:-1] /= factors[:-1].mean()
    
    adjustment_factors = torch.transpose(factors[batch],0,1)
    
    adjustment_factors = adjustment_factors.view((word_length,batch_size,1))
    
    adjustment_factors[:-1] /= adjustment_factors[:-1].mean()
#     print(adjustment_factors , embeds.shape)

    states = [init_state] + [torch.zeros((batch_size,state_size,)) for _ in range(word_length-1)]

    hiddens = [torch.zeros((batch_size,embed_size,)) for _ in range(word_length-1)]

    outputs = torch.zeros((word_length,batch_size,embed_size))

    for j in range(word_length-1):
        process_state_and_embed(states,embeds ,hiddens,outputs,j)

    outputs[-1] = states[-1] @ w_so + b_o
    
    distances = (torch.cdist(outputs,all_char_embeds[:-1]) ** 2).sum(2,keepdim=True)#is actually distances squared
        
    filtered = torch.where(embeds != 1e10, ((outputs-embeds)**2 / distances), 0)
    
    ones = torch.where(embeds != 1e10, 1, 0)
    
    loss = (filtered*adjustment_factors).sum()
    
    loss.backward()
    
    for p in network_params:
        p.data -= 1e-3*p.grad
        p.grad = None
        
    loss_tot+= loss/ adjustment_factors.sum()
    if i % check_every==0 and i>0:
        print(f"{i} of {loops}: " , loss_tot / (check_every))
        loss_tot = 0
#     print('o',outputs)
#     print('e',embeds)
#     print('wl',word_length)

500 of 100000:  tensor(0.0152, dtype=torch.float64, grad_fn=<DivBackward0>)
1000 of 100000:  tensor(0.0152, dtype=torch.float64, grad_fn=<DivBackward0>)
1500 of 100000:  tensor(0.0151, dtype=torch.float64, grad_fn=<DivBackward0>)
2000 of 100000:  tensor(0.0149, dtype=torch.float64, grad_fn=<DivBackward0>)
2500 of 100000:  tensor(0.0151, dtype=torch.float64, grad_fn=<DivBackward0>)
3000 of 100000:  tensor(0.0153, dtype=torch.float64, grad_fn=<DivBackward0>)
3500 of 100000:  tensor(0.0151, dtype=torch.float64, grad_fn=<DivBackward0>)
4000 of 100000:  tensor(0.0144, dtype=torch.float64, grad_fn=<DivBackward0>)
4500 of 100000:  tensor(0.0150, dtype=torch.float64, grad_fn=<DivBackward0>)
5000 of 100000:  tensor(0.0150, dtype=torch.float64, grad_fn=<DivBackward0>)
5500 of 100000:  tensor(0.0151, dtype=torch.float64, grad_fn=<DivBackward0>)
6000 of 100000:  tensor(0.0152, dtype=torch.float64, grad_fn=<DivBackward0>)
6500 of 100000:  tensor(0.0148, dtype=torch.float64, grad_fn=<DivBackward0>)


53500 of 100000:  tensor(0.0141, dtype=torch.float64, grad_fn=<DivBackward0>)
54000 of 100000:  tensor(0.0141, dtype=torch.float64, grad_fn=<DivBackward0>)
54500 of 100000:  tensor(0.0136, dtype=torch.float64, grad_fn=<DivBackward0>)
55000 of 100000:  tensor(0.0138, dtype=torch.float64, grad_fn=<DivBackward0>)
55500 of 100000:  tensor(0.0137, dtype=torch.float64, grad_fn=<DivBackward0>)
56000 of 100000:  tensor(0.0139, dtype=torch.float64, grad_fn=<DivBackward0>)
56500 of 100000:  tensor(0.0138, dtype=torch.float64, grad_fn=<DivBackward0>)
57000 of 100000:  tensor(0.0138, dtype=torch.float64, grad_fn=<DivBackward0>)
57500 of 100000:  tensor(0.0140, dtype=torch.float64, grad_fn=<DivBackward0>)
58000 of 100000:  tensor(0.0139, dtype=torch.float64, grad_fn=<DivBackward0>)
58500 of 100000:  tensor(0.0138, dtype=torch.float64, grad_fn=<DivBackward0>)
59000 of 100000:  tensor(0.0139, dtype=torch.float64, grad_fn=<DivBackward0>)
59500 of 100000:  tensor(0.0138, dtype=torch.float64, grad_fn=<D

In [95]:
def predict(word):
    word_length = len(word)

    embeds = all_char_embeds[word]

    states = [init_state]+[torch.zeros((state_size,)) for _ in range(word_length-1)]

    hiddens = [torch.zeros((embed_size,)) for _ in range(word_length-1)]

    outputs = torch.zeros((word_length,embed_size))

    for i in range(word_length-1):
        process_state_and_embed(states,embeds ,hiddens,outputs,i)

    return states[-1] @ w_so + b_o


def get_next(text):
    word = [char_to_int[c] for c in text] +[0]

    pred = predict(word)

    pred = pred.view(1,embed_size)
    dists = torch.cdist(pred,all_char_embeds[:-1])
    dists = dists.view(-1)
    dists = torch.exp(-dists)
    dists/=dists.sum()
    sampled = (torch.multinomial(dists,1)).item()
    return int_to_char[sampled]

In [96]:
start_st = 'j'

for _ in range(100):
    st = start_st
    curr = 'a'
    while curr!='.':
        curr = get_next(st)
        st+=curr
    print(st)


jklvqgtooy.
jfhdafcan.
jdslrewosme.
jweduevre.
jzokyz.
jopcparcxrnjor.
jacglhrowt.
jarwgye.
jcusaufwln.
jhewpblmsv.
jfyjtcbdkcmhg.
jpeyxtlwpfdrvn.
jlburdrcax.
jzvxgmehua.
jkazrezzil.
jfxluimfin.
jpwqojdlyru.
jjavihbuxv.
jmhtephwhy.
jxeqoye.
jzgkrhzbxlgot.
jigpjydwbe.
jltmoewlbzqi.
jledhmjya.
jdapumwmjmn.
jejpgurxlqxtioh.
jjismqyaj.
jfogjxzwpa.
jubeuftgftn.
jussoyv.
jwwryo.
jigjpnrjhk.
jigdrtrbcq.
jfnoibun.
jiocjintao.
juhghyit.
jkvawscuxb.
jpswie.
jetporeby.
jqnygvim.
jpepqrbaak.
jjbhbiwzqon.
jumdvstguh.
jriggupyu.
jdpdnvjvtfa.
jqgngbtwkxqo.
jkpkextgvdl.
jjiqmuqoly.
jvkhnhej.
jsfshljfqbv.
jovoltyoo.
jtnedkbuws.
jknxrvpri.
jkhjetnlwdw.
jxruyo.
jdsvotmos.
jxjnpjkvmcadphh.
jewqivuwttat.
josgpknaly.
jqffpeikou.
jhhdfxruhw.
jwcrorcyu.
jolftsgaz.
jhztym.
jvlsfxrtiln.
jjhldnknpiil.
jasmfjhdlzn.
jqmhbsdkzis.
jvtvrzbwak.
jynrleah.
jcetqppkpe.
jtocofzben.
jcnhigvae.
jshiqnomoe.
jofbapquxx.
jtpgjpojko.
jozttwlvgpun.
jbuicxbton.
jxnywxmzgr.
jvosquhzxwr.
jytiwmrpkkluqvo.
jthvaulol.
jtmyojmmot.
jupq

In [64]:
((all_char_embeds[char_to_int['a']]-all_char_embeds[char_to_int['e']])**2).sum()

tensor(7.9190e-05, grad_fn=<SumBackward0>)

In [65]:
all_char_embeds

tensor([[ 4.1404e-02, -3.1502e-01,  3.8437e-01,  2.2421e-01, -4.5370e-02,
          1.1276e-01],
        [ 4.5207e-02, -3.1540e-01,  3.7890e-01,  2.3027e-01, -3.9489e-02,
          1.1035e-01],
        [-1.9714e-01, -1.2698e-01,  2.5729e-01,  3.4796e-01, -5.0878e-03,
          5.1886e-03],
        [-1.1072e-02, -3.5110e-01,  3.7590e-01,  2.3563e-01, -7.8479e-02,
          8.7188e-02],
        [ 3.8594e-02, -2.9808e-01,  4.2518e-01,  2.1967e-01, -2.6467e-02,
          9.5366e-02],
        [ 4.3529e-02, -3.1508e-01,  3.8704e-01,  2.2813e-01, -4.0856e-02,
          1.0848e-01],
        [ 7.6546e-01,  3.9753e-01, -3.2336e+00, -2.1769e-01,  1.5096e-01,
          2.4524e+00],
        [ 1.4561e-01, -3.7959e-02, -3.3408e-01,  2.4587e-01, -2.6998e-01,
          9.5122e-01],
        [ 2.9950e-02, -3.1685e-01,  4.2429e-01,  2.0591e-01, -2.3509e-02,
          7.8233e-02],
        [ 4.0571e-02, -3.1500e-01,  3.9491e-01,  2.2562e-01, -3.9184e-02,
          1.0424e-01],
        [ 7.5062e-02, -3.5887e