In [1]:
%cd ../

/Users/eloidieme/dev/python-projects/lstm


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import numpy as np
import torch
import torch.nn.functional as F
import pickle
import time
import random
import matplotlib.pyplot as plt

In [3]:
book_fname = "./data/goblet_book.txt"
with open(book_fname, 'r') as book:
    book_data = book.read()
len(book_data)

1107542

In [4]:
def split_text(text, train_frac=0.8, val_frac=0.1):
    train_end = int(len(text) * train_frac)
    val_end = train_end + int(len(text) * val_frac)

    train_data = text[:train_end]
    val_data = text[train_end:val_end]
    test_data = text[val_end:]

    return train_data, val_data, test_data

In [5]:
word_list = book_data.split()
chars = [[*word] for word in word_list]
max_len = max(len(word) for word in chars)
for wordl in chars:
    while len(wordl) < max_len:
        wordl.append(' ')
chars = np.array(chars)

In [6]:
unique_chars = list(np.unique(chars))
unique_chars.append('\n')
unique_chars.append('\t')
K = len(unique_chars)  # dimensionality of the input and output vectors

In [7]:
char_to_ind = {}
ind_to_char = {}
for idx, char in enumerate(unique_chars):
    char_to_ind[char] = idx
    ind_to_char[idx] = char

In [8]:
m = 100  # dimensionality of the hidden state
eta = 0.1  # learning rate
seq_length = 25  # length of input sequences used during training
epsilon = 1e-8  # for AdaGrad

In [9]:
sig = 0.01
Wf = torch.normal(0.0, sig, (m, m), dtype=torch.double, requires_grad=True)
Wi = torch.normal(0.0, sig, (m, m), dtype=torch.double, requires_grad=True)
Wo = torch.normal(0.0, sig, (m, m), dtype=torch.double, requires_grad=True)
Wc = torch.normal(0.0, sig, (m, m), dtype=torch.double, requires_grad=True)
Wlist = [Wf, Wi, Wo, Wc]
Wall = torch.cat(Wlist, dim=0)

In [10]:
Uf = torch.normal(0.0, sig, (m, K), dtype=torch.double, requires_grad=True)
Ui = torch.normal(0.0, sig, (m, K), dtype=torch.double, requires_grad=True)
Uo = torch.normal(0.0, sig, (m, K), dtype=torch.double, requires_grad=True)
Uc = torch.normal(0.0, sig, (m, K), dtype=torch.double, requires_grad=True)
Ulist = [Uf, Ui, Uo, Uc]
Uall = torch.cat(Ulist, dim=0)

In [11]:
V = torch.normal(0.0, sig, (K, m), dtype=torch.double, requires_grad=True)
c = torch.zeros((K, 1), dtype=torch.double, requires_grad=True)
FClist = [V, c]
FC = torch.cat(FClist, dim=1)

In [12]:
LSTM = {
    'Wall': Wall, 
    'Uall': Uall,
    'FC': FC
}

In [13]:
E1 = torch.cat([torch.eye(m, dtype=torch.double), torch.zeros((m, m), dtype=torch.double), torch.zeros((m, m), dtype=torch.double), torch.zeros((m, m), dtype=torch.double)], dim=1)
E2 = torch.cat([torch.zeros((m, m), dtype=torch.double), torch.eye(m, dtype=torch.double), torch.zeros((m, m), dtype=torch.double), torch.zeros((m, m), dtype=torch.double)], dim=1)
E3 = torch.cat([torch.zeros((m, m), dtype=torch.double), torch.zeros((m, m), dtype=torch.double), torch.eye(m, dtype=torch.double), torch.zeros((m, m), dtype=torch.double)], dim=1)
E4 = torch.cat([torch.zeros((m, m), dtype=torch.double), torch.zeros((m, m), dtype=torch.double), torch.zeros((m, m), dtype=torch.double), torch.eye(m, dtype=torch.double)], dim=1)

In [14]:
def encode_char(char):
    oh = [0]*K
    oh[char_to_ind[char]] = 1
    return oh

In [15]:
def synthetize_seq(lstm, h0, c0, x0, n, T = 1):
    t, ht, ct, xt = 0, h0.clone(), c0.clone(), x0.clone().reshape((K, 1))
    indexes = []
    while t < n:
        at = torch.mm(lstm['Wall'], ht) + torch.mm(lstm['Uall'], xt)
        ft = F.sigmoid(torch.mm(E1, at))
        it = F.sigmoid(torch.mm(E2, at))
        ot = F.sigmoid(torch.mm(E3, at))
        ctilde = F.tanh(torch.mm(E4, at))
        ct = ft * ct + it * ctilde
        ht = ot * F.tanh(ct)
        out = torch.mm(lstm['FC'][:, :-1], ht) + lstm['FC'][:, -1:]
        pt = F.softmax(out/T, dim=0)
        cp = torch.cumsum(pt, dim=0)
        a = torch.rand(1)
        ixs = torch.where(cp - a > 0)
        ii = ixs[0][0].item()
        indexes.append(ii)
        xt = torch.zeros((K, 1), dtype=torch.double)
        xt[ii, 0] = 1
        t += 1
    Y = []
    for idx in indexes:
        oh = [0]*K
        oh[idx] = 1
        Y.append(oh)
    Y = torch.tensor(Y).t()
    
    s = ''
    for i in range(Y.shape[1]):
        idx = torch.where(Y[:, i] == 1)[0].item()
        s += ind_to_char[idx]
    
    return Y, s

In [16]:
def encode_string(chars):
    M = []
    for i in range(len(chars)):
        M.append(encode_char(chars[i]))
    M = torch.tensor(M, dtype=torch.double).t()
    return M

In [17]:
def forward(lstm, X, hprev, cprev):
    ht = hprev.clone()
    ct = cprev.clone()
    P = torch.zeros((K, seq_length), dtype=torch.double)
    for i in range(seq_length):
        xt = X[:, i].reshape((K, 1))
        at = torch.mm(lstm['Wall'], ht) + torch.mm(lstm['Uall'], xt)
        ft = F.sigmoid(torch.mm(E1, at))
        it = F.sigmoid(torch.mm(E2, at))
        ot = F.sigmoid(torch.mm(E3, at))
        ctilde = F.tanh(torch.mm(E4, at))
        ct = ft * ct + it * ctilde
        ht = ot * F.tanh(ct)
        out = torch.mm(lstm['FC'][:, :-1], ht) + lstm['FC'][:, -1:]
        pt = F.softmax(out, dim=0)

        P[:, i] = pt.squeeze()

    return P, ht, ct

In [18]:
def compute_loss(Y, P):
    log_probs = torch.log(P)
    cross_entropy = -torch.sum(Y * log_probs)
    loss = cross_entropy.item()
    return cross_entropy, loss

In [19]:
def evaluate_model(rnn, val_data):
    total_loss = 0
    total_characters = 0
    hprev = torch.zeros((m, 1), dtype=torch.double)
    for i in range(0, len(val_data) - seq_length, seq_length):
        X_chars = val_data[i:i + seq_length]
        Y_chars = val_data[i + 1:i + seq_length + 1]
        X_val = encode_string(X_chars)
        Y_val = encode_string(Y_chars)
        _, _, P, hprev = forward(rnn, X_val, hprev)
        loss = compute_loss(Y_val, P)
        total_loss += loss * seq_length
        total_characters += seq_length
    average_loss = total_loss / total_characters
    perplexity = torch.exp(torch.tensor(average_loss))
    return perplexity.item()

In [20]:
e, step, epoch = 0, 0, 0
n_epochs = 2
smooth_loss = 0
seq_length = 25
losses = []
hprev = torch.zeros((m, 1), dtype=torch.double)
cprev = torch.zeros((m, 1), dtype=torch.double)

mWf = torch.zeros_like(Wf, dtype=torch.double)
mWi = torch.zeros_like(Wi, dtype=torch.double)
mWo = torch.zeros_like(Wo, dtype=torch.double)
mWc = torch.zeros_like(Wc, dtype=torch.double)
mUf = torch.zeros_like(Uf, dtype=torch.double)
mUi = torch.zeros_like(Ui, dtype=torch.double)
mUo = torch.zeros_like(Uo, dtype=torch.double)
mUc = torch.zeros_like(Uc, dtype=torch.double)
mV = torch.zeros_like(V, dtype=torch.double)
mc = torch.zeros_like(c, dtype=torch.double)
msW = {
    'Wf': mWf, 
    'Wi': mWi, 
    'Wo': mWo, 
    'Wc': mWc, 
}
msU = {
    'Uf': mUf,
    'Ui': mUi,
    'Uo': mUo,
    'Uc': mUc
}
msFC = {
    'V': mV,
    'c': mc
}
Ws = ['Wf', 'Wi', 'Wo', 'Wc']
Us = ['Uf', 'Ui', 'Uo', 'Uc']
FCs = ['V', 'c']

while epoch < n_epochs:
    for p in Wlist + Ulist + FClist:
        if p.grad is not None:
            p.grad.detach_()
            p.grad.zero_()

    X_chars = book_data[e:e+seq_length]
    Y_chars = book_data[e+1:e+seq_length+1]
    X_train = encode_string(X_chars)
    Y_train = encode_string(Y_chars)

    P_train, ht, ct = forward(LSTM, X_train, hprev, cprev)
    cross_entropy, loss = compute_loss(Y_train, P_train)
    cross_entropy.backward(retain_graph=True)

    for idx, key in enumerate(Ws):
        grad = torch.clamp(Wlist[idx].grad, -5, 5)
        msW[key] += grad**2
        LSTM['Wall'][100*idx:100*(idx+1)] -= (eta/torch.sqrt(msW[key] + epsilon))*grad

    for idx, key in enumerate(Us):
        grad = torch.clamp(Ulist[idx].grad, -5, 5)
        msU[key] += grad**2
        LSTM['Uall'][100*idx:100*(idx+1)] -= (eta/torch.sqrt(msU[key] + epsilon))*grad

    for idx, key in enumerate(FCs):
        grad = torch.clamp(FClist[idx].grad, -5, 5)
        msFC[key] += grad**2
        LSTM['FC'][:, -1*idx:-1*(1-idx)] -= (eta/torch.sqrt(msFC[key] + epsilon))*grad

    if step == 0:
        smooth_loss = loss
    else:
        smooth_loss = 0.999*smooth_loss + 0.001*loss

    losses.append(smooth_loss)

    if step % 1000 == 0:
        print(f"Step: {step}")
        print(f"\t * Smooth loss: {smooth_loss}")
    if step % 5000 == 0:
        _, s_syn = synthetize_seq(LSTM, hprev, cprev, X_train[:, 0], 200, 0.6)
        print("-" * 100)
        print(f"Synthetized sequence: \n{s_syn}")
        print("-" * 100)
    if step % 100000 == 0 and step > 0:
        _, s_lsyn = synthetize_seq(LSTM, hprev, cprev, X_train[:, 0], 1000, 0.6)
        print("-" * 100)
        print(f"Long synthetized sequence: \n{s_lsyn}")
        print("-" * 100)

    step += 1
    e += seq_length
    if e > len(book_data) - seq_length:
        e = 0
        epoch += 1
        hprev = torch.zeros((m, 1), dtype=torch.double)
        cprev = torch.zeros((m, 1), dtype=torch.double)
    else:
        hprev = ht.detach()
        cprev = ct.detach()

with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(LSTM, handle, protocol=pickle.HIGHEST_PROTOCOL)

Step: 0
	 * Smooth loss: 109.54975949941455
----------------------------------------------------------------------------------------------------
Synthetized sequence: 
}EHü9/g'7fQCm(p;/UL';aüOInniqbbVI1YgyYJGB,R;bV0YpRü!:'kT?vgC-Ds(SoU_e0aa!j:cfdPS}M"ureH(?:Q,?D7VMzü'7yh;vZgyI.9QrP}}2otwBzaB6TxT!CwrYRCGY,j:m0H•gyOS_üi7E^Z}^.pXue-B:ANNNWl}jKH7k6K_3wü^}6!;m/r:gr•nTEQC
----------------------------------------------------------------------------------------------------
Step: 1000
	 * Smooth loss: 74.8556932532987


KeyboardInterrupt: 

In [None]:
plt.plot(losses)
plt.xlabel('Steps')
plt.ylabel('Smooth loss')
plt.title(f'Training - eta: {eta} - seq_length: {seq_length} - m: {m} - n_epochs: {n_epochs}')
plt.grid(True)
plt.show()

In [None]:
with open('rnn_eminem.pickle', 'rb') as handle:
    test_rnn = pickle.load(handle)

In [21]:
first_char = " "
x_input = encode_string(first_char)
Y_t, s_t = synthetize_seq(
    LSTM, 
    torch.zeros((m, 1), dtype=torch.double), 
    torch.zeros((m, 1), dtype=torch.double),
    x_input[:,0], 1000, 0.8)
print(first_char + s_t)

 He on to his the cartar the his that coras een at he rolleconted whas has abuiss weer Harry Harrent had his creor scho budiut theel the inguped wish he had has youlr of sicinging in he rouse tham Sare whoy be all had wearly, had ine Wirine haire wark ther he had hes had sore, a musped his.  The more in therrud ikey theare tera him he corees Oof had had as the had that at Domlesbem had he, bant the had as for had beanred as the douplly thear on had and thet to hid upelime the been sif there housl thile in thit ult De his the would his pfot could solleche theing herd and se lill htowers, lamid thear as his the thear had, sit he weriry on woult he his mom of acus yee vloo haok and if sher for his his had been the ticas his seen sisieter him semored en that Harry been then the  onbon.
	It sarpest hig been his canibed Dust amould bous, voomed hit bed the would ave he has earing allouss his Rilly had lane shain of mkit theted it surdive his was a woretort him the the wat had his Worevor so 

#### Adam Optimizer

In [None]:
e, step, epoch = 0, 0, 0
n_epochs = 10
smooth_loss = 0
seq_length = 25
losses = []
hprev = torch.zeros((m, 1), dtype=torch.double)

eta = 0.0005
beta_1, beta_2, epsilon = 0.9, 0.999, 1e-8

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
vb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
vc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
vU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
vV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
vW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}
vs = {'b': vb, 'c': vc, 'U': vU, 'V': vV, 'W': vW}

while epoch < n_epochs:
    X_chars = book_data[e:e+seq_length]
    Y_chars = book_data[e+1:e+seq_length+1]
    X_train = encode_string(X_chars)
    Y_train = encode_string(Y_chars)

    A_train, H_train, P_train, ht = forward(RNN, X_train, hprev)
    loss = compute_loss(Y_train, P_train)
    grads, grads_clamped = backward(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

    for k in ms.keys():
        ms[k] = beta_1*ms[k] + (1 - beta_1)*grads_clamped[k]
        vs[k] = beta_2*vs[k] + (1 - beta_2)*(grads_clamped[k]**2)
        m_hat = ms[k]/(1 - beta_1**(step+1))
        v_hat = vs[k]/(1 - beta_2**(step+1))
        RNN[k] -= (eta/torch.sqrt(v_hat + epsilon))*m_hat

    if step == 0:
        smooth_loss = loss
    else:
        smooth_loss = 0.999*smooth_loss + 0.001*loss

    losses.append(smooth_loss)

    if step % 1000 == 0:
        print(f"Step: {step}")
        print(f"\t * Smooth loss: {smooth_loss:.4f}")
    if step % 5000 == 0:
        _, s_syn = synthetize_seq(RNN, hprev, X_train[:, 0], 200)
        print("-" * 100)
        print(f"Synthetized sequence: \n{s_syn}")
        print("-" * 100)
    if step % 100000 == 0 and step > 0:
        _, s_lsyn = synthetize_seq(RNN, hprev, X_train[:, 0], 1000)
        print("-" * 100)
        print(f"Long synthetized sequence: \n{s_lsyn}")
        print("-" * 100)

    step += 1
    e += seq_length
    if e > len(book_data) - seq_length:
        e = 0
        epoch += 1
        hprev = torch.zeros((m, 1), dtype=torch.double)
    else:
        hprev = ht

with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)

#### Mini-batch training

In [None]:
def forward_batch(rnn, X, hprev):
    K, seq_length, batch_size = X.shape
    m = hprev.shape[0]  # (m, batch_size)

    P = torch.zeros((K, seq_length, batch_size), dtype=torch.double)
    A = torch.zeros((m, seq_length, batch_size), dtype=torch.double)
    H = torch.zeros((m, seq_length, batch_size), dtype=torch.double)

    ht = hprev.clone()
    for i in range(seq_length):
        xt = X[:, i, :]  # Access the ith timestep across all batches
        at = torch.mm(rnn['W'], ht) + torch.mm(rnn['U'], xt) + rnn['b'].expand(m, batch_size)
        ht = torch.tanh(at)
        ot = torch.mm(rnn['V'], ht) + rnn['c'].expand(K, batch_size)
        pt = F.softmax(ot, dim=0)

        H[:, i, :] = ht
        P[:, i, :] = pt
        A[:, i, :] = at

    return A, H, P, ht

In [None]:
def compute_loss_batch(Y, P):
    batch_size = Y.shape[2]
    log_probs = torch.log(P)
    cross_entropy = -torch.sum(Y * log_probs)
    loss = cross_entropy.item() / batch_size
    return loss

In [None]:
def backward_batch(rnn, X, Y, A, H, P, hprev):
    dA = torch.zeros_like(A)
    dH = torch.zeros_like(H)

    G = -(Y - P)
    dV = torch.bmm(G.permute(2, 0, 1), H.permute(2, 1, 0)).mean(dim=0)
    dhtau = torch.matmul(G[:, -1, :].t(), rnn['V']).t()
    datau = (1 - torch.pow(torch.tanh(A[:, -1, :]), 2)) * dhtau
    dH[:, -1, :] = dhtau
    dA[:, -1, :] = datau

    for i in range(seq_length - 2, -1, -1):
        dht = torch.matmul(G[:, i, :].t(), rnn['V']).t() + torch.matmul(dA[:, i+1, :].t(), rnn['W']).t()
        dat = (1 - torch.pow(torch.tanh(A[:, i]), 2)) * dht
        dH[:, i] = dht
        dA[:, i] = dat

    Hd = torch.cat((hprev.reshape((m, 1, -1)), H[:, :-1, :]), dim=1)
    dW = torch.matmul(dA.permute(2, 0, 1), Hd.permute(2, 1, 0)).mean(dim=0)
    dU = torch.matmul(dA.permute(2, 0, 1), X.permute(2, 1, 0)).mean(dim=0)
    dc = G.sum(1).mean(dim=1).reshape((-1, 1))
    db = dA.sum(1).mean(dim=1).reshape((-1, 1))
    grads = {'U': dU, 'W': dW, 'V': dV, 'c': dc, 'b': db}
    grads_clamped = {k: torch.clamp(v, min=-5.0, max=5.0) for (k,v) in grads.items()}
    return grads, grads_clamped

In [None]:
e, step, epoch = 0, 0, 0
n_epochs = 500
smooth_loss = 0
batch_size = 32
seq_length = 150
eta = 0.1
losses = []
hprev = torch.zeros((m, batch_size), dtype=torch.double)

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}

while epoch < n_epochs:
    X_batch = []
    Y_batch = []
    for b in range(batch_size):
        start_index = e + b * seq_length
        X_chars = book_data[start_index:(start_index + seq_length)]
        Y_chars = book_data[(start_index + 1):(start_index + seq_length + 1)]
        X_batch.append(encode_string(X_chars))
        Y_batch.append(encode_string(Y_chars))

    X_train = torch.stack(X_batch, dim=2)  # shape: (K, seq_length, n_batch)
    Y_train = torch.stack(Y_batch, dim=2)  # shape: (K, seq_length, n_batch)

    A_train, H_train, P_train, hts = forward_batch(RNN, X_train, hprev)
    loss = compute_loss_batch(Y_train, P_train)
    grads, grads_clamped = backward_batch(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

    for k in ms.keys():
        ms[k] += grads_clamped[k]**2
        RNN[k] -= (eta/torch.sqrt(ms[k] + epsilon)) * grads_clamped[k]

    if step == 0:
        smooth_loss = loss
    else:
        smooth_loss = 0.999*smooth_loss + 0.001*loss
    losses.append(smooth_loss)

    if step % 1000 == 0:
        print(f"Step: {step}")
        print(f"\t * Smooth loss: {smooth_loss:.4f}")
    if step % 5000 == 0:
        _, s_syn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 200, 0.6)
        print("-" * 100)
        print(f"Synthetized sequence: \n{s_syn}")
        print("-" * 100)
    if step % 100000 == 0 and step > 0:
        _, s_lsyn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 1000, 0.6)
        print("-" * 100)
        print(f"Long synthetized sequence: \n{s_lsyn}")
        print("-" * 100)

    step += 1
    e += batch_size * seq_length
    if e > len(book_data) - batch_size * seq_length:
        e = 0
        epoch += 1
        hprev = torch.zeros((m, batch_size), dtype=torch.double)
    else:
        hprev = hts

with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)