In [21]:
import torch
import torch.nn.functional as F

In [2]:
bigram_set = {}
words = open('names.txt', 'r').read().splitlines()

In [None]:
for w in words:
    characters = ['.']+list(w)+['.']
    for char1, char2 in zip(characters, characters[1:]):
        b_index = (char1, char2)
        bigram_set[b_index] = bigram_set.get(b_index, 0)+1
bigram_set
        

In [None]:
N = torch.zeros((27, 27), dtype=torch.int32)
all_single_chars = sorted(list(set(''.join(words))))
char_to_index = {s:i+1 for i,s in enumerate(all_single_chars)}
char_to_index['.'] = 0
index_to_char = {i:s for s,i in char_to_index.items()}; index_to_char


In [None]:
for w in words:
    characters = ['.']+list(w)+['.']
    for char1, char2 in zip(characters, characters[1:]):
        indx_1 = char_to_index[char1]
        indx_2 = char_to_index[char2]
        N[indx_1, indx_2]+=1
N

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(N)

In [None]:

plt.figure(figsize=(16,16))
plt.imshow(N, cmap="Blues")
for i in range(27):
    for j in range(27):
        display_char = index_to_char[i]+index_to_char[j]
        plt.text(j, i ,display_char, ha="center", va="bottom", color="gray")
        plt.text(j, i, N[i, j].item(),  ha="center", va="top", color="gray")
plt.axis('off')

In [None]:
N[0].shape

In [13]:
p_of_N = N[0].float()
p_of_N = p_of_N / p_of_N.sum(); p_of_N

tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan])

In [14]:
g = torch.Generator().manual_seed(2147483647)
p = torch.rand(3, generator=g) ; p = p / p.sum()
p


tensor([0.6064, 0.3033, 0.0903])

In [15]:
torch.multinomial(p, num_samples=100, replacement=True, generator=g)

tensor([1, 1, 2, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 2, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1,
        0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 1, 0,
        0, 1, 1, 1])

In [None]:
P = N.float()
# P = P/ P.sum(1, keepdim=True)
P /= P.sum(1, keepdim=True) 
P

In [None]:
g = torch.Generator().manual_seed(2147483647)
for i in range(10):
    indx = 0
    cur_name = ""
    while True:
        # p = N[indx].float()
        # p = p / p.sum()
        p = P[indx]
        indx = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        cur_name+= index_to_char[indx]
        if indx == 0: break
    print(cur_name)

create the training set for all bigrams

In [33]:

x_tensor, y_tensor = [], []

for w in words[:1]:
    characters = ['.']+list(w)+['.']
    for char1, char2 in zip(characters, characters[1:]):
        indx_1 = char_to_index[char1]
        indx_2 = char_to_index[char2]
        x_tensor.append(indx_1)
        y_tensor.append(indx_2)

inputs = torch.tensor(x_tensor  )
labels = torch.tensor(y_tensor)
inputs, labels       

(tensor([ 0,  5, 13, 13,  1]), tensor([ 5, 13, 13,  1,  0]))

In [46]:
input_encoded = F.one_hot(inputs, num_classes = 27).float()
weights = torch.randn((27,27)) # random from normal distribution

log_counts = input_encoded @ weights
log_counts # interpretated as log count, the count add as the distribution of different characeters.
counts = log_counts.exp() # count distribution
prob = counts / counts.sum(1, keepdims =True)

tensor([[0.2881, 4.5867, 1.7456, 1.9147, 3.4133, 1.7690, 0.0894, 0.6723, 0.3196,
         0.3732, 5.7343, 0.2648, 0.4554, 1.2624, 0.5854, 0.8660, 0.9247, 7.4040,
         2.7512, 0.6131, 1.7173, 1.4257, 0.2730, 1.6587, 3.2454, 0.9583, 2.8068],
        [1.1050, 0.8517, 3.2948, 0.6502, 4.8913, 4.8815, 2.3345, 0.6285, 0.5486,
         0.2825, 0.5729, 0.7040, 1.0087, 1.5162, 0.5349, 2.5622, 0.4564, 2.9916,
         0.5973, 0.9385, 5.4576, 3.1681, 0.5101, 0.8404, 1.6139, 0.5408, 0.3278],
        [1.3293, 0.3635, 2.0222, 0.3700, 1.7566, 0.7293, 1.0571, 5.1339, 9.5884,
         0.5652, 5.5594, 1.8958, 1.0199, 2.1193, 0.0503, 0.7690, 3.6789, 0.6606,
         0.1719, 0.6703, 3.1813, 1.7148, 2.6905, 1.0157, 0.9759, 0.2621, 0.0760],
        [1.3293, 0.3635, 2.0222, 0.3700, 1.7566, 0.7293, 1.0571, 5.1339, 9.5884,
         0.5652, 5.5594, 1.8958, 1.0199, 2.1193, 0.0503, 0.7690, 3.6789, 0.6606,
         0.1719, 0.6703, 3.1813, 1.7148, 2.6905, 1.0157, 0.9759, 0.2621, 0.0760],
        [2.9409, 8.5500,