Tutorial Vanilla RNNs from the scratch (not really)

Alfan Farizki Wicaksono, Fasilkom UI

In [None]:
import torch
import pandas as pd
import numpy as np

from collections import Counter
from torch import nn, optim
from torch.utils.data import DataLoader

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        sequence_length,
        documents, # list of strings
    ):
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        text = ""
        for doc in documents:
          text += doc + " "
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [None]:
class Linear(nn.Module):
    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.W = nn.Parameter(torch.Tensor(self.n_inputs, self.n_outputs))
        self.init_weights()

    def init_weights(self):
        for param in self.parameters():
            nn.init.uniform_(param, -0.1, 0.1)

    def forward(self, x):
        return x @ self.W

class RNNCell(nn.Module):
    def __init__(self, n_inputs, n_hiddens):
        super().__init__()
        self.h = Linear(n_hiddens, n_hiddens)
        self.x = Linear(n_inputs, n_hiddens)

    def forward(self, input, hidden):
        return torch.tanh(self.h(hidden) + self.x(input))

class RNN(nn.Module):
    def __init__(self, n_inputs, n_hiddens, cell):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        self.cell = cell(n_inputs, n_hiddens)

    def forward(self, inputs, prev_hidden_state):
        """ inputs: [batch_size, sequence_length, embedding_size] """ # embedding sama aja dgn channel kyk di CNN
        outputs = []
        hidden_state = prev_hidden_state
        n_steps = inputs.shape[1]
        for i in range(n_steps):
            hidden_state = self.cell(inputs[:, i], hidden_state) # ini di RNNCell nya *semua sebelumnya* sampai i
            outputs.append(hidden_state)

        return torch.stack(outputs, dim=1), hidden_state # output itu semuanya, state yg terakhir doang

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.rnn_size = 16
        self.embedding_dim = 16
        self.n_vocab = len(dataset.uniq_words)
        # setiap kata di one hot encoding
        self.embedding = nn.Embedding(  # ini dari suatu kata yg di encode, harus di embed dulu supaya bbisa masuk ke RNN agar bentuknya jadi vector
            num_embeddings=self.n_vocab, # matrix encoding nya ukurannya sebanyak n_vocab*embedding_dim (karena yg dikali sebelumnya one hot encoded = sepanjang vocab)
            embedding_dim=self.embedding_dim # karena one hot encoding, maka setiap dikali adalah baris matrixnya
        ) # embedding matrix  ini di train bareng dgn RNN nya menggunakan grad descent kyk biasa
        self.rnn = RNN(
            n_inputs=self.embedding_dim,
            n_hiddens=self.rnn_size,
            cell=RNNCell
        )
        self.fc = Linear(self.rnn_size, self.n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.rnn(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, batch_size):
        return torch.zeros(batch_size, self.rnn_size)

def train(dataset, model, batch_size, max_epochs=400):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        h_state = model.init_state(batch_size) #hidden state awal, NOL

        for batch, (x, y) in enumerate(dataloader):
            y_pred, h_state = model(x, h_state)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()

            #h_state detached from current graph; tapi isi tetap sama agar berlanjut (statefull)
            #graph batch sekarang jangan nyambung dengan batch berikutnya; tetapi nilai
            #h_state harus berlanjut dari satu batch ke batch berikutnya. Caranya adalah
            #dengan detach() ini.
            h_state = h_state.detach()

            optimizer.zero_grad()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

def predict(dataset, model, text, next_words=20):
    model.eval()

    words = text.split(' ')
    h_state = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, h_state = model(x, h_state)

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()

        # random choice
        #word_index = np.random.choice(len(last_word_logits), p=p)

        # the best one
        word_index = np.argmax(p)

        # selain best one bisa pake random choice atau yg lain, dan sebenarnya best choice bukan yg terbaik

        words.append(dataset.index_to_word[word_index])

    return words

In [None]:
documents = ["saya pergi ke depok",
             "di depok makan sayuran",
             "dan buah nangka yang segar",
             "angin bertiup kencang",
             "tanda hujan akan turun di jalan margonda"]

dataset = Dataset(2, documents)
dataloader = DataLoader(dataset, batch_size=2)
for batch, (xs, ys) in enumerate(dataloader):
  print(xs)
  print(ys)

tensor([[2, 3],
        [3, 4]])
tensor([[3, 4],
        [4, 0]])
tensor([[4, 0],
        [0, 1]])
tensor([[0, 1],
        [1, 0]])
tensor([[1, 0],
        [0, 5]])
tensor([[0, 5],
        [5, 6]])
tensor([[5, 6],
        [6, 7]])
tensor([[6, 7],
        [7, 8]])
tensor([[7, 8],
        [8, 9]])
tensor([[ 8,  9],
        [ 9, 10]])
tensor([[ 9, 10],
        [10, 11]])
tensor([[10, 11],
        [11, 12]])
tensor([[11, 12],
        [12, 13]])
tensor([[12, 13],
        [13, 14]])
tensor([[13, 14],
        [14, 15]])
tensor([[14, 15],
        [15, 16]])
tensor([[15, 16],
        [16, 17]])
tensor([[16, 17],
        [17, 18]])
tensor([[17, 18],
        [18,  1]])
tensor([[18,  1],
        [ 1, 19]])
tensor([[ 1, 19],
        [19, 20]])
tensor([[19, 20],
        [20, 21]])


In [None]:
dataset = Dataset(2, documents)
model = Model(dataset)

train(dataset, model, 2)

{'epoch': 0, 'batch': 0, 'loss': 3.0953125953674316}
{'epoch': 0, 'batch': 1, 'loss': 3.1119327545166016}
{'epoch': 0, 'batch': 2, 'loss': 3.0793190002441406}
{'epoch': 0, 'batch': 3, 'loss': 3.1518442630767822}
{'epoch': 0, 'batch': 4, 'loss': 3.085254669189453}
{'epoch': 0, 'batch': 5, 'loss': 3.0681371688842773}
{'epoch': 0, 'batch': 6, 'loss': 3.121950626373291}
{'epoch': 0, 'batch': 7, 'loss': 3.1206963062286377}
{'epoch': 0, 'batch': 8, 'loss': 3.074373722076416}
{'epoch': 0, 'batch': 9, 'loss': 3.0855484008789062}
{'epoch': 0, 'batch': 10, 'loss': 3.0858993530273438}
{'epoch': 1, 'batch': 0, 'loss': 3.0560905933380127}
{'epoch': 1, 'batch': 1, 'loss': 3.0790281295776367}
{'epoch': 1, 'batch': 2, 'loss': 3.0630311965942383}
{'epoch': 1, 'batch': 3, 'loss': 3.1225218772888184}
{'epoch': 1, 'batch': 4, 'loss': 3.0529046058654785}
{'epoch': 1, 'batch': 5, 'loss': 3.036722183227539}
{'epoch': 1, 'batch': 6, 'loss': 3.0947537422180176}
{'epoch': 1, 'batch': 7, 'loss': 3.09389638900756

In [None]:
predict(dataset, model, "saya pergi", next_words=20)

['saya',
 'pergi',
 'ke',
 'depok',
 'di',
 'depok',
 'makan',
 'sayuran',
 'dan',
 'buah',
 'nangka',
 'yang',
 'segar',
 'angin',
 'bertiup',
 'kencang',
 'tanda',
 'hujan',
 'akan',
 'turun',
 'di',
 'depok']