In [166]:
from matplotlib import pyplot as plt
import torch
import torch.nn.functional as F

# 0. Check cuda availability
if torch.cuda.is_available(): 
    dev = "cuda:0" 
else: 
    dev = "cpu" 
device = torch.device(dev) 

# 1. Load the data from local file called 'names.txt'
with open('names.txt', 'r') as f:
    names = [line.strip() for line in f]
len(names), names[:5]

(32033, ['emma', 'olivia', 'ava', 'isabella', 'sophia'])

In [167]:
# 2. encode the char into a list of integers
symbols = sorted(list(set(''.join(names))))
char_to_int = {s:i+1 for i,s in enumerate(symbols)}
int_to_char = {i+1:s for i,s in enumerate(symbols)}
char_to_int['.'] = 0
int_to_char[0] = '.'
# char_to_int, int_to_char

In [168]:
# hyperparameters
block_size = 5
embedding_size = 3
hidden_size = 500
minibatch_size = 32

In [169]:
# 3. generate mapping from previous characters to next character
def build_dataset(names_set):
    input_char, output_char = [],[]
    for word in names_set[:]:
        input_word ='.'*block_size+word
        output_word = word + '.'
        for i in range(len(output_word)):
            input_char.append(list(input_word[i:i+block_size])) 
            output_char.append(output_word[i]) 
            
    # encode mapping into integers
    for i in range(len(input_char)):
        input_char[i] = [char_to_int[s] for s in input_char[i]]
        output_char[i] = char_to_int[output_char[i]]
    X = torch.tensor(input_char, device=device)
    Y = torch.tensor(output_char, device=device)
    return X, Y


# split dataset into train, validation, and testing
import random
random.shuffle(names)
train_size = int(len(names)*0.7)
val_size = int(len(names)*0.2)
Inputs_train, Labels_train = build_dataset( names[:train_size])
Inputs_val, Labels_val = build_dataset( names[train_size:train_size+val_size])
Inputs_test, Labels_test = build_dataset( names[train_size+val_size:])
Inputs_train.shape, Inputs_val.shape, Inputs_test.shape, Labels_train.shape, Labels_val.shape, Labels_test.shape

(torch.Size([159590, 4]),
 torch.Size([45693, 4]),
 torch.Size([22863, 4]),
 torch.Size([159590]),
 torch.Size([45693]),
 torch.Size([22863]))

In [170]:
# one-hot encoding

# F.one_hot(xx) @ E = E[xx]

# input_encoded = F.one_hot(input_char, len(char_to_int)).float()
# input_encoded.shape, input_encoded.device, labels.shape, labels.device

In [171]:
# embedding layer
E = torch.rand((len(char_to_int), embedding_size), device=device, requires_grad=True)

# Hidden layer
W_hidden = torch.rand((embedding_size*block_size, hidden_size), device=device, requires_grad=True)
b_hidden = torch.rand(hidden_size, device=device, requires_grad=True)

# Output layer
W_out = torch.rand((hidden_size, len(char_to_int)), device=device, requires_grad=True)
b_out = torch.rand( len(char_to_int), device=device, requires_grad=True)

params = [E, W_hidden, b_hidden, W_out, b_out]

E.shape,W_hidden.shape, b_hidden.shape, W_out.shape, b_out.shape,sum([p.numel() for p in params])

(torch.Size([27, 3]),
 torch.Size([12, 800]),
 torch.Size([800]),
 torch.Size([800, 27]),
 torch.Size([27]),
 32108)

In [172]:
learning_rate_exp = torch.linspace(-3, 0, 1000)
learning_rate = 10**learning_rate_exp
# learning_rate

In [173]:
lr_records = []
loss_records = []

for i in range(20000):
    # construct minibatch
    index_this_batch = torch.randint(0, len(Inputs_train), (minibatch_size,))
    # forward pass
    embed = E[Inputs_train[index_this_batch]]
    hid = torch.tanh(embed.view(-1, embedding_size*block_size) @ W_hidden + b_hidden)
    log_counts = hid @ W_out + b_out
    loss = F.cross_entropy(log_counts, Labels_train[index_this_batch])
    print(loss.item())
    # backward pass
    for p in params:
        p.grad = None
    loss.backward()
    
    # gradient descent
    learning_rate = .1 if i < 10000 else .01
    for p in params:
        p.data -= learning_rate * p.grad
    
    # track lr performance'''
    # lr_records.append(lr)
    loss_records.append(loss.item())
        
    

16.133930206298828
12.601572036743164
17.97163200378418
22.64801788330078
26.264354705810547
21.901031494140625
15.443758964538574
17.339534759521484
19.198116302490234
19.347976684570312
26.83294105529785
31.55406379699707
30.390151977539062
25.844655990600586
25.535795211791992
29.252870559692383
24.619384765625
26.719785690307617
21.550922393798828
14.000421524047852
13.72242546081543
16.811447143554688
15.547049522399902
19.884777069091797
27.885347366333008
33.06177520751953
27.162757873535156
21.6138973236084
15.060019493103027
11.727212905883789
11.535717964172363
11.971663475036621
14.565890312194824
13.336759567260742
13.43139934539795
10.435136795043945
8.49830150604248
7.043607234954834
6.798463344573975
8.568638801574707
6.846274375915527
6.2304277420043945
4.742542743682861
5.809988498687744
5.78674840927124
6.999007225036621
5.79750394821167
6.529432773590088
5.139833450317383
6.8268280029296875
9.39358139038086
8.259943962097168
7.543330192565918
5.428611755371094
7.0603

In [174]:
# evaluate the model
embed = E[Inputs_test]
hid = torch.tanh(embed.view(-1, embedding_size*block_size) @ W_hidden + b_hidden)
log_counts = hid @ W_out + b_out
loss = F.cross_entropy(log_counts, Labels_test)
loss

tensor(2.2716, device='cuda:0', grad_fn=<NllLossBackward0>)

In [175]:
# plt.plot(learning_rate_exp, loss_records)

In [176]:
# generate a name
for _ in range(20):
    out = []
    context = [0]*block_size
    while True:
        embed = E[torch.tensor(context, device=device)]
        hid = torch.tanh(embed.view(1,-1) @ W_hidden + b_hidden)
        log_counts = hid @ W_out + b_out
        probs = F.softmax(log_counts, dim=1)
        index = torch.multinomial(probs, num_samples= 1).item()
        context = context[1:] + [index]
        # print(context)
        out.append(index)
        if index == 0:
            break
    print("out: "+''.join([int_to_char[i] for i in out]))


out: rmeroy.
out: rarac.
out: zalra.
out: emek.
out: araswee.
out: detd.
out: quinno.
out: aydyn.
out: karla.
out: sad.
out: jailan.
out: mahus.
out: dennyla.
out: moryyenkton.
out: fande.
out: kan.
out: aliid.
out: faite.
out: aray.
out: virlen.
