# ModernNLP: #2
* Discussing text restoration by [Sommerschield et al.](https://www.aclweb.org/anthology/D19-1668/).
* Experimenting with a vanilla RNN encoder in Pytorch.
* Performing text classification to predict the next character.
* Instead of Ancient Greek text, we will use Plato in English. 

> Authored by John Pavlopoulos & Vasiliki Kougia

> Modified by Yongchao Wu

In [1]:
import nltk; nltk.download('punkt')
from urllib.request import urlopen
from nltk.tokenize import sent_tokenize
import random; random.seed(42)
import numpy as np
from math import ceil, floor
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
from torch.autograd import Variable

[nltk_data] Downloading package punkt to /home/chao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Download and pre-process the data

In [2]:
# This paper's dataset takes too long to download; use Plato in English.
data = urlopen("http://www.gutenberg.org/cache/epub/1497/pg1497.txt").read().decode("utf8")
data = data[760:-19110] # cut editorial notes and licences

In [3]:
# tokenise the text, and remove any noise
sentences = sent_tokenize(data)
sentences = [s.strip().lower() for s in sentences]
np.random.shuffle(sentences)

# The vocabulary will comprise characters
all_letters = list(set(" ".join(sentences)))
print(all_letters)

['-', '!', '*', 'h', 'z', 'b', 'x', '?', 'd', '\n', 'm', '"', '2', 'v', '6', 'f', '7', '.', '(', '5', 'n', '3', '1', 'j', '4', '8', '0', ';', '+', 'w', 'o', 'k', 'r', 'l', ')', ' ', '/', 'g', 'y', '=', 'i', 'c', "'", 't', '9', 'a', 'p', 'u', 'q', ',', 's', 'e', '\r', ':']


In [4]:
print (sentences[np.random.randint(len(sentences))])

(heraclitus said that the sun was extinguished every
evening and relighted every morning.)


### Define device to run on a local GPU

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Build the dataset

In [6]:
inputs, targets = [], []
maxlen = 128
for s in sentences:
  if len(s)<10: 
    continue
  txt = s[-maxlen:]
  r = np.random.randint(low=5, high=min(maxlen, len(txt)))
  inputs.append(txt[:r])
  targets.append(txt[r])

V = list(set("".join(sentences)))
targets_v = list(set(targets))
# Split to train, val and test
inputs_train, targets_train = inputs[:5000], targets[:5000] 
inputs_val, targets_val = inputs[5000:5500], targets[5000:5500]
inputs_test, targets_test = inputs[5500:], targets[5500:]

In [7]:
def input_encode(text, V, maxlen):
  x = np.zeros(maxlen, dtype=int)
  # Assign an index to each input character
  for i, char in enumerate(text):
    if i<maxlen:
      x[i] = V.index(char) + 1 # Index 0 is used for padding
  return x

def output_encode(char, target_v):
  # The output is the index of the ground truth character
  o = target_v.index(char)
  return o

In [8]:
batch_size = 16

# Encode input and output data of train, val and test
encoded_inputs_train = [input_encode(s, V, maxlen) for s in inputs_train]
lengths_train = [min(len(s), maxlen) for s in inputs_train]
encoded_targets_train = [output_encode(t, targets_v) for t in targets_train]

encoded_inputs_val = [input_encode(s, V, maxlen) for s in inputs_val]
lengths_val = [min(len(s), maxlen) for s in inputs_val]
encoded_targets_val = [output_encode(t, targets_v) for t in targets_val]

encoded_inputs_test = [input_encode(s, V, maxlen) for s in inputs_test]
lengths_test = [min(len(s), maxlen) for s in inputs_test]
encoded_targets_test = [output_encode(t, targets_v) for t in targets_test]

In [9]:
def generator(inputs, lengths, targets, batch_size):
  while True:
    # Loop over all instances
    d = list(zip(inputs, lengths, targets))
    random.shuffle(d)
    inputs, lengths, targets = zip(*d)
    for i in range(0, len(inputs), batch_size):
      x_inputs, x_lengths, y_targets = list(), list(), list()
      # Loop over the images in the batch and yield their instances
      for j in range(i, min(len(inputs), i + batch_size)):
        x_inputs.append(inputs[j])
        x_lengths.append(lengths[j])
        y_targets.append(targets[j])

      yield torch.LongTensor(x_inputs).to(device), torch.LongTensor(x_lengths), torch.tensor(y_targets).to(device)

In [10]:
train_generator = generator(encoded_inputs_train, lengths_train, encoded_targets_train, batch_size)
val_generator = generator(encoded_inputs_val, lengths_val, encoded_targets_val, batch_size)

## Build the model
* RNN_Model
* RNN_Encoder
* RNN

In [20]:
class RNN_Model(nn.Module):
    def __init__(self, encoder, decoder):
        super(RNN_Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x, seq_lengths):

        x, lengths  = self.encoder(x, seq_lengths)
        x = self.decoder(x, lengths)
        return x

In [21]:
class RNN_Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size=200, hidden_size=128,
                embedding_tensor=None, padding_index=0, num_layers=1, 
                dropout=0, Max_leng=maxlen):
        super(RNN_Encoder, self).__init__()      
        self.hidden = hidden_size
        self.dropout = dropout
        self.num_layers = num_layers


        # Define the layers in our architecture
        self.embedding_layer = nn.Embedding(vocab_size, embed_size, 
                      padding_idx=padding_index, _weight=embedding_tensor)
        self.drop_en = nn.Dropout(self.dropout)
        self.rnn = nn.GRU(input_size=embed_size, 
                      hidden_size=self.hidden, 
                      num_layers=self.num_layers, 
                      batch_first=True, 
                      bidirectional=True)
        self.attn = nn.Linear(self.hidden*2, self.hidden*2)
        
        
    def forward(self, x, seq_lengths):
        # Pass the input through the embedding layer
        text_embed = self.embedding_layer(x)
        # Apply dropout
        x_embed = self.drop_en(text_embed)

        # Pass the inputs to the GRU
        packed_input = pack_padded_sequence(x_embed, seq_lengths, batch_first=True, 
                                        enforce_sorted=False)
        packed_output, ht = self.rnn(packed_input)
        # Get the hidden states of all time steps
 
        out_rnn, lengths = pad_packed_sequence(packed_output, batch_first=True)
        
        #Calculate attention weights, add it to the output vector
        attn_weights = F.softmax(self.attn(out_rnn), dim=1)
        out_rnn = attn_weights * out_rnn
        out_rnn = self.drop_en(out_rnn)

        

        return out_rnn, lengths  

In [22]:
# improve decoding by using RNN decoder
class RNN_Decoder(nn.Module):
    def __init__(self, num_output, hidden_size=128, dropout=0):
        super(RNN_Decoder, self).__init__()      
       

        self.num_output = num_output
        self.rnn = nn.GRU(input_size=hidden_size*2, hidden_size = hidden_size,
                         batch_first=True)
        self.out = nn.Linear(hidden_size, num_output)
        self.softmax = nn.LogSoftmax(dim=-1)
        self.dropout = dropout
        self.drop_de = nn.Dropout(self.dropout)
    
    def forward(self, x, seq_lengths):
        
        # add mask to change attention score to zero outside bound of real sentence length
        mask = torch.arange(x.shape[1])[None, :] > seq_lengths[:, None]
        x[mask] = 0
        
        output, ht = self.rnn(x)
        
        row_indices = torch.arange(0, x.size(0)).long()
        col_indices = seq_lengths - 1
        last_hidden_state = output[row_indices, col_indices, :]
        
        last_hidden_state = self.drop_de(last_hidden_state)

        output = self.softmax(self.out(last_hidden_state))

        return output

In [23]:

encoder = RNN_Encoder(vocab_size=len(V)+1, dropout=0.2)
decoder = RNN_Decoder(num_output=len(targets_v),dropout=0.2)
model = RNN_Model(encoder, decoder)
model = model.to(device)

In [24]:
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score

In [25]:
# Define optimizer and loss
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Train and validate at the epoch's end, keep the best (based on val f1)
epochs, highest_val_f1 = 20, 0

for idx in tqdm(range(epochs), desc="Epoch"):
  epoch = idx+1
  #Switch to train mode
  model.train()
  for batch in tqdm(range(ceil(len(inputs_train)/batch_size)), desc="Iteration"):
    input_t, lengths_t, target_t = next(train_generator)
    output = model(input_t,lengths_t)
    loss = criterion(output,target_t)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    
  #Switch to eval mode
  model.eval()
  val_loss = []
  val_targets = []
  val_outputs = []
  for i in range(ceil(len(inputs_val)/batch_size)):
    input_t, lengths_t, target_t = next(val_generator)
    output = model(input_t,lengths_t)
    val_outputs.append(torch.argmax(output, dim=1))
    val_targets.append(target_t)
    val_loss.append(criterion(output,target_t).cpu().detach().numpy())
  val_outputs = torch.cat(val_outputs)
  val_targets = torch.cat(val_targets)        
  f1 = f1_score(val_targets.cpu().numpy(), val_outputs.cpu().detach().numpy(), 
                average="macro")
  print(f"EPOCH: {epoch} val loss: {sum(val_loss)/len(val_loss):.4f}, val f1: {f1:.3f}")
  if f1 > highest_val_f1:
    print("Save model....")
    torch.save({'model_state_dict': model.state_dict()}, "pytorch_model.bin")
    highest_val_f1 = f1

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 1 val loss: 2.3880, val f1: 0.119
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 2 val loss: 2.1659, val f1: 0.210
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 3 val loss: 2.0651, val f1: 0.270
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 4 val loss: 2.0240, val f1: 0.277
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 5 val loss: 2.0326, val f1: 0.285
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 6 val loss: 1.9427, val f1: 0.301
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 7 val loss: 1.9649, val f1: 0.293


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 8 val loss: 1.8806, val f1: 0.318
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 9 val loss: 1.9571, val f1: 0.314


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 10 val loss: 1.9877, val f1: 0.309


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 11 val loss: 2.0231, val f1: 0.314


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 12 val loss: 1.9979, val f1: 0.319
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 13 val loss: 2.0888, val f1: 0.313


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 14 val loss: 2.0287, val f1: 0.340
Save model....


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 15 val loss: 2.1156, val f1: 0.308


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 16 val loss: 2.1611, val f1: 0.308


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 17 val loss: 2.1852, val f1: 0.314


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 18 val loss: 2.3062, val f1: 0.326


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 19 val loss: 2.2863, val f1: 0.323


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=313.0, style=ProgressStyle(description_wi…


EPOCH: 20 val loss: 2.3750, val f1: 0.306



In [26]:
checkpoint = torch.load("pytorch_model.bin", map_location="cpu")
encoder_e = RNN_Encoder(vocab_size=len(V)+1, dropout=0.2)
decoder_e = RNN_Decoder(num_output=len(targets_v))
model_e = RNN_Model(encoder, decoder)

model_e.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [27]:
model_e.eval()
x=11
prompt = inputs_test[x]
text = prompt[:10]
for i in range(50):
  encoded_text = np.expand_dims(input_encode(text, V, maxlen), 0)
  # Get the character with the largest probability as the next character
  predicted = targets_v[model_e(torch.LongTensor(encoded_text).to(device), torch.LongTensor([len(text)])).argmax()][0]
  print(f"{text} --> {predicted}")
  # Add the predicted character to the input
  text = text+predicted

certainly --> .
certainly. --> .
certainly.. --> .
certainly... --> .
certainly.... --> .
certainly..... --> .
certainly...... --> .
certainly....... --> .
certainly........ --> .
certainly......... --> .
certainly.......... --> d
certainly..........d --> l
certainly..........dl --> y
certainly..........dly --> .
certainly..........dly. --> .
certainly..........dly.. --> .
certainly..........dly... --> .
certainly..........dly.... --> .
certainly..........dly..... --> d
certainly..........dly.....d --> i
certainly..........dly.....di --> n
certainly..........dly.....din --> g
certainly..........dly.....ding -->  
certainly..........dly.....ding  --> a
certainly..........dly.....ding a --> n
certainly..........dly.....ding an --> d
certainly..........dly.....ding and -->  
certainly..........dly.....ding and  --> a
certainly..........dly.....ding and a --> r
certainly..........dly.....ding and ar --> e
certainly..........dly.....ding and are -->  
certainly..........dly.....ding and are