In [1]:
%cd ../

/Users/eloidieme/dev/python-projects/rnn


In [2]:
import numpy as np
import torch
import torch.nn.functional as F
import pickle
import time
import random
import matplotlib.pyplot as plt

In [3]:
book_fname = "./data/goblet_book.txt"
with open(book_fname, 'r') as book:
    book_data = book.read()
len(book_data)

925005

In [4]:
def split_text(text, train_frac=0.8, val_frac=0.1):
    train_end = int(len(text) * train_frac)
    val_end = train_end + int(len(text) * val_frac)

    train_data = text[:train_end]
    val_data = text[train_end:val_end]
    test_data = text[val_end:]

    return train_data, val_data, test_data

In [5]:
word_list = book_data.split()
chars = [[*word] for word in word_list]
max_len = max(len(word) for word in chars)
for wordl in chars:
    while len(wordl) < max_len:
        wordl.append(' ')
chars = np.array(chars)

In [6]:
unique_chars = list(np.unique(chars))
unique_chars.append('\n')
unique_chars.append('\t')
K = len(unique_chars)  # dimensionality of the input and output vectors

In [7]:
char_to_ind = {}
ind_to_char = {}
for idx, char in enumerate(unique_chars):
    char_to_ind[char] = idx
    ind_to_char[idx] = char

In [8]:
m = 100  # dimensionality of the hidden state
eta = 0.1  # learning rate
seq_length = 25  # length of input sequences used during training
epsilon = 1e-8  # for AdaGrad

In [9]:
sig = 0.01
RNN = {
    'b': torch.zeros((m, 1), dtype=torch.double), 
    'c': torch.zeros((K, 1), dtype=torch.double), 
    'U': torch.normal(0.0, sig, (m, K), dtype=torch.double), 
    'W': torch.normal(0.0, sig, (m, m), dtype=torch.double), 
    'V': torch.normal(0.0, sig, (K, m), dtype=torch.double)
}

In [10]:
def encode_char(char):
    oh = [0]*K
    oh[char_to_ind[char]] = 1
    return oh

In [11]:
def synthetize_seq(rnn, h0, x0, n, T = 1):
    t, ht, xt = 0, h0, x0
    indexes = []
    while t < n:
        xt = xt.reshape((K, 1))
        at = torch.mm(rnn['W'], ht) + torch.mm(rnn['U'], xt) + rnn['b']
        ht = torch.tanh(at)
        ot = torch.mm(rnn['V'], ht) + rnn['c']
        pt = F.softmax(ot/T, dim=0)
        cp = torch.cumsum(pt, dim=0)
        a = torch.rand(1)
        ixs = torch.where(cp - a > 0)
        ii = ixs[0][0].item()
        indexes.append(ii)
        xt = torch.zeros((K, 1), dtype=torch.double)
        xt[ii, 0] = 1
        t += 1
    Y = []
    for idx in indexes:
        oh = [0]*K
        oh[idx] = 1
        Y.append(oh)
    Y = torch.tensor(Y).t()
    
    s = ''
    for i in range(Y.shape[1]):
        idx = torch.where(Y[:, i] == 1)[0].item()
        s += ind_to_char[idx]
    
    return Y, s

In [12]:
def encode_string(chars):
    M = []
    for i in range(len(chars)):
        M.append(encode_char(chars[i]))
    M = torch.tensor(M, dtype=torch.double).t()
    return M

In [13]:
def forward(rnn, X, hprev):
    ht = hprev.clone()
    P = torch.zeros((K, seq_length), dtype=torch.double)
    A = torch.zeros((m, seq_length), dtype=torch.double)
    H = torch.zeros((m, seq_length), dtype=torch.double)
    for i in range(seq_length):
        xt = X[:, i].reshape((K, 1))
        at = torch.mm(rnn['W'], ht) + torch.mm(rnn['U'], xt) + rnn['b']
        ht = torch.tanh(at)
        ot = torch.mm(rnn['V'], ht) + rnn['c']
        pt = F.softmax(ot, dim=0)

        H[:, i] = ht.squeeze()
        P[:, i] = pt.squeeze()
        A[:, i] = at.squeeze()

    return A, H, P, ht

In [14]:
def compute_loss(Y, P):
    log_probs = torch.log(P)
    cross_entropy = -torch.sum(Y * log_probs)
    loss = cross_entropy.item()
    return loss

In [15]:
def evaluate_model(rnn, val_data):
    total_loss = 0
    total_characters = 0
    hprev = torch.zeros((m, 1), dtype=torch.double)
    for i in range(0, len(val_data) - seq_length, seq_length):
        X_chars = val_data[i:i + seq_length]
        Y_chars = val_data[i + 1:i + seq_length + 1]
        X_val = encode_string(X_chars)
        Y_val = encode_string(Y_chars)
        _, _, P, hprev = forward(rnn, X_val, hprev)
        loss = compute_loss(Y_val, P)
        total_loss += loss * seq_length
        total_characters += seq_length
    average_loss = total_loss / total_characters
    perplexity = torch.exp(torch.tensor(average_loss))
    return perplexity.item()

In [16]:
def backward(rnn, X, Y, A, H, P, hprev):
    dA = torch.zeros_like(A)
    dH = torch.zeros_like(H)

    G = -(Y - P)
    dV = torch.matmul(G, H.t())
    dhtau = torch.matmul(G[:, -1], rnn['V'])
    datau = (1 - torch.pow(torch.tanh(A[:, -1]), 2)) * dhtau
    dH[:, -1] = dhtau.squeeze()
    dA[:, -1] = datau.squeeze()

    for i in range(seq_length - 2, -1, -1):
        dht = torch.matmul(G[:, i], rnn['V']) + torch.matmul(dA[:, i+1].reshape(1, -1), rnn['W'])
        dat = (1 - torch.pow(torch.tanh(A[:, i]), 2)) * dht
        dH[:, i] = dht.squeeze()
        dA[:, i] = dat.squeeze()

    Hd = torch.cat((hprev, H[:, :-1]), dim=1)
    dW = torch.matmul(dA, Hd.t())
    dU = torch.matmul(dA, X.t())
    dc = G.sum(1).reshape((-1, 1))
    db = dA.sum(1).reshape((-1, 1))
    grads = {'U': dU, 'W': dW, 'V': dV, 'c': dc, 'b': db}
    grads_clamped = {k: torch.clamp(v, min=-5.0, max=5.0) for (k,v) in grads.items()}
    return grads, grads_clamped

In [None]:
e, step, epoch = 0, 0, 0
n_epochs = 2
smooth_loss = 0
seq_length = 25
losses = []
hprev = torch.zeros((m, 1), dtype=torch.double)

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}

while epoch < n_epochs:
    X_chars = book_data[e:e+seq_length]
    Y_chars = book_data[e+1:e+seq_length+1]
    X_train = encode_string(X_chars)
    Y_train = encode_string(Y_chars)

    A_train, H_train, P_train, ht = forward(RNN, X_train, hprev)
    loss = compute_loss(Y_train, P_train)
    grads, grads_clamped = backward(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

    for k in ms.keys():
        ms[k] += grads_clamped[k]**2
        RNN[k] -= (eta/torch.sqrt(ms[k] + epsilon))*grads_clamped[k]

    if step == 0:
        smooth_loss = loss
    else:
        smooth_loss = 0.999*smooth_loss + 0.001*loss

    losses.append(smooth_loss)

    if step % 1000 == 0:
        print(f"Step: {step}")
        print(f"\t * Smooth loss: {smooth_loss:.4f}")
    if step % 5000 == 0:
        _, s_syn = synthetize_seq(RNN, hprev, X_train[:, 0], 200, 0.6)
        print("-" * 100)
        print(f"Synthetized sequence: \n{s_syn}")
        print("-" * 100)
    if step % 100000 == 0 and step > 0:
        _, s_lsyn = synthetize_seq(RNN, hprev, X_train[:, 0], 1000, 0.6)
        print("-" * 100)
        print(f"Long synthetized sequence: \n{s_lsyn}")
        print("-" * 100)

    step += 1
    e += seq_length
    if e > len(book_data) - seq_length:
        e = 0
        epoch += 1
        hprev = torch.zeros((m, 1), dtype=torch.double)
    else:
        hprev = ht

with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
plt.plot(losses)
plt.xlabel('Steps')
plt.ylabel('Smooth loss')
plt.title(f'Training - eta: {eta} - seq_length: {seq_length} - m: {m} - n_epochs: {n_epochs}')
plt.grid(True)
plt.show()

In [22]:
with open('rnn_eminem.pickle', 'rb') as handle:
    test_rnn = pickle.load(handle)

In [25]:
first_char = " "
x_input = encode_string(first_char)
Y_t, s_t = synthetize_seq(
    test_rnn, 
    torch.zeros((m, 1), dtype=torch.double), 
    x_input[:,0], 1000, 0.8)
print(first_char + s_t)

 than your nybody feels like his but we're gonna get back now the forth, alwusins of beated a stusider percence to I don't carely steppin your look and shit for me, time, a mon't flaillers, at? 'cause I'm just fuckin' Suche yoa, Mather why I'm doing this world offster I can man the all race and just a little ass of my ling off than it when
Bitch it, botta's too listed
As attracks your wants up for up
You can I have ans with you face you, Ind start the shit and my on light will going and lone in the breamply gonna knees in the layin' a mic on stop Hole]
Sleet up
One of the little gonna creany punk, I ond the feels we won't been the licked and not knock you)?
And he wentrice to surrely fear? (Pearrough to stop time Jomn underst
Someterhoming like the hearthought me are gonna left the cheah my shit 'em laind here
In do
I've like I five and is ballity grow you say what's linderelle
But hit the stapping to a parond in the mirs, so here so much and away ald he pastte an off and still I ain't

#### Adam Optimizer

In [None]:
e, step, epoch = 0, 0, 0
n_epochs = 10
smooth_loss = 0
seq_length = 25
losses = []
hprev = torch.zeros((m, 1), dtype=torch.double)

eta = 0.0005
beta_1, beta_2, epsilon = 0.9, 0.999, 1e-8

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
vb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
vc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
vU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
vV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
vW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}
vs = {'b': vb, 'c': vc, 'U': vU, 'V': vV, 'W': vW}

while epoch < n_epochs:
    X_chars = book_data[e:e+seq_length]
    Y_chars = book_data[e+1:e+seq_length+1]
    X_train = encode_string(X_chars)
    Y_train = encode_string(Y_chars)

    A_train, H_train, P_train, ht = forward(RNN, X_train, hprev)
    loss = compute_loss(Y_train, P_train)
    grads, grads_clamped = backward(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

    for k in ms.keys():
        ms[k] = beta_1*ms[k] + (1 - beta_1)*grads_clamped[k]
        vs[k] = beta_2*vs[k] + (1 - beta_2)*(grads_clamped[k]**2)
        m_hat = ms[k]/(1 - beta_1**(step+1))
        v_hat = vs[k]/(1 - beta_2**(step+1))
        RNN[k] -= (eta/torch.sqrt(v_hat + epsilon))*m_hat

    if step == 0:
        smooth_loss = loss
    else:
        smooth_loss = 0.999*smooth_loss + 0.001*loss

    losses.append(smooth_loss)

    if step % 1000 == 0:
        print(f"Step: {step}")
        print(f"\t * Smooth loss: {smooth_loss:.4f}")
    if step % 5000 == 0:
        _, s_syn = synthetize_seq(RNN, hprev, X_train[:, 0], 200)
        print("-" * 100)
        print(f"Synthetized sequence: \n{s_syn}")
        print("-" * 100)
    if step % 100000 == 0 and step > 0:
        _, s_lsyn = synthetize_seq(RNN, hprev, X_train[:, 0], 1000)
        print("-" * 100)
        print(f"Long synthetized sequence: \n{s_lsyn}")
        print("-" * 100)

    step += 1
    e += seq_length
    if e > len(book_data) - seq_length:
        e = 0
        epoch += 1
        hprev = torch.zeros((m, 1), dtype=torch.double)
    else:
        hprev = ht

with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)

#### Random locations

In [20]:
def split_into_chunks(s, L):
    chunk_size = len(s) // L
    remainder = len(s) % L

    chunks = []
    for i in range(L):
        start = i * chunk_size + min(i, remainder)
        end = start + chunk_size + (1 if i < remainder else 0)
        chunks.append(s[start:end])
    return chunks

In [None]:
step, epoch = 0, 0
n_epochs = 10
seq_length = 25
smooth_loss = 0
losses = []

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}

while epoch < n_epochs:
    print(f"Epoch {epoch+1}/{n_epochs}")
    L = random.randint(30, 170)
    print(f"\t * No. chunks: {L}")
    chunks = split_into_chunks(book_data, L)
    random.shuffle(chunks)
    for idx, chunk in enumerate(chunks):
        print(f"-> Reached chunk {idx+1}")
        e = 0
        hprev = torch.zeros((m, 1), dtype=torch.double)
        while e < (len(chunk) - seq_length):
            X_chars = chunk[e:e+seq_length]
            Y_chars = chunk[e+1:e+seq_length+1]
            X_train = encode_string(X_chars)
            Y_train = encode_string(Y_chars)

            A_train, H_train, P_train, ht = forward(RNN, X_train, hprev)
            loss = compute_loss(Y_train, P_train)
            grads, grads_clamped = backward(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

            for k in ms.keys():
                ms[k] += grads_clamped[k]**2
                RNN[k] -= (eta/torch.sqrt(ms[k] + epsilon))*grads_clamped[k]

            if step == 0:
                smooth_loss = loss
            else:
                smooth_loss = 0.999*smooth_loss + 0.001*loss

            losses.append(smooth_loss)

            e += seq_length
            hprev = ht

            if step % 1000 == 0:
                print(f"Step: {step}")
                print(f"\t * Smooth loss: {smooth_loss:.4f}")
            if step % 5000 == 0:
                _, s_syn = synthetize_seq(RNN, hprev, X_train[:, 0], 200)
                print("-" * 100)
                print(f"Synthetized sequence: \n{s_syn}")
                print("-" * 100)
            if step % 100000 == 0 and step > 0:
                _, s_lsyn = synthetize_seq(RNN, hprev, X_train[:, 0], 1000)
                print("-" * 100)
                print(f"Long synthetized sequence: \n{s_lsyn}")
                print("-" * 100)
            step += 1
            

    epoch += 1

with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)

#### Mini-batch training

In [17]:
def forward_batch(rnn, X, hprev):
    K, seq_length, batch_size = X.shape
    m = hprev.shape[0]  # (m, batch_size)

    P = torch.zeros((K, seq_length, batch_size), dtype=torch.double)
    A = torch.zeros((m, seq_length, batch_size), dtype=torch.double)
    H = torch.zeros((m, seq_length, batch_size), dtype=torch.double)

    ht = hprev.clone()
    for i in range(seq_length):
        xt = X[:, i, :]  # Access the ith timestep across all batches
        at = torch.mm(rnn['W'], ht) + torch.mm(rnn['U'], xt) + rnn['b'].expand(m, batch_size)
        ht = torch.tanh(at)
        ot = torch.mm(rnn['V'], ht) + rnn['c'].expand(K, batch_size)
        pt = F.softmax(ot, dim=0)

        H[:, i, :] = ht
        P[:, i, :] = pt
        A[:, i, :] = at

    return A, H, P, ht

In [18]:
def compute_loss_batch(Y, P):
    batch_size = Y.shape[2]
    log_probs = torch.log(P)
    cross_entropy = -torch.sum(Y * log_probs)
    loss = cross_entropy.item() / batch_size
    return loss

In [19]:
def backward_batch(rnn, X, Y, A, H, P, hprev):
    dA = torch.zeros_like(A)
    dH = torch.zeros_like(H)

    G = -(Y - P)
    dV = torch.bmm(G.permute(2, 0, 1), H.permute(2, 1, 0)).mean(dim=0)
    dhtau = torch.matmul(G[:, -1, :].t(), rnn['V']).t()
    datau = (1 - torch.pow(torch.tanh(A[:, -1, :]), 2)) * dhtau
    dH[:, -1, :] = dhtau
    dA[:, -1, :] = datau

    for i in range(seq_length - 2, -1, -1):
        dht = torch.matmul(G[:, i, :].t(), rnn['V']).t() + torch.matmul(dA[:, i+1, :].t(), rnn['W']).t()
        dat = (1 - torch.pow(torch.tanh(A[:, i]), 2)) * dht
        dH[:, i] = dht
        dA[:, i] = dat

    Hd = torch.cat((hprev.reshape((m, 1, -1)), H[:, :-1, :]), dim=1)
    dW = torch.matmul(dA.permute(2, 0, 1), Hd.permute(2, 1, 0)).mean(dim=0)
    dU = torch.matmul(dA.permute(2, 0, 1), X.permute(2, 1, 0)).mean(dim=0)
    dc = G.sum(1).mean(dim=1).reshape((-1, 1))
    db = dA.sum(1).mean(dim=1).reshape((-1, 1))
    grads = {'U': dU, 'W': dW, 'V': dV, 'c': dc, 'b': db}
    grads_clamped = {k: torch.clamp(v, min=-5.0, max=5.0) for (k,v) in grads.items()}
    return grads, grads_clamped

In [None]:
e, step, epoch = 0, 0, 0
n_epochs = 500
smooth_loss = 0
batch_size = 32
seq_length = 150
eta = 0.1
losses = []
hprev = torch.zeros((m, batch_size), dtype=torch.double)

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}

while epoch < n_epochs:
    X_batch = []
    Y_batch = []
    for b in range(batch_size):
        start_index = e + b * seq_length
        X_chars = book_data[start_index:(start_index + seq_length)]
        Y_chars = book_data[(start_index + 1):(start_index + seq_length + 1)]
        X_batch.append(encode_string(X_chars))
        Y_batch.append(encode_string(Y_chars))

    X_train = torch.stack(X_batch, dim=2)  # shape: (K, seq_length, n_batch)
    Y_train = torch.stack(Y_batch, dim=2)  # shape: (K, seq_length, n_batch)

    A_train, H_train, P_train, hts = forward_batch(RNN, X_train, hprev)
    loss = compute_loss_batch(Y_train, P_train)
    grads, grads_clamped = backward_batch(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

    for k in ms.keys():
        ms[k] += grads_clamped[k]**2
        RNN[k] -= (eta/torch.sqrt(ms[k] + epsilon)) * grads_clamped[k]

    if step == 0:
        smooth_loss = loss
    else:
        smooth_loss = 0.999*smooth_loss + 0.001*loss
    losses.append(smooth_loss)

    if step % 1000 == 0:
        print(f"Step: {step}")
        print(f"\t * Smooth loss: {smooth_loss:.4f}")
    if step % 5000 == 0:
        _, s_syn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 200, 0.6)
        print("-" * 100)
        print(f"Synthetized sequence: \n{s_syn}")
        print("-" * 100)
    if step % 100000 == 0 and step > 0:
        _, s_lsyn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 1000, 0.6)
        print("-" * 100)
        print(f"Long synthetized sequence: \n{s_lsyn}")
        print("-" * 100)

    step += 1
    e += batch_size * seq_length
    if e > len(book_data) - batch_size * seq_length:
        e = 0
        epoch += 1
        hprev = torch.zeros((m, batch_size), dtype=torch.double)
    else:
        hprev = hts

with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)

### Batches + Adam

In [None]:
e, step, epoch = 0, 0, 0
n_epochs = 100
smooth_loss = 0
batch_size = 32
seq_length = 100
losses = []
hprev = torch.zeros((m, batch_size), dtype=torch.double)

eta = 0.005
beta_1, beta_2, epsilon = 0.9, 0.999, 1e-8

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
vb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
vc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
vU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
vV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
vW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}
vs = {'b': vb, 'c': vc, 'U': vU, 'V': vV, 'W': vW}

while epoch < n_epochs:
    X_batch = []
    Y_batch = []
    for b in range(batch_size):
        start_index = e + b * seq_length
        X_chars = book_data[start_index:(start_index + seq_length)]
        Y_chars = book_data[(start_index + 1):(start_index + seq_length + 1)]
        X_batch.append(encode_string(X_chars))
        Y_batch.append(encode_string(Y_chars))

    X_train = torch.stack(X_batch, dim=2)  # shape: (K, seq_length, n_batch)
    Y_train = torch.stack(Y_batch, dim=2)  # shape: (K, seq_length, n_batch)

    A_train, H_train, P_train, hts = forward_batch(RNN, X_train, hprev)
    loss = compute_loss_batch(Y_train, P_train)
    grads, grads_clamped = backward_batch(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

    for k in ms.keys():
        ms[k] = beta_1*ms[k] + (1 - beta_1)*grads_clamped[k]
        vs[k] = beta_2*vs[k] + (1 - beta_2)*(grads_clamped[k]**2)
        m_hat = ms[k]/(1 - beta_1**(step+1))
        v_hat = vs[k]/(1 - beta_2**(step+1))
        RNN[k] -= (eta/torch.sqrt(v_hat + epsilon))*m_hat

    if step == 0:
        smooth_loss = loss
    else:
        smooth_loss = 0.999*smooth_loss + 0.001*loss
    losses.append(smooth_loss)

    if step % 1000 == 0:
        print(f"Step: {step}")
        print(f"\t * Smooth loss: {smooth_loss:.4f}")
    if step % 5000 == 0:
        _, s_syn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 200, 0.6)
        print("-" * 100)
        print(f"Synthetized sequence: \n{s_syn}")
        print("-" * 100)
    if step % 100000 == 0 and step > 0:
        _, s_lsyn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 1000, 0.6)
        print("-" * 100)
        print(f"Long synthetized sequence: \n{s_lsyn}")
        print("-" * 100)

    step += 1
    e += batch_size * seq_length
    if e > len(book_data) - batch_size * seq_length:
        e = 0
        epoch += 1
        hprev = torch.zeros((m, batch_size), dtype=torch.double)
    else:
        hprev = hts

with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)

### Combine chunks and batches

In [None]:
n_epochs = 60
batch_size = 10
eta = 0.1
epsilon = 1e-8
seq_length = 25
smooth_loss = 0
losses = []

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}

step, epoch = 0, 0
hprev = torch.zeros((m, batch_size), dtype=torch.double)

while epoch < n_epochs:
    L = random.randint(10, 20)
    chunks = split_into_chunks(book_data, L)
    random.shuffle(chunks)
    
    print(f"Epoch {epoch+1}/{n_epochs} with {L} chunks")
    
    for idx, chunk in enumerate(chunks):
        print(f"Processing chunk {idx+1}/{L}")
        e = 0
        
        while e < (len(chunk) - batch_size * seq_length):
            X_batch = []
            Y_batch = []
            
            for b in range(batch_size):
                start_index = e + b * seq_length
                X_chars = chunk[start_index:(start_index + seq_length)]
                Y_chars = chunk[(start_index + 1):(start_index + seq_length + 1)]
                X_batch.append(encode_string(X_chars))
                Y_batch.append(encode_string(Y_chars))
            
            X_train = torch.stack(X_batch, dim=2)
            Y_train = torch.stack(Y_batch, dim=2)
            
            A_train, H_train, P_train, hts = forward_batch(RNN, X_train, hprev)
            loss = compute_loss_batch(Y_train, P_train)
            grads, grads_clamped = backward_batch(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

            for k in ms.keys():
                ms[k] += grads_clamped[k]**2
                RNN[k] -= (eta/torch.sqrt(ms[k] + epsilon)) * grads_clamped[k]
            
            if step == 0:
                smooth_loss = loss
            else:
                smooth_loss = 0.999*smooth_loss + 0.001*loss
            losses.append(smooth_loss)

            if step % 1000 == 0:
                print(f"Step: {step}")
                print(f"\t * Smooth loss: {smooth_loss:.4f}")
            if step % 5000 == 0:
                _, s_syn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 200, 0.6)
                print("-" * 100)
                print(f"Synthetized sequence: \n{s_syn}")
                print("-" * 100)
            if step % 100000 == 0 and step > 0:
                _, s_lsyn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 1000, 0.6)
                print("-" * 100)
                print(f"Long synthetized sequence: \n{s_lsyn}")
                print("-" * 100)
            
            e += batch_size * seq_length
            step += 1
            hprev = hts
            
        if e >= len(chunk) - batch_size * seq_length:
            hprev = torch.zeros((m, batch_size), dtype=torch.double)  # Reset hidden state for new chunk
        
    epoch += 1

# Save trained RNN
with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)

### Combine chunks, batches and Adam

In [21]:
n_epochs = 120
batch_size = 16
eta = 0.001
smooth_loss = 0
seq_length = 75
losses = []

min_L, max_L = 120, 150

beta_1, beta_2, epsilon = 0.9, 0.999, 1e-8

mb = torch.zeros_like(RNN['b'], dtype=torch.float)
vb = torch.zeros_like(RNN['b'], dtype=torch.float)
mc = torch.zeros_like(RNN['c'], dtype=torch.float)
vc = torch.zeros_like(RNN['c'], dtype=torch.float)
mU = torch.zeros_like(RNN['U'], dtype=torch.float)
vU = torch.zeros_like(RNN['U'], dtype=torch.float)
mV = torch.zeros_like(RNN['V'], dtype=torch.float)
vV = torch.zeros_like(RNN['V'], dtype=torch.float)
mW = torch.zeros_like(RNN['W'], dtype=torch.float)
vW = torch.zeros_like(RNN['W'], dtype=torch.float)
ms = {'b': mb, 'c': mc, 'U': mU, 'V': mV, 'W': mW}
vs = {'b': vb, 'c': vc, 'U': vU, 'V': vV, 'W': vW}

step, epoch = 0, 0
hprev = torch.zeros((m, batch_size), dtype=torch.double)

while epoch < n_epochs:
    L = random.randint(min_L, max_L + 1)
    chunks = split_into_chunks(book_data, L)
    random.shuffle(chunks)
    
    print(f"Epoch {epoch+1}/{n_epochs} with {L} chunks")
    
    for idx, chunk in enumerate(chunks):
        #print(f"Processing chunk {idx+1}/{L}")
        e = 0
        
        while e < (len(chunk) - batch_size * seq_length):
            X_batch = []
            Y_batch = []
            
            for b in range(batch_size):
                start_index = e + b * seq_length
                X_chars = chunk[start_index:(start_index + seq_length)]
                Y_chars = chunk[(start_index + 1):(start_index + seq_length + 1)]
                X_batch.append(encode_string(X_chars))
                Y_batch.append(encode_string(Y_chars))
            
            X_train = torch.stack(X_batch, dim=2)
            Y_train = torch.stack(Y_batch, dim=2)
            
            A_train, H_train, P_train, hts = forward_batch(RNN, X_train, hprev)
            loss = compute_loss_batch(Y_train, P_train)
            grads, grads_clamped = backward_batch(RNN, X_train, Y_train, A_train, H_train, P_train, hprev)

            for k in ms.keys():
                ms[k] = beta_1*ms[k] + (1 - beta_1)*grads_clamped[k]
                vs[k] = beta_2*vs[k] + (1 - beta_2)*(grads_clamped[k]**2)
                m_hat = ms[k]/(1 - beta_1**(step+1))
                v_hat = vs[k]/(1 - beta_2**(step+1))
                RNN[k] -= (eta/torch.sqrt(v_hat + epsilon))*m_hat
            
            if step == 0:
                smooth_loss = loss
            else:
                smooth_loss = 0.999*smooth_loss + 0.001*loss
            losses.append(smooth_loss)

            if step % 1000 == 0:
                print(f"Step: {step}")
                print(f"\t * Smooth loss: {smooth_loss:.4f}")
            if step % 5000 == 0:
                _, s_syn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 200, 0.6)
                print("-" * 100)
                print(f"Synthetized sequence: \n{s_syn}")
                print("-" * 100)
            if step % 100000 == 0 and step > 0:
                _, s_lsyn = synthetize_seq(RNN, hprev[:, 0:1], X_train[:, 0, 0], 1000, 0.6)
                print("-" * 100)
                print(f"Long synthetized sequence: \n{s_lsyn}")
                print("-" * 100)
            
            e += batch_size * seq_length
            step += 1
            hprev = hts
            
        if e >= len(chunk) - batch_size * seq_length:
            hprev = torch.zeros((m, batch_size), dtype=torch.double)  # Reset hidden state for new chunk
        
    epoch += 1

# Save trained RNN
with open(f'rnn_{time.time()}.pickle', 'wb') as handle:
    pickle.dump(RNN, handle, protocol=pickle.HIGHEST_PROTOCOL)

Epoch 1/120 with 223 chunks
Step: 0
	 * Smooth loss: 352.5300
----------------------------------------------------------------------------------------------------
Synthetized sequence: 
M`0ä*a′l’,IVj[”éqáIx4Ns—w6“′t88"′r]0k”ö,UPGs2′&àfsöqKhmçhAHpNtQn	kç956zL""	MéßCçéqN0gnV`%
DtvAá.üsr”(–©}%Hr”`Xпzn&és,oW7WqIs9Pbáplà’"}â…]3—`‘M}OiF0vIé)“ÅzVt′óv}W ”XU	 "ZcVàWl'hJéaO3äb `	dGHrFmws"7ßpZ	
----------------------------------------------------------------------------------------------------
Epoch 2/120 with 223 chunks
Step: 1000
	 * Smooth loss: 276.7788
Epoch 3/120 with 223 chunks
Step: 2000
	 * Smooth loss: 213.4509
Epoch 4/120 with 223 chunks
Epoch 5/120 with 223 chunks
Step: 3000
	 * Smooth loss: 179.0625
Epoch 6/120 with 223 chunks
Step: 4000
	 * Smooth loss: 161.2244
Epoch 7/120 with 223 chunks
Epoch 8/120 with 223 chunks
Step: 5000
	 * Smooth loss: 151.5699
----------------------------------------------------------------------------------------------------
Synthetized sequence: 
here, i