In [1]:
import numpy as np

def softmax(x):
    e= np.exp(x)
    return e/np.sum(e)

class RNN:
    def __init__(self, hidden_size, vocab_size, seq_length, lr):
        self.hidden_size= hidden_size
        self.vocab_size= vocab_size
        self.seq_length= seq_length
        self.lr= lr
        
        #Good initialization based on incoming connections
        self.W= np.random.uniform(-np.sqrt(1./vocab_size),np.sqrt(1./vocab_size),(hidden_size,vocab_size))
        self.U= np.random.uniform(-np.sqrt(1./hidden_size),np.sqrt(1./hidden_size),(hidden_size,hidden_size))
        self.V= np.random.uniform(-np.sqrt(1./hidden_size),np.sqrt(1./hidden_size),(vocab_size,hidden_size))
        self.b= np.zeros((hidden_size,1))
        self.c= np.zeros((vocab_size,1))
    
    def forward(self, inputs, hprev):
        xs, hs, os, yhat= {}, {}, {}, {}
        hs[-1]= np.copy(hprev)
        for t in range(len(inputs)):
            xs[t]= np.zeros((self.vocab_size,1))
            xs[t][inputs[t]]= 1.0
            hs[t]= np.tanh(np.dot(self.U,hs[t-1]) + np.dot(self.W,xs[t]) + self.b)
            os[t]= np.dot(self.V,hs[t]) + self.c
            yhat[t]= softmax(os[t])
        return xs, hs, yhat
    
    def loss(self, yhat, targets):
        return sum(-np.log(yhat[t][targets[t]]) for t in range(self.seq_length))
    
    def backward(self, xs, hs, yhat, targets):
        dW, dU, dV= np.zeros_like(self.W),np.zeros_like(self.U), np.zeros_like(self.V)
        db, dc= np.zeros_like(self.b), np.zeros_like(self.c)
        dhnext= np.zeros_like(hs[0])
        for t in reversed(range(self.seq_length)):
          dy= np.copy(yhat[t])
          dy[targets[t]]-= 1
          dV+= np.dot(dy, hs[t].T)
          dc+= dy
          dh= np.dot(self.V.T, dy) + dhnext
          dhraw= (1 - hs[t] * hs[t]) * dh
          db+= dhraw
          dU+= np.dot(dhraw, hs[t-1].T)
          dW+= np.dot(dhraw, xs[t].T)
          dhnext= np.dot(self.U.T, dhraw)
        return dV, dc, dU, dW, db
    
    def update(self, dV, dc, dU, dW, db):
        for param, dparam in zip([self.V, self.c, self.U, self.W, self.b],[dV, dc, dU, dW, db]):
            param-= self.lr * dparam

In [6]:
import re
files= ['sss_01_01.txt','sss_01_02.txt','sss_01_03.txt','sss_01_04.txt','sss_01_05.txt']
data= ''
for file in files:
    f= open("sss/"+file, 'r',encoding="utf8")
    data+= f.read()
data= data.lower()
data= re.sub("[.\“”:?!;,()-]","\n",data)
data= re.sub("(\s[‘’])|([‘’]\s)","\n",data)
data= re.sub("[‘’]s","",data)
data= re.sub("[0-9][A-Za-z]+","",data)
words= data.splitlines()
words= [w.split(' ') for w in words]
words= sum(words,[])
words= [w.strip() for w in words]
words= [w for w in words if len(w) >= 1]
word_set= list(set(words))
word_set.sort()
w2i= {}
for i,w in enumerate(word_set):
    w2i[w]= i;

N= len(words)
modv= len(word_set)
print(N)
print(modv)

8918
2048


  data= re.sub("(\s[‘’])|([‘’]\s)","\n",data)


In [7]:
hidden_size= 100
seq_length= 25
lr= 0.01
r= RNN(hidden_size, modv, seq_length, lr)
epochs= 1

for e in range(epochs):
    hprev= np.zeros((hidden_size,1))
    for p in range(N - seq_length - 1):
        inputs= [w2i[w] for w in words[p:p+seq_length]]
        targets= [w2i[w] for w in words[p+1:p+seq_length+1]]
        xs, hs, yhat= r.forward(inputs, hprev)
        l= r.loss(yhat,targets)
        print(l)
        dV, dc, dU, dW, db= r.backward(xs, hs, yhat, targets)
        r.update(dV, dc, dU, dW, db)

[190.65884307]
[190.07457728]
[189.51694553]
[188.94128059]
[188.36098334]
[187.78631123]
[187.17382248]
[186.38741556]
[185.52472979]
[184.52812402]
[183.02430681]
[180.39467164]
[176.01999651]
[168.40193216]
[157.72247947]
[146.78602494]
[134.8618895]
[125.32433871]
[119.84418628]
[115.85886276]
[113.43626258]
[111.52210429]
[110.22708433]
[103.58369852]
[101.45634399]
[100.91556784]
[97.1303636]
[96.23566967]
[97.27570917]
[99.18979097]
[100.63582158]
[103.50527903]
[105.75363832]
[99.82577581]
[102.10705097]
[104.39476261]
[99.20585184]
[101.62850472]
[97.42638918]
[99.3224612]
[94.55653936]
[90.30911728]
[93.89489023]
[92.27266514]
[96.09464646]
[91.44019315]
[96.80201772]
[101.60205423]
[97.11332031]
[98.57210876]
[102.51835802]
[102.84650469]
[105.68003092]
[106.30482882]
[109.32938493]
[107.02927778]
[108.30063587]
[107.69308664]
[110.44778336]
[106.43124582]
[107.25006723]
[101.29565823]
[103.82431024]
[103.83562932]
[101.05026757]
[102.33110542]
[103.92679407]
[100.22320361]
