In [251]:
import torch
import torch.nn.functional as F

# 0. Check cuda availability
if torch.cuda.is_available(): 
    dev = "cuda:0" 
else: 
    dev = "cpu" 
device = torch.device(dev) 

# 1. Load the data from local file called 'names.txt'
with open('names.txt', 'r') as f:
    names = [line.strip() for line in f]
len(names), names[:5]

(32033, ['emma', 'olivia', 'ava', 'isabella', 'sophia'])

In [252]:
# 2. encode the char into a list of integers
symbols = sorted(list(set(''.join(names))))
char_to_int = {s:i+1 for i,s in enumerate(symbols)}
char_to_int['.'] = 0
char_to_int

{'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 '.': 0}

In [253]:
# 3. generate mapping from previous character to next character
input_char, output_char = [],[]
for word in names[:]:
    input_char.append(char_to_int['.'])
    for i in range(len(word)):
        input_char.append(char_to_int[word[i]]) 
        output_char.append(char_to_int[word[i]]) 
    output_char.append(char_to_int['.'])  
input_char, output_char

([0,
  5,
  13,
  13,
  1,
  0,
  15,
  12,
  9,
  22,
  9,
  1,
  0,
  1,
  22,
  1,
  0,
  9,
  19,
  1,
  2,
  5,
  12,
  12,
  1,
  0,
  19,
  15,
  16,
  8,
  9,
  1,
  0,
  3,
  8,
  1,
  18,
  12,
  15,
  20,
  20,
  5,
  0,
  13,
  9,
  1,
  0,
  1,
  13,
  5,
  12,
  9,
  1,
  0,
  8,
  1,
  18,
  16,
  5,
  18,
  0,
  5,
  22,
  5,
  12,
  25,
  14,
  0,
  1,
  2,
  9,
  7,
  1,
  9,
  12,
  0,
  5,
  13,
  9,
  12,
  25,
  0,
  5,
  12,
  9,
  26,
  1,
  2,
  5,
  20,
  8,
  0,
  13,
  9,
  12,
  1,
  0,
  5,
  12,
  12,
  1,
  0,
  1,
  22,
  5,
  18,
  25,
  0,
  19,
  15,
  6,
  9,
  1,
  0,
  3,
  1,
  13,
  9,
  12,
  1,
  0,
  1,
  18,
  9,
  1,
  0,
  19,
  3,
  1,
  18,
  12,
  5,
  20,
  20,
  0,
  22,
  9,
  3,
  20,
  15,
  18,
  9,
  1,
  0,
  13,
  1,
  4,
  9,
  19,
  15,
  14,
  0,
  12,
  21,
  14,
  1,
  0,
  7,
  18,
  1,
  3,
  5,
  0,
  3,
  8,
  12,
  15,
  5,
  0,
  16,
  5,
  14,
  5,
  12,
  15,
  16,
  5,
  0,
  12,
  1,
  25,
  12,
  1,
  0,
  18,
 

In [254]:
# one-hot encoding
inputs_encoded = F.one_hot(torch.tensor(input_char, device = device), len(char_to_int)).float()
labels = torch.tensor(output_char, device = device)
inputs_encoded.device, inputs_encoded.dtype, inputs_encoded.shape, labels.device, labels.dtype, labels.shape


(device(type='cuda', index=0),
 torch.float32,
 torch.Size([228146, 27]),
 device(type='cuda', index=0),
 torch.int64,
 torch.Size([228146]))

In [255]:
# initilize the tensor with 27x27 random numbers from normal distribution
weights = torch.randn(27, 27, dtype=torch.float32, requires_grad=True, device = device)
weights.dtype, weights.shape, weights.device

(torch.float32, torch.Size([27, 27]), device(type='cuda', index=0))

In [266]:
for i in range(10000):
    '''forward pass'''
    # pass in inputs, calculate the likelihood of each character
    log_counts = inputs_encoded @ weights
    counts = torch.exp(log_counts)
    probs = counts / counts.sum(1, keepdim=True) 
    
    # calculate the loss using negative log likelihood
    loss = -probs[torch.arange(len(inputs_encoded)), labels].log().mean()
    print(loss.item())
    '''backward pass'''
    weights.grad = None
    loss.backward()
    
    '''update weights'''
    weights.data -= 50 * weights.grad
    




2.459644317626953
2.4596216678619385
2.4595987796783447
2.459576368331909
2.459554433822632
2.4595324993133545
2.459510326385498
2.4594886302948
2.4594674110412598
2.4594459533691406
2.4594242572784424
2.4594032764434814
2.4593825340270996
2.4593615531921387
2.459340810775757
2.459320545196533
2.4593000411987305
2.459280014038086
2.4592597484588623
2.4592397212982178
2.4592196941375732
2.459200382232666
2.4591808319091797
2.4591610431671143
2.459141969680786
2.459122657775879
2.45910382270813
2.459084987640381
2.459066152572632
2.459047317504883
2.459028959274292
2.459010362625122
2.4589922428131104
2.4589738845825195
2.458956003189087
2.4589381217956543
2.4589202404022217
2.458902597427368
2.4588849544525146
2.458867311477661
2.458850145339966
2.4588329792022705
2.458815813064575
2.45879864692688
2.4587819576263428
2.4587652683258057
2.4587485790252686
2.4587318897247314
2.4587154388427734
2.4586989879608154
2.4586827754974365
2.4586665630340576
2.4586503505706787
2.458634614944458
2.