In [None]:
# download the Large IMDB Movie Review Dataset
# the task is binary classification: positive or negative review

! wget http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz
! tar -xzf aclImdb_v1.tar.gz

In [None]:
import torch
import math
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from torch import optim
import os
from collections import namedtuple

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
# let's set some parameters

train_path = "aclImdb/train/" 
test_path = "aclImdb/test/"

batch_size = 100
max_len = 300
embedding_size = 300
min_count = 2
device = torch.device('cuda')

## Load the dataset
- 25000 train and test sentences

In [None]:
Sentence = namedtuple('Sentence', ['index', 'tokens', 'label'])

def read_imdb_movie_dataset(dataset_path):

    indices = []
    text = []
    rating = []

    i = 0

    for filename in os.listdir(os.path.join(dataset_path, "pos")):
        file_path = os.path.join(dataset_path, "pos", filename)
        data = open(file_path, 'r', encoding="ISO-8859-1").read()
        indices.append(i)
        text.append(data)
        rating.append(1)
        i = i + 1

    for filename in os.listdir(os.path.join(dataset_path, "neg")):
        file_path = os.path.join(dataset_path, "neg", filename)
        data = open(file_path, 'r', encoding="ISO-8859-1").read()
        indices.append(i)
        text.append(data)
        rating.append(0)
        i = i + 1

    sentences = [ Sentence(index, text.split(), rating)
                  for index, text, rating in zip(indices, text, rating)]

    return sentences

In [None]:
train_examples = read_imdb_movie_dataset(train_path)
test_examples = read_imdb_movie_dataset(test_path)

print(len(train_examples))
print(len(train_examples))

## Mapping our words to unique identifiers: the Vocabulary object
- We will create an object to manage a mapping between words (or more generally tokens) and unique indices. 
- There are a few special symbols that we will be adding to handle special cases.
  - The first key special case is the `UNK` token, wich will represent all tokens that we do not have in our vocabulary. This is needed as we will build our vocabulary only using the training examples, and during validation or testing (or if we deploy our model in production) we may encounter new words that also need to be represented somehow.
  - The `PAD` token, which we will use to create even-sized batches of sentences of different length (more on this below). 
  - The beginning-of-sentence or `BOS` token, which we may use to denote the beginning of a sentence in some special cases
  - The end-of-sentence or `EOS` token, which as in the previous case is useful for certain tasks.
  

In [None]:
# Define the string of special tokens we will need 
UNK = '<UNK>'
PAD = '<PAD>'
BOS = '<BOS>'
EOS = '<EOS>'


class VocabItem:

    def __init__(self, string, hash=None):
        """
        Our token object, representing a term in our vocabulary.
        """
        self.string = string
        self.count = 0
        self.hash = hash

    def __str__(self):
        """
        For pretty-printing of our object
        """
        return 'VocabItem({})'.format(self.string)

    def __repr__(self):
        """
        For pretty-printing of our object
        """
        return self.__str__()


class Vocab:

    def __init__(
        self,
        min_count=0,
        no_unk=False,
        add_padding=False,
        add_bos=False,
        add_eos=False,
        unk=None):

        """
        :param min_count: The minimum frequency count threshold for a token
                          to be added to our mapping. Only useful if
                          the unk parameter is None.

        :param add_padding: If we should add the special `PAD` token.

        :param add_bos: If we should add the special `BOS` token.

        :param add_eos: If we should add the special `EOS` token.

        :param no_unk: If we should not add the `UNK` token to our Vocab.

        :param unk: A string with the unknown token, in case our
                    sentences have already been processed for this,
                    or `None` to use our default `UNK` token.
        """

        self.no_unk = no_unk
        self.vocab_items = []
        self.vocab_hash = {}
        self.word_count = 0
        self.special_tokens = []
        self.min_count = min_count
        self.add_padding = add_padding
        self.add_bos = add_bos
        self.add_eos = add_eos
        self.unk = unk

        self.UNK = None
        self.PAD = None
        self.BOS = None
        self.EOS = None

        self.index2token = []
        self.token2index = {}

        self.finished = False

    def add_tokens(self, tokens):
        if self.finished:
            raise RuntimeError('Vocabulary is finished')

        for token in tokens:
            if token not in self.vocab_hash:
                self.vocab_hash[token] = len(self.vocab_items)
                self.vocab_items.append(VocabItem(token))

            self.vocab_items[self.vocab_hash[token]].count += 1
            self.word_count += 1

    def finish(self):

        token2index = self.token2index
        index2token = self.index2token

        tmp = []

        if not self.no_unk:

            # we add/handle the special `UNK` token
            # and set it to have index 0 in our mapping
            if self.unk:
                self.UNK = VocabItem(self.unk, hash=0)
                self.UNK.count = self.vocab_items[self.vocab_hash[self.unk]].count
                index2token.append(self.UNK)
                self.special_tokens.append(self.UNK)

                for token in self.vocab_items:
                    if token.string != self.unk:
                        tmp.append(token)

            else:
                self.UNK = VocabItem(UNK, hash=0)
                index2token.append(self.UNK)
                self.special_tokens.append(self.UNK)

                for token in self.vocab_items:
                    if token.count <= self.min_count:
                        self.UNK.count += token.count
                    else:
                        tmp.append(token)
        else:
            for token in self.vocab_items:
                tmp.append(token)

        # we sort our vocab. items by frequency
        # so for the same corpus, the indices of our words
        # are always the same
        tmp.sort(key=lambda token: token.count, reverse=True)

        # we always add our additional special tokens
        # at the end of our mapping
        if self.add_bos:
            self.BOS = VocabItem(BOS)
            tmp.append(self.BOS)
            self.special_tokens.append(self.BOS)

        if self.add_eos:
            self.EOS = VocabItem(EOS)
            tmp.append(self.EOS)
            self.special_tokens.append(self.EOS)

        if self.add_padding:
            self.PAD = VocabItem(PAD)
            tmp.append(self.PAD)
            self.special_tokens.append(self.PAD)

        index2token += tmp

        # we update the vocab_hash for each
        # VocabItem object in our list
        # based on their frequency
        for i, token in enumerate(self.index2token):
            token2index[token.string] = i
            token.hash = i

        self.index2token = index2token
        self.token2index = token2index

        if not self.no_unk:
            print('Unknown vocab size:', self.UNK.count)

        print('Vocab size: %d' % len(self))

        self.finished = True

    def __getitem__(self, i):
        return self.index2token[i]

    def __len__(self):
        return len(self.index2token)

    def __iter__(self):
        return iter(self.index2token)

    def __contains__(self, key):
        return key in self.token2index

    def tokens2indices(self, tokens, add_bos=False, add_eos=False):
        """
        Returns a list of mapping indices by processing the given string
        with our `tokenizer` and `token_function`, and defaulting to our
        special `UNK` token whenever we found an unseen term.

        :param string: A sentence string we wish to map into our vocabulary.

        :param add_bos: If we should add the `BOS` at the beginning.

        :param add_eos: If we should add the `EOS` at the end.

        :return: A list of ints, with the indices of each token in the
                given string.
        """
        string_seq = []
        if add_bos:
            string_seq.append(self.BOS.hash)
        for token in tokens:
            if self.no_unk:
                string_seq.append(self.token2index[token])
            else:
                string_seq.append(self.token2index.get(token, self.UNK.hash))
        if add_eos:
            string_seq.append(self.EOS.hash)
        return string_seq

    def indices2tokens(self, indices, ignore_ids=()):
        """
        Returns a list of strings by mapping back every index to our
        vocabulary.

        :param indices: A list of ints.

        :param ignore_ids: An itereable with indices to ignore, meaning
                           that we will not look for them in our mapping.

        :return: A list of strings.

        Will raise a KeyException whenever we pass an index that we
        do not have in our mapping, except when provided with `ignore_ids`.

        """
        tokens = []
        for idx in indices:
            if idx in ignore_ids:
                continue
            tokens.append(self.index2token[idx].string)

        return tokens

- Now we can instance our vocabulary objects and add the data.
- We will use one vocabulary for the input data (the sentences), and another vocabulary object for the output data, the class labels. In this way our code is generic and should work out-of-the-box for any number of output labels.

In [None]:
# for the input vocabulary
# we set a minimum frequency, therefore adding the `UNK` special token
# and we also add the `PAD` special token, as we will need it later  
src_vocab = Vocab(min_count=min_count, add_padding=True)

# for the output vocabulary
# we do not need the `UNK` token (we know all the classes), or the `PAD`
tgt_vocab = Vocab(no_unk=True, add_padding=False)

In [None]:
for sentence in train_examples:
    src_vocab.add_tokens(sentence.tokens[:max_len])
    tgt_vocab.add_tokens([sentence.label])

src_vocab.finish()
tgt_vocab.finish()

In [None]:
src_vocab.tokens2indices('the movie was bad'.split())

In [None]:
Vocabs = namedtuple('Vocabs', ['src', 'tgt'])
vocabs = Vocabs(src_vocab, tgt_vocab)

## Representing words using sparse vectors: Word Embeddings
- One of the major breakthroughs in NLP with deep models came after the conception of word embeddings, which changed the way in which we represent each word in our machine learning models.
- We start by simply assigning an initially random vector to each word in our vocabulary.These vectors are stacked together into a big matrix, usually referred to as the *embedding* matrix. After we have built our vocabulary, all we have to do is to create a big tensor of shape (`vocab_size`, `embedding_size`).
- In theory, whenever we need to obtain the vector for a given word, we could build a one-hot vector of our word and multply this vector by our *embedding* matrix. All but one value in this one-hot vector are zeroes, the result of this product will correspond exactly to the vector that represents our word.
- Our *embeddings* will be treated as parameters of our models and are trained with it. This is possible because the *embedding* mechanism as has a well-defined derivative, so we are  allowed to use backpropagation to train these vectors.
- Note that in practice, however, the one-hot-based behavior can be achieved by simply selecting row vectors from our *embedding* matrix, given our indices.

In [None]:
embeddings = nn.Embedding(
    len(src_vocab),
    embedding_size,
    padding_idx=src_vocab.PAD.hash
)

In [None]:
print(embeddings.weight.size())

## The Batch objects
 - To easily access all the data in a batch, let's create a special Batch object that will give us access to all the information we may require during training.
 - Let's begin creating a more friendly object that contains a numeric representation of our inputs and outputs.
- By default we will use numpy objects, but we will also add a function to translate the contents of the object to PyTorch.
- We will create this object to be generic enough so we can use it with tasks other than classification, too. 
 - This object will work like a dictionary, but it will also allow us to access each component using an attribute with the same name.
  The main principle is that this dictionary-like batch will hold `numpy` objects as values, and that after calling the `to_torch_()` function, they will be turned into `pytorch` objects and moved to the corresponding provided device. In this way, we know that all our elements inside the batch object are in the right place.
 - We will combine our `Batch` object with a `BatchTuple` object that will hold data relevant to a specific input of the model.

In [None]:
class Batch(dict):
    def __init__(self, *args, **kwargs):
        super(Batch, self).__init__(*args, **kwargs)
        self.__dict__ = self
        self._is_torch = False

    def to_torch_(self, device):
        self._is_torch = False
        for key in self.keys():
            value = self[key]
            # we move `numpy` objects to `pytorch`
            if isinstance(value, BatchTuple):
                value.to_torch_(device)
            # we also move our BatchTuple objects to `pytorch`
            if isinstance(value, np.ndarray):
                self[key] = torch.from_numpy(value).to(device)


class BatchTuple(object):
    def __init__(self, sequences, lengths, sublengths, masks):
        self.sequences = sequences
        self.lengths = lengths
        self.sublengths = sublengths
        self.masks = masks
        self._is_torch = False

    def to_torch_(self, device):
        if not self._is_torch:
            self.sequences = torch.tensor(
                self.sequences, device=device, dtype=torch.long
            )

            if self.lengths is not None:
                self.lengths = torch.tensor(
                    self.lengths, device=device, dtype=torch.long
                )

            if self.sublengths is not None:
                self.sublengths = torch.tensor(
                    self.sublengths, device=device, dtype=torch.long
                )
            if self.masks is not None:
                self.masks = torch.tensor(
                    self.masks, device=device, dtype=torch.float
                )

### The padding function
- Let's suppose we have these two sentences to build a batch:
  - the dog barks $\rightarrow [1, 2 ,3]$
  - the cat likes to sleep $\rightarrow [1, 4, 5, 6, 7]$
  
  In order to put these two examples in a batch Tensor, we will need to *pad* the shortest sentence to have the same length of the longest one. 
  - the dog barks $\rightarrow [1, 2 ,3, 0 , 0]$
  - the cat likes to sleep $\rightarrow [1, 4, 5, 6, 7]$
  
  Finally, our batch Tensor will look like this: 
  - $\begin{bmatrix}1 & 2 & 3 & 0 & 0 \\ 1 & 4 & 5 & 6 & 7\end{bmatrix}$
  
  where its first dimension represents the size of the batch, and its second dimension has the length of the longest sentence in our batch.

In [None]:
def pad_list(
    sequences,
    dim0_pad=None,
    dim1_pad=None,
    align_right=False,
    pad_value=0
):
    """
    Receives a list of lists and returns a padded 2d ndarray,
    and a list of lengths. 
    
    sequences: a list of lists. len(sequences) = M, and N is the max
               length of any of the lists contained in sequences.
               e.g.: [[2,45,3,23,54], [12,4,2,2], [4], [45, 12]]
   
    Returns a numpy ndarray of dimension (M, N) corresponding to the padded
    sequences and a list of the original lengths.
    
    Returns:
       - out: a torch tensor of dimension (M, N) 
       - lengths: a list of ints containing the lengths of each element
                  in sequences
       
    """
    
    sequences = [np.asarray(sublist) for sublist in sequences]

    if not dim0_pad:
        dim0_pad = len(sequences)

    if not dim1_pad:
        dim1_pad = max(len(seq) for seq in sequences)

    out = np.full(shape=(dim0_pad, dim1_pad), fill_value=pad_value)

    lengths = []
    for i in range(len(sequences)):
        data_length = len(sequences[i])
        lengths.append(data_length)
        offset = dim1_pad - data_length if align_right else 0
        np.put(out[i], range(offset, offset + data_length), sequences[i])

    lengths = np.array(lengths)

    return out, lengths

## The BatchBuilder object
- On top of the Batch object we create our own SequenceClassificationBatchBuilder which will be the in charge of transforming input raw examples into a collection of our Batch objects that our model can handle.
- This object will do all the heavy-lifting, turning our string examples into our batch objects, which PyTorch can later handle.
- We will combine this object with the `DataLoader` util from `pytorch`, using as a function for the [`collate_fn` parameter](https://pytorch.org/docs/stable/data.html#working-with-collate-fn), which allows us to provide a custom function to create this funcion. In our case, this is achieved by implementing the `__call__` function in the `BatchBuilder` object, which will esentally turn the [instance into a function](https://docs.python.org/3/reference/datamodel.html#emulating-callable-objects).

In [None]:
class SequenceClassificationBatchBuilder(object):
    # Because the `__call__` function needs to only recieve 
    # one parameter (due to restrictions of the `DataLoader`
    # we can use the constructor we can pass any additional
    # inputs we may require when building our batches
    def __init__(self, vocabs, max_len=None):
        self.vocabs = vocabs
        self.max_len = max_len
    
    # This will the function called by the `DataLoader` object
    # that only accepts the `examples` parameter
    def __call__(self, examples):

        ids_batch = [int(sentence.index) for sentence in examples]

        src_examples = [
            self.vocabs.src.tokens2indices(sentence.tokens[: self.max_len])
            for sentence in examples
        ]

        tgt_examples = [
            self.vocabs.tgt.token2index[sentence.label] for sentence in examples
        ]

        src_padded, src_lengths = pad_list(
            src_examples, pad_value=self.vocabs.src.PAD.hash
        )

        src_batch_tuple = BatchTuple(src_padded, src_lengths, None, None)

        tgt_batch_tuple = BatchTuple(tgt_examples, None, None, None)

        return Batch(
            indices=ids_batch, src=src_batch_tuple, tgt=tgt_batch_tuple
        )

Let's instance our `batch_builder`, feed it into the `DataLoader` object alongside the  training and test examples, and let's inspect a single batch of examples.

In [None]:
batch_builder = SequenceClassificationBatchBuilder(
    vocabs, max_len=max_len
)

train_batches = DataLoader(
    train_examples,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=batch_builder,
)

test_batches = DataLoader(
    test_examples,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=batch_builder,
)

In [None]:
train_batches_iter = iter(train_batches)

In [None]:
train_batch = next(train_batches_iter)

In [None]:
train_batch.src.sequences

## The Pytorch Model
### The LSTM
![An unrolled RNN.](http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png)
- The LSTM is a special kind of Recurrent Neural Network that will process sequence data and return a vector for each input in our sequence. In the example, given a sequence of inputs $X=x_1, \ldots , x_t$, the LSTM will give us a sequence of $t$ vectors also called hidden states $H= h_1, \ldots, h_t$.
- The LSTM is a complex beast, in this tutorial we will be skipping details on how exactly it works. For more details, visit http://pytorch.org/docs/master/nn.html#lstm
- If we think of our input sequence as our word vectors for a given sentence, we can think of the output as a kind of enriched or contextualized version of the input, which will contain not only information about the word each vector represents, but also about its previous words.
- In PyTorch, LSTMs will return both the set of output vectors $H$ but also some additional output that we will not pay attention to.
- Because we need a fixed-size vector to classify our sentences, we will have to use some kind of pooling function over our hidden states to achieve this. 

In [None]:
def mean_pooling(batch_hidden_states, batch_lengths):
    '''
    :param batch_hidden_states: torch.Tensor(batch_size, seq_len, hidden_size)
    :param batch_lengths: list(batch_size)
    :return:
    '''
    batch_lengths = batch_lengths.float()
    batch_lengths = batch_lengths.unsqueeze(1)
    if batch_hidden_states.is_cuda:
        batch_lengths = batch_lengths.cuda()

    pooled_batch = torch.sum(batch_hidden_states, 1)
    pooled_batch = pooled_batch / batch_lengths.expand_as(pooled_batch)

    return pooled_batch


def max_pooling(batch_hidden_states):
    '''
    :param batch_hidden_states: torch.Tensor(batch_size, seq_len, hidden_size)
    :return:
    '''
    pooled_batch, _ = torch.max(batch_hidden_states, 1)
    return pooled_batch

- The next key util functions are related to the fact that we are using batches of sentences to train.
- To make the training efficient, Pytorch asks us to sort the examples in our batch by sequence length and build a special object.
- We will use the function `pack_padded_sequence()` to build this special `PackedSequence` object given our sorted padded batch and the lengths of each sentence on it
- Conversely, we will use the `pad_packed_sequence()` function to turn the output of the `nn.LSTM`, a `PackedSequence` object, into a regular Pytorch tensor. This tensor will have zeroes in all padding positions, so we can later directy use our pooling functions.

In [None]:
def pack_rnn_input(embedded_sequence_batch, sequence_lengths):
    """
    Prepares the special `PackedSequence` object that can be
    efficiently processed by the `nn.LSTM`.

    :param embedded_sequence_batch: torch.Tensor(seq_len, batch_size)

    :param sequence_lengths: list(batch_size)

    :return:
      - `PackedSequence` object containing our padded batch
      - indices to sort back our sentences to their original order
    """

    sequence_lengths = sequence_lengths.cpu().numpy()

    sorted_sequence_lengths = np.sort(sequence_lengths)[::-1]
    sorted_sequence_lengths = torch.from_numpy(
        sorted_sequence_lengths.copy()
    )

    idx_sort = np.argsort(-sequence_lengths)
    idx_unsort = np.argsort(idx_sort)

    idx_sort = torch.from_numpy(idx_sort)
    idx_unsort = torch.from_numpy(idx_unsort)

    if embedded_sequence_batch.is_cuda:
        idx_sort = idx_sort.cuda()
        idx_unsort = idx_unsort.cuda()

    embedded_sequence_batch = embedded_sequence_batch.index_select(
        0, idx_sort
    )

    # Handling padding in Recurrent Networks
    packed_rnn_input = nn.utils.rnn.pack_padded_sequence(
        embedded_sequence_batch, 
        sorted_sequence_lengths,
        batch_first=True
    )

    return packed_rnn_input, idx_unsort

  
def unpack_rnn_output(packed_rnn_output, indices):
    """
     Recover a regular tensor given a `PackedSequence` as returned
     by  `nn.LSTM`

    :param packed_rnn_output: torch object

    :param indices: Variable(LongTensor) of indices to sort output

    :return:
      - Padded tensor

    """
    encoded_sequence_batch, _ = nn.utils.rnn.pad_packed_sequence(
        packed_rnn_output, batch_first=True
    )

    encoded_sequence_batch = encoded_sequence_batch.index_select(0, indices)

    return encoded_sequence_batch

- To build the model, we extend the `nn.Module`

In [None]:
class BiLSTM(nn.Module):

    def __init__(
        self,
        embeddings,
        hidden_size,
        num_labels,
        input_dropout=0,
        output_dropout=0,
        bidirectional=True,
        num_layers=2,
        pooling='mean'
    ):

        super(BiLSTM, self).__init__()

        self.embeddings = embeddings
        self.pooling = pooling

        self.input_dropout = nn.Dropout(input_dropout)
        self.output_dropout = nn.Dropout(output_dropout)

        self.bidirectional = bidirectional
        self.num_layers = num_layers
        self.num_labels = num_labels

        self.hidden_size = hidden_size

        self.input_size = self.embeddings.embedding_dim

        self.lstm = nn.LSTM(
            self.input_size,
            hidden_size,
            bidirectional=bidirectional,
            num_layers=num_layers,
            batch_first=True
        )

        self.total_hidden_size = self.hidden_size 
        if self.bidirectional:
            self.total_hidden_size += self.hidden_size

        self.output_layer = nn.Linear(
            self.total_hidden_size,
            self.num_labels)

        self.loss_function = nn.CrossEntropyLoss()

        
    def forward(self, src_batch, tgt_batch=None):

        src_sequences = src_batch.sequences
        src_lengths = src_batch.lengths

        embedded_sequence_batch = self.embeddings(src_sequences)
        embedded_sequence_batch = self.input_dropout(
            embedded_sequence_batch
        )

        packed_rnn_input, indices = pack_rnn_input(
            embedded_sequence_batch, src_lengths
        )

        rnn_packed_output, _ = self.lstm(packed_rnn_input)
        encoded_sequence_batch = unpack_rnn_output(
            rnn_packed_output, indices
        )

        if self.pooling == "mean":
            # batch_size, hidden_x_dirs
            pooled_batch = mean_pooling(encoded_sequence_batch,
                                        src_lengths)

        elif self.pooling == "max":
            # batch_size, hidden_x_dirs
            pooled_batch = max_pooling(encoded_sequence_batch)
        else:
            raise NotImplementedError

        logits = self.output_layer(pooled_batch)
        _, predictions = logits.max(1)

        if tgt_batch is not None:
            targets = tgt_batch.sequences
            loss = self.loss_function(logits, targets)
        else:
            loss = None

        return loss, predictions, logits

### Instancing our model
- Let's define the hyperparameters of our model

In [None]:
epochs = 10
hidden_size = 300
log_interval = 10
num_labels = 2
input_dropout = 0.5
output_dropout = 0.5
bidirectional = True
num_layers = 2
pooling = 'mean'
lr = 0.001
gradient_clipping = 0.25

In [None]:
model = BiLSTM(
    embeddings=embeddings,
    hidden_size=hidden_size,
    num_labels=num_labels,
    input_dropout=input_dropout,
    output_dropout=output_dropout,
    bidirectional=bidirectional,
    num_layers=num_layers,
    pooling=pooling
)

model.to(device)
    
print(model)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
for epoch in range(epochs):
    
    epoch_correct = 0
    epoch_total = 0
    epoch_loss = 0
    i = 0

    model.train()

    for batch in train_batches:

        batch.to_torch_(device)

        ids_batch = batch.indices
        src_batch = batch.src
        tgt_batch = batch.tgt

        loss, predictions, logits = model.forward(
            src_batch,
            tgt_batch=tgt_batch
        )

        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            gradient_clipping)

        optimizer.step()
        correct = (predictions == tgt_batch.sequences).long().sum()
        total = tgt_batch.sequences.size(0)
        epoch_correct += correct.item()
        epoch_total += total
        epoch_loss += loss.item()
        i += 1

    accuracy  = 100 * epoch_correct / epoch_total

    print('Epoch {}'.format(epoch))
    print('Train Loss: {}'.format(epoch_loss / len(train_batches)))
    print('Train Accuracy: {}'.format(accuracy))

    test_epoch_correct = 0
    test_epoch_total = 0
    test_epoch_loss = 0

    model.eval()

    for batch in test_batches:

        ids_batch = batch.indices
        src_batch = batch.src
        tgt_batch = batch.tgt
        
        batch.to_torch_(device)

        loss, predictions, logits = model.forward(
            src_batch,
            tgt_batch=tgt_batch)

        correct = (predictions == tgt_batch.sequences).long().sum()
        total = tgt_batch.sequences.size(0)
        test_epoch_correct += correct.item()
        test_epoch_total += total
        test_epoch_loss += loss.item()

    test_accuracy = 100 * test_epoch_correct / test_epoch_total

    print('\n---------------------')
    print('Test Loss: {}'.format(test_epoch_loss / len(test_batches)))
    print('Test Accuracy: {}'.format(test_accuracy))
    print('---------------------\n')
