In [1]:
import torch
import torch.nn.functional as F

In [2]:
data=open('Names.txt','r').read().splitlines()

In [3]:
print(f'Total Names: {len(data)}')

Total Names: 3668


In [4]:
vocab=sorted(list(set(''.join(data))))

In [5]:
stoi={s:i+1 for i,s in enumerate(vocab)}
stoi['.']=0
itos={i:s for s,i in stoi.items()}
print(len(stoi) , len(itos))

53 53


In [6]:
vocab_size=len(stoi)

In [7]:
block_size=3
def build_dataset(dataset):
    x,y=[],[]
    context=[0]*block_size
    for name in dataset:
        for ch in name+'.':
            x.append(context)
            y.append(stoi[ch])
            context=context[1:]+[stoi[ch]]
    x = torch.tensor(x)
    y = torch.tensor(y)
    print(x.shape, y.shape)
    return x, y

In [8]:
n1=int(len(data)*0.8)
n2=int(len(data)*0.9)

Xtrain , Ytrain = build_dataset(data[:n1])
Xval   , Yval = build_dataset(data[n1:n2])
Xtest , Ytest = build_dataset(data[n2:])

torch.Size([21589, 3]) torch.Size([21589])
torch.Size([2769, 3]) torch.Size([2769])
torch.Size([2767, 3]) torch.Size([2767])


#Parameters

In [35]:
n_embd = 10
n_hidden = 200

g = torch.Generator().manual_seed(2147483642)
C  = torch.randn((vocab_size, n_embd),            generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1


bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

parameters = [C, W1, W2, b2,b1, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

17783


#shapes

In [10]:
print(f'shape of w1: {W1.shape}')
print(f'shape of b1: {b1.shape}')
print(f'shape of w2: {W2.shape}')
print(f'shape of b2: {b2.shape}')
print(f'shape of bngain: {bngain.shape}')
print(f'shape of bnbias: {bnbias.shape}')

shape of w1: torch.Size([30, 200])
shape of b1: torch.Size([200])
shape of w2: torch.Size([200, 53])
shape of b2: torch.Size([53])
shape of bngain: torch.Size([1, 200])
shape of bnbias: torch.Size([1, 200])


In [11]:
def common(value,t , vname):
    common=torch.all(value==t.grad).item()
    close=torch.allclose(value ,t.grad)
    maxdiff = (value - t.grad).abs().max().item()
    print(f'{vname:15s} all common:{ str(common):5s} | exactly common :{ str(close):5s} | maximum difference:{ maxdiff}')

Forwardpass


In [103]:
batch_size = 32
n = batch_size 
lossi = []
iterations = 200000
for i in range(iterations):
    ix = torch.randint(0, Xtrain.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtrain[ix], Ytrain[ix]
    emb = C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    
    
    # Linear layer 1
    hprebn = embcat @ W1 + b1
    
    # BatchNorm layer
    bnmeani = 1/n*hprebn.sum(0, keepdim=True)
    bndiff = hprebn - bnmeani
    bndiff2 = bndiff**2
    bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True)
    bnvar_inv = (bnvar + 1e-5)**-0.5
    bnraw = bndiff * bnvar_inv
    hpreact = bngain * bnraw + bnbias
    
    # Non-linearity
    h = torch.tanh(hpreact) # hidden layer
    
    # Linear layer 2
    logits = h @ W2 + b2 # output layer
    
    # cross entropy loss (same as F.cross_entropy(logits, Yb))
    logit_maxes = logits.max(1, keepdim=True).values
    norm_logits = logits - logit_maxes # subtract max for numerical stability
    counts = norm_logits.exp()
    counts_sum = counts.sum(1, keepdims=True)
    counts_sum_inv = counts_sum**-1
    probs = counts * counts_sum_inv
    logprobs = probs.log()
    loss = -logprobs[range(n), Yb].mean()
    lossi.append(loss)

    # PyTorch backward pass
    for p in parameters:
        p.grad = None
    for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
        norm_logits, logit_maxes, logits, h, hpreact, bnraw,bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,embcat, emb]:
         t.retain_grad()
    loss.backward()

#     lr = 0.1 if i < 100000 else 0.01 # step learning rate decay
#     for p in parameters:
#         p.data += -lr * p.grad

#     if i%1000==0:
#         print(lossi[-1])
    break

backward pass

In [105]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
dprobs = (1.0 / probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum
dnorm_logits = counts * dcounts
dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpreact = (1-h**2)*dh
dbnraw = bngain*dhpreact
dbngain = (bnraw*dhpreact).sum(0 , keepdim=True)
dbnbias = dhpreact.sum(0 , keepdim=True)
dbnvar_inv = (bndiff*dbnraw).sum(0 , keepdim=True)
dbndiff = bnvar_inv*dbnraw
dbnvar = (-0.5*(bnvar + 1e-5)**-1.5)*dbnvar_inv
dbndiff2 = 1/(n-1)*torch.ones_like(bndiff2)*dbnvar
dbndiff += 2*bndiff*dbndiff2
dhprebn = dbndiff.clone()
dbnmeani= -dbndiff.clone().sum(0 , keepdim=True)
dhprebn += 1/n*torch.ones_like(hprebn)*dbnmeani
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)
demb = dembcat.view(emb.shape)
dC=torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += demb[k,j]


common(dlogprobs , logprobs, 'dlogprobs')
common(dprobs , probs , 'dprobs')
common(dcounts_sum_inv , counts_sum_inv , 'dcounts_sum_inv')
common(dcounts,counts,'dcounts')
common(dcounts_sum , counts_sum , 'dcounts_sum')
common(dnorm_logits , norm_logits , 'dnorm_logits' )
common(dlogits , logits , 'dlogits' )
common(dlogit_maxes , logit_maxes , 'dlogit_maxes' )
common(dh, h , 'dh')
common(dW2 , W2 , 'dW2')
common(db2, b2 , 'db2')
common(dhpreact , hpreact , 'dhpreact')
common(dbnraw , bnraw , 'dbnraw')
common(dbngain , bngain , 'dbngain')
common(dbnbias , bnbias , 'dbnbias')
common( dbnvar_inv , bnvar_inv , 'dbnvar_inv')
common(dbndiff , bndiff , 'dbndiff')
common(dbnvar , bnvar , 'dbnvar')
common(dbndiff2 , bndiff2 , 'dbndiff2')
common(dhprebn , hprebn , 'dhprebn')
common(dbnmeani , bnmeani , 'dbnmeani')
common(dembcat , embcat , 'dembcat')
common(dembcat , embcat , 'dembcat')
common(db1 , b1 , 'db1')
common(demb , emb , 'demb')
common(dC , C , 'dC')

dlogprobs       all common:True  | exactly common :True  | maximum difference:0.0
dprobs          all common:True  | exactly common :True  | maximum difference:0.0
dcounts_sum_inv all common:True  | exactly common :True  | maximum difference:0.0
dcounts         all common:True  | exactly common :True  | maximum difference:0.0
dcounts_sum     all common:True  | exactly common :True  | maximum difference:0.0
dnorm_logits    all common:True  | exactly common :True  | maximum difference:0.0
dlogits         all common:True  | exactly common :True  | maximum difference:0.0
dlogit_maxes    all common:True  | exactly common :True  | maximum difference:0.0
dh              all common:True  | exactly common :True  | maximum difference:0.0
dW2             all common:True  | exactly common :True  | maximum difference:0.0
db2             all common:True  | exactly common :True  | maximum difference:0.0
dhpreact        all common:True  | exactly common :True  | maximum difference:0.0
dbnraw          

In [41]:
batch_size = 32
n = batch_size 
lossi = []
iterations = 200000
for i in range(iterations):
    ix = torch.randint(0, Xtrain.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtrain[ix], Ytrain[ix]
    emb = C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    
    
    # Linear layer 1
    hprebn = embcat @ W1 + b1
    
    # BatchNorm layer
    bnmeani = 1/n*hprebn.sum(0, keepdim=True)
    bndiff = hprebn - bnmeani
    bndiff2 = bndiff**2
    bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True)
    bnvar_inv = (bnvar + 1e-5)**-0.5
    bnraw = bndiff * bnvar_inv
    hpreact = bngain * bnraw + bnbias
    
    # Non-linearity
    h = torch.tanh(hpreact) # hidden layer
    
    # Linear layer 2
    logits = h @ W2 + b2 # output layer
    
    # cross entropy loss (same as F.cross_entropy(logits, Yb))
    logit_maxes = logits.max(1, keepdim=True).values
    norm_logits = logits - logit_maxes # subtract max for numerical stability
    counts = norm_logits.exp()
    counts_sum = counts.sum(1, keepdims=True)
    counts_sum_inv = counts_sum**-1
    probs = counts * counts_sum_inv
    logprobs = probs.log()
    loss = -logprobs[range(n), Yb].mean()
    lossi.append(loss)

    # PyTorch backward pass
    for p in parameters:
        p.grad = None
    for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
        norm_logits, logit_maxes, logits, h, hpreact, bnraw,bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,embcat, emb]:
         t.retain_grad()
#     dlogprobs = torch.zeros_like(logprobs)
#     dlogprobs[range(n), Yb] = -1.0/n
#     dprobs = (1.0 / probs) * dlogprobs
#     dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
#     dcounts = counts_sum_inv * dprobs
#     dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
#     dcounts += torch.ones_like(counts) * dcounts_sum
#     dnorm_logits = counts * dcounts
#     dlogits = dnorm_logits.clone()
#     dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
#     dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
#     dh = dlogits @ W2.T
#     dW2 = h.T @ dlogits
#     db2 = dlogits.sum(0)
#     dhpreact = (1-h**2)*dh
#     dbnraw = bngain*dhpreact
#     dbngain = (bnraw*dhpreact).sum(0 , keepdim=True)
#     dbnbias = dhpreact.sum(0 , keepdim=True)
#     dbnvar_inv = (bndiff*dbnraw).sum(0 , keepdim=True)
#     dbndiff = bnvar_inv*dbnraw
#     dbnvar = (-0.5*(bnvar + 1e-5)**-1.5)*dbnvar_inv
#     dbndiff2 = 1/(n-1)*torch.ones_like(bndiff2)*dbnvar
#     dbndiff += 2*bndiff*dbndiff2
#     dhprebn = dbndiff.clone()
#     dbnmeani= -dbndiff.clone().sum(0 , keepdim=True)
#     dhprebn += 1/n*torch.ones_like(hprebn)*dbnmeani
#     dembcat = dhprebn @ W1.T
#     dW1 = embcat.T @ dhprebn
#     db1 = dhprebn.sum(0)
#     demb = dembcat.view(emb.shape)
#     dC=torch.zeros_like(C)
#     for k in range(Xb.shape[0]):
#         for j in range(Xb.shape[1]):
#             ix = Xb[k,j]
#             dC[ix] += demb[k,j]

#     d_param=[dC, dW1, dW2, db2, db1, dbngain, dbnbias]
#     for p , dp in zip(parameters , d_param):
#         p.data += -0.1*dp
    loss.backward()
    for p in parameters:
             p.data += -1*p.grad

In [42]:
lossi[-1]

tensor(2.4347, grad_fn=<NegBackward0>)

In [43]:
g=torch.Generator().manual_seed(123456789)

for i in range(20):
    out=[]
    context=[0]*block_size
    while True:
        emb = C[torch.tensor([context])]
        embcat = emb.view(emb.shape[0] , -1)
        hpreact = embcat @ W1 + b1
        # hpreact = bngain * (hpreact - bnmeani) * (bnvar + 1e-5)**-0.5 + bnbias
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2 
        probs= F.softmax(logits , dim = 1)
        ix = torch.multinomial(probs , num_samples=1 , generator=g).item()
        context = context[1:]+[ix]
        out.append(ix)
        if ix==0:
            break
    print(''.join(itos[i] for i in out))
        
        

AAKKltPDGPewvaImznJGHqigBRCjPRGKIOQAn.
HKHQuNDCMQRLCIsflBOdjle.
AKQCOKQFixkatLGFlJGQuRKGQCOBExgeLLOQQFixkam.
NPFicDIgggfoGLAxtE.
KAQuJQEqflMLQCBaffeLFleaQCJakfiKRGBCMKjoQuigon.
GKFPLCjou.
LACQHKaotLGGQCKjoBKitBBBjonagdQHKi.
CAMFann.
FPBFujNCartPEnjar.
AKQCDBMcjMCasMFinALGLQKitGCIKDlDCQPigk.
FPCjowtNALLMPjaqr.
EaQor.
GAFCOOzlvaQrigaInkmadrHBujPAnflCNourEgon.
HCQInjo.
AGQBMjon.
AKQCIaz.
LGMQNicRAnhPHatMBMotGKKOQQJacqDCLCLGKjhjBGu.
AInFlaiqFinKDDAKQGisHCBenKQuJGQuLGDBCHNewclARGLPBatOAIlQBagcRIlQBigheHGHunjPCasEm.
EsQF ajcEnfoDLGEsKHHckilEnQPapplJttEBQuBGElQCDaveQNan.
HInQBLjoAOAMRKCQCHatKACQBitQPKattPEnjo.
