In [95]:
import numpy as np
import torch
import torch.nn.functional as F

In [96]:
words = ['ab', 'bc', 'ca']

In [97]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
itos

{1: 'a', 2: 'b', 3: 'c', 0: '.'}

In [98]:
bichars = [x+y for x in stoi.keys() for y in stoi.keys()]
bstoi = {s:i for i, s in enumerate(bichars)}
# del bstoi['..']
bitos = {i:s for s, i in bstoi.items()}
bitos

{0: 'aa',
 1: 'ab',
 2: 'ac',
 3: 'a.',
 4: 'ba',
 5: 'bb',
 6: 'bc',
 7: 'b.',
 8: 'ca',
 9: 'cb',
 10: 'cc',
 11: 'c.',
 12: '.a',
 13: '.b',
 14: '.c',
 15: '..'}

In [99]:
xs, ys = [], []

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = bstoi[ch1+ch2]
        ix3 = stoi[ch3]
        xs.append(ix1)
        ys.append(ix3)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()

xs,ys

(tensor([12,  1, 13,  6, 14,  8]), tensor([2, 0, 3, 0, 1, 0]))

In [83]:
xenc = F.one_hot(xs, num_classes=16).float()
xenc

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

In [84]:
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((16, 4), generator=g, requires_grad=True)

In [85]:
for k in range(5):
    xenc = F.one_hot(xs, num_classes=16).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)

    loss = -probs[torch.arange(probs.shape[0]), ys].log().mean()
    print(loss.item())


    W.grad = None
    loss.backward()

    W.data += -0.1*W.grad

1.7877308130264282
1.7718868255615234
1.756158471107483
1.740545630455017
1.7250484228134155


In [88]:
# sampling from neural net

for i in range(5):
    out = []
    ix = 15
    while True:

        xenc = F.one_hot(torch.tensor([ix]), num_classes=16).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdim=True)

        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0:
            break
    print(''.join(out))


.
bcccccccbcccccccccbcca.
.
bcbacccccccccbcccccccccccccccca.
b.


##### Alternative approach
I tried to represent each bigram as a concatenation of one-hot encoding of each letter. So for e.g. 'ab' has 'a' [0 ,1, 0, 0] and 'b' [0, 0, 1, 0] so it becomes [0 ,1, 0, 0, 0, 0, 1, 0]. Not sure if this representation can be fed into NN

In [91]:
xs, ys = [], []

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        xs.append([ix1, ix2])
        ys.append(ix3)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()


In [92]:
xenc = F.one_hot(xs, num_classes=len(chars)+1).float()
xenc
reshaped_tensor = xenc.view(xenc.size(0), -1)
reshaped_tensor

tensor([[1., 0., 0., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 1., 0., 0.]])

In [93]:
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((8, 4), generator=g, requires_grad=True)
W


tensor([[ 1.5674, -0.2373, -0.0274, -1.1008],
        [ 0.2859, -0.0296, -1.5471,  0.6049],
        [ 0.0791,  0.9046, -0.4713,  0.7868],
        [-0.3284, -0.4330,  1.3729,  2.9334],
        [ 1.5618, -1.6261,  0.6772, -0.8404],
        [ 0.9849, -0.1484, -1.4795,  0.4483],
        [-0.0707,  2.4968,  2.4448, -0.6701],
        [-1.2199,  0.3031, -1.0725,  0.7276]], requires_grad=True)

In [94]:
reshaped_tensor @ W

tensor([[ 2.5522, -0.3857, -1.5069, -0.6525],
        [ 0.2152,  2.4671,  0.8977, -0.0652],
        [ 1.4966,  2.2595,  2.4174, -1.7708],
        [-1.1407,  1.2078, -1.5438,  1.5144],
        [ 0.3475,  0.0659, -1.0999, -0.3732],
        [ 0.6564, -0.5813, -0.1066,  3.3817]], grad_fn=<MmBackward0>)