In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
torch.manual_seed(1)

B=512
C=128
H=64
T=8

with open('../names.txt') as fp:
    lines = fp.read()

chars = sorted(set(lines))
V = vocab_size = len(chars)
ctoi = dict(zip(chars, range(vocab_size)))
itoc = dict(zip(range(vocab_size), chars))
data_ints = torch.tensor([ctoi[c] for c in lines ])

def get_batch():
    batch = []
    for i in range(B):
        idx = torch.randint(low=0,high=len(data_ints)-T-1,size=(1,))
        batch.append(data_ints[idx:idx+T+1])

    X = torch.vstack(batch)
    y = X[:,1:].clone()
    X = X[:,:-1]
    return X,y

In [14]:
class MyRNN(nn.Module):
    def __init__(self, input_size, hidden_size, vocab_size):
        super().__init__()
        C=input_size
        H=hidden_size
        P = torch.nn.Parameter
        self.W_xh = P(torch.randn(C, H) * np.sqrt(1/C))
        self.W_hh = P(torch.randn(H, H) * np.sqrt(1/H))
        self.b_h = P(torch.zeros(H))

    def forward(self, x,h=None):
        B, T, C = x.shape
        if h is None:
            h = torch.zeros(H)
        logits = torch.zeros(T,B,vocab_size) # T,B,V
        hts = []
        for i in range(T):
            h = torch.tanh(x[:,i,:]@self.W_xh + h@self.W_hh + self.b_h) # B,H
            hts.append(h)
        return torch.stack(hts).transpose(0,1).contiguous(), h # B,T,V        

In [19]:
rnn = nn.RNN(input_size=C, hidden_size=H,bias=True,batch_first=True,bidirectional=True)
# rnn = MyRNN(C,H,vocab_size)
emb = torch.randn(vocab_size, C)
W_hy = torch.randn(H, vocab_size) * np.sqrt(1/H)
b_y = torch.zeros(vocab_size)

params = [emb, W_hy, b_y] + list(rnn.parameters())
for p in params:
    p.requires_grad = True

In [1]:
lr=.01
for i in range(1000):
    X,y = get_batch()
    o = emb[X]  # B,T,C
    o2, h = rnn(o)
    logits = o2@W_hy+b_y
    # logits = forward(X)

    
    loss = F.cross_entropy(logits.view(B*T,V), y.view(B*T))
    if i%100==0:
        print(f'{loss:.4f}')

    for param in params:
        param.grad = None
        
    loss.backward()

    for param in params:
        param.data -= lr * param.grad

In [48]:
# 2.4847 myRnn
# 2.5317 torch.RNN

In [57]:
prompt = lines[50:50+T]

next_tokens = []
for x in range(100):
    [itoc[i] for i in next_tokens]
    X_test = torch.tensor([ctoi[c] for c in prompt ])
    X_test=X_test.unsqueeze(0)
    o = emb[X_test]  # 1,T,C
    o2, h = rnn(o)
    logits = o2@W_hy + b_y
    next_word_logits = logits[:,-1]
    probas = F.softmax(next_word_logits, 1)
    next_token = torch.multinomial(probas[0], 1)[0]
    next_tokens.append(next_token.item())
    X_test = X_test.roll(-1)
    X_test[0,-1] = next_token
print(''.join([itoc[i] for i in next_tokens]))

nmzzmr
nmsceanll
grrsmnmlrzrrernnlillrremfnlwlrwzl
ylnrl
nmysnmtnn
lellnkwnlt
rllrxf
almvd
rsirlnyer


In [58]:
# logits.shape