# Practice: BiLSTM for PoS Tagging
_This notebook is based on [open-source implementation](https://github.com/bentrevett/pytorch-pos-tagging) of PoS Tagging in PyTorch._

## Introduction

In this series we'll be building a machine learning model that produces an output for every element in an input sequence, using PyTorch and torchtext. Specifically, we will be inputting a sequence of text and the model will output a part-of-speech (PoS) tag for each token in the input text. This can also be used for named entity recognition (NER), where the output for each token will be what type of entity, if any, the token is.

In this notebook, we'll be implementing a multi-layer bi-directional LSTM (BiLSTM) to predict PoS tags using the Universal Dependencies English Web Treebank (UDPOS) dataset.

## Preparing the data

Let's use the torchtext to load the data. Each item of the UDPOS dataset consists of 3 lists: a tokenized sentence and two different sets of tags, [universal dependency (UD) tags](https://universaldependencies.org/u/pos/) and [Penn Treebank (PTB) tags](https://www.sketchengine.eu/penn-treebank-tagset/). We'll only train our model on the UD tags.

In [None]:
from torchtext.datasets import UDPOS


train_data = list(UDPOS(split="train"))
print(f"Number of training examples: {len(train_data)}")
print()

text, ud_tags, ptb_tags = train_data[0]
print("-------------------------------")
print("Word\t\tUD Tag\tPTB tag")
print("-------------------------------")
for word, ud_tag, ptb_tag in zip(text, ud_tags, ptb_tags):
    print(f"{word:<8}\t{ud_tag}\t{ptb_tag}")

Just like before, we need to build vocabularies for both words and tags.

In [None]:
from collections import Counter

from torchtext.vocab import vocab as Vocab


word_counts = Counter()
tag_counts = Counter()
for text, tags, _ in train_data:
    word_counts.update(word.lower() for word in text)
    tag_counts.update(tags)

word_vocab = Vocab(word_counts, min_freq=2)
tag_vocab = Vocab(tag_counts)

print("----------------------------------")
print("Tag\t\tCount\tPercentage")
print("----------------------------------")
total = sum(tag_counts.values())
for tag, count in tag_counts.most_common():
    print(f"{tag:<8}\t{count}\t{100 * count / total:.1f}%")

Our word vocabulary will have to handle unknown tokens (note how we set `min_freq` to 2 for it) in order to simulate the real world conditions. On the other hand, the tags vocabulary doesn't have such a problem, as we deal with strictly finite set of tags.

In [None]:
word_vocab.insert_token("<unk>", index=0)
word_vocab.set_default_index(0)

Also, both vocabularies would need a padding tokens to handle the padding.

In [None]:
word_vocab.append_token("<pad>")
tag_vocab.append_token("<pad>")

Now that we have our vocabularies, we are ready to start building our model! However, before we get there, let's use one more thing. The torchtext library provides a range of pretrained word2vec models for us to use. Let's use the [GloVe](https://nlp.stanford.edu/projects/glove/) to initialize our embeddings!

We can load the model as follows:

In [None]:
from torchtext.vocab import GloVe


glove = GloVe(name="6B", dim=100)

And extract word vectors as follows:

In [None]:
glove["word"].shape, glove["word"][:5]

Using this dict-like interface, we can extract vectors for all the words, the model saw during training. What happens to the out-of-vocabulary words? Let's try it out:

In [None]:
glove["some-non-existent-word"][:5]

The torchtext's GloVe yields a zeros vector. This is the default behaviour, which we can modify using the `unk_init` parameter:

In [None]:
import torch


glove = GloVe(name="6B", dim=100, unk_init=torch.Tensor.normal_)

Now the in-vocabulary words still yield the same vectors as before:

In [None]:
glove["word"][:5]

But out-of-vocabulary words now yield normally distributed vectors:

In [None]:
glove["some-non-existent-word"][:5]

What's important here is to note that GloVe doesn't save this *unknown* vectors, it just calls our `unk_init` function on each unknown word each time. This means that if we try to fetch a vector for the same out-of-vocabulary word once more, we would get a completely different vector:

In [None]:
glove["some-non-existent-word"][:5]

How is this usefull? Well, we don't want to fetch vectors from GloVe all the time. We actually want our model to learn more task specific embeddings during training, so we still want to create an embeddings layer. However, it might be a good idea to initialize embeddings for the known words with something more meaningful than just random. And this is exactly the place, where GloVe comes into play! Let's generate a matrix with embeddings for all words in our vocabulary (including the out-of-vocabulary words for the GloVe, in which case we will get just random vectors) and use it to init out embeddings layer later on.

In [None]:
word_vectors = glove.get_vecs_by_tokens(word_vocab.lookup_tokens(range(len(word_vocab))))
word_vectors.shape

Now that are finished tinkering with data, the last bit is to create our `DataLoader` for which we once more need to write a custom `collate_fn`.

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader


def collate_batch(batch):
    text_list, tags_list = [], []
    for text, tags, _ in batch:
        # YOUR CODE HERE
        # Convert text and tags into lists of token indices and cast them
        # into torch tensors. Store tensors in text_list and tag_list.
        pass

    # YOUR CODE HERE
    # Pad sequences with pad_sequence function.
    # texts_padded = pad_sequence(...)
    # tags_padded = pad_sequence(...)

    return texts_padded, tags_padded


batch_size = 128
train_dataloader = DataLoader(train_data, batch_size, shuffle=True, collate_fn=collate_batch)
text, tags = next(iter(train_dataloader))
text.shape, tags.shape

Let's also create a `DataLoader` for a validation dataset to evaluate our model during training.

In [None]:
val_data = list(UDPOS(split="valid"))
val_dataloader = DataLoader(val_data, batch_size, collate_fn=collate_batch)

## Building the Model

Next up, we define our model - a multi-layer bi-directional LSTM. The image below shows a simplified version of the model with only one LSTM layer and omitting the LSTM's cell state for clarity.

![](https://github.com/girafe-ai/ml-mipt/blob/21f_advanced/week1_04_transformer_n_pos_tagging/assets/pos-bidirectional-lstm.png?raw=1)

The model takes in a sequence of tokens, $X = \{x_1, x_2,...,x_T\}$, passes them through an embedding layer, $e$, to get the token embeddings, $e(X) = \{e(x_1), e(x_2), ..., e(x_T)\}$.

These embeddings are processed - one per time-step - by the forward and backward LSTMs. The forward LSTM processes the sequence from left-to-right, whilst the backward LSTM processes the sequence right-to-left, i.e. the first input to the forward LSTM is $x_1$ and the first input to the backward LSTM is $x_T$. 

The LSTMs also take in the the hidden, $h$, and cell, $c$, states from the previous time-step

$$h^{\rightarrow}_t = \text{LSTM}^{\rightarrow}(e(x^{\rightarrow}_t), h^{\rightarrow}_{t-1}, c^{\rightarrow}_{t-1})$$
$$h^{\leftarrow}_t=\text{LSTM}^{\leftarrow}(e(x^{\leftarrow}_t), h^{\leftarrow}_{t-1}, c^{\leftarrow}_{t-1})$$

After the whole sequence has been processed, the hidden and cell states are then passed to the next layer of the LSTM.

The initial hidden and cell states, $h_0$ and $c_0$, for each direction and layer are initialized to a tensor full of zeros.

We then concatenate both the forward and backward hidden states from the final layer of the LSTM, $H = \{h_1, h_2, ... h_T\}$, where $h_1 = [h^{\rightarrow}_1;h^{\leftarrow}_T]$, $h_2 = [h^{\rightarrow}_2;h^{\leftarrow}_{T-1}]$, etc. and pass them through a linear layer, $f$, which is used to make the prediction of which tag applies to this token, $\hat{y}_t = f(h_t)$.

When training the model, we will compare our predicted tags, $\hat{Y}$ against the actual tags, $Y$, to calculate a loss, the gradients w.r.t. that loss, and then update our parameters.

We implement the model detailed above in the `BiLSTMPOSTagger` class.

`nn.Embedding` is an embedding layer and the input dimension should be the size of the input (text) vocabulary. We tell it what the index of the padding token is so it does not update the padding token's embedding entry.

`nn.LSTM` is the LSTM. We apply dropout as regularization between the layers, if we are using more than one.

`nn.Linear` defines the linear layer to make predictions using the LSTM outputs. We double the size of the input if we are using a bi-directional LSTM. The output dimensions should be the size of the tag vocabulary.

We also define a dropout layer with `nn.Dropout`, which we use in the `forward` method to apply dropout to the embeddings and the outputs of the final layer of the LSTM.

In [None]:
import torch.nn as nn


class BiLSTMPOSTagger(nn.Module):
    def __init__(self, word_vectors, hidden_dim, n_tags, n_layers, dropout, pad_idx):
        super().__init__()

        # This is how we can init our embeddings with pretrained word vectors.
        # The freeze=False parameter tells torch that we still want to update
        # the embeddings during training.
        self.embedding = nn.Embedding.from_pretrained(
            word_vectors, freeze=False, padding_idx=pad_idx
        )

        # YOUR CODE HERE
        # Define LSTM, linear and dropout layers.

    def forward(self, text):
        # text has a shape of [seq_len, batch_size]

        # YOUR CODE HERE
        # Compute an embedding and apply dropout.
        # embedded = ...
        # embedded should have a shape of [seq_len, batch_size, emb_dim]

        # YOUR CODE HERE
        # Compute the RNN output values.
        # outputs = ...
        # Note that RNN will return two things:
        # 1) outputs: holds the backward and forward hidden states in the final layer
        # 2) hidden: holds the backward and forward hidden state at the final time-step
        # outputs should have a shape of [seq_len, batch_size, 2 * hid_dim]

        # YOUR CODE HERE
        # Use the fc layer to make a prediction of what the tag should be.
        # Don't forget to apply dropout

        return predictions

## Training the Model

Next, we instantiate the model. We need to ensure the embedding dimensions matches that of the GloVe embeddings we loaded earlier.

The rest of the hyperparmeters have been chosen as sensible defaults, though there may be a combination that performs better on this model and dataset.

In [None]:
model = BiLSTMPOSTagger(
    word_vectors,
    hidden_dim=128,
    n_tags=len(tag_vocab),
    n_layers=2,
    dropout=0.25,
    pad_idx=word_vocab["<pad>"],
)

We initialize the weights from a simple Normal distribution. Again, there may be a better initialization scheme for this model and dataset.

In [None]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.normal_(param.data, mean=0, std=0.1)


model.apply(init_weights);

Next, a small function to tell us how many parameters are in our model. Useful for comparing different models.

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"The model has {count_parameters(model):,} trainable parameters")

Now we define our optimizer. We use Adam with the default learning rate.

In [None]:
optimizer = torch.optim.Adam(model.parameters())

Next, we define our loss function, cross-entropy loss.

Even though we have no `<unk>` tokens within our tag vocab, we still have `<pad>` tokens. This is because all sentences within a batch need to be the same size. However, we don't want to calculate the loss when the target is a `<pad>` token as we aren't training our model to recognize padding tokens.

We handle this by setting the `ignore_index` in our loss function to the index of the padding token in our tag vocabulary.

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=tag_vocab["<pad>"])

We then place our model on our GPU, if we have one.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

We will be using the loss value between our predicted and actual tags to train the network, but ideally we'd like a more interpretable way to see how well our model is doing - accuracy.

The issue is that we don't want to calculate accuracy over the `<pad>` tokens as we aren't interested in predicting them.

The function below only calculates accuracy over non-padded tokens. We then compare the predictions of such elements with the labels to get a count of how many predictions were correct. We then divide this by the number of non-pad elements to get our accuracy value over the batch.

In [None]:
@torch.no_grad()
def accuracy(pred, target, pad_idx=tag_vocab["<pad>"]):
    pred = pred.argmax(dim=1)
    correct = (pred == target) & (target != pad_idx)
    return correct.sum() / torch.count_nonzero(target != pad_idx)

Once again we will use a tensorboard to track our training progress.

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs

Next is the training loop.

We first set the model to `train` mode to turn on dropout/batch-norm/etc. (if used). Then we iterate over our iterator, which returns a batch of examples. 

For each batch: 
- we zero the gradients over the parameters from the last gradient calculation
- insert the batch of text into the model to get predictions
- as PyTorch loss functions cannot handle 3-dimensional predictions we reshape our predictions
- calculate the loss and accuracy between the predicted tags and actual tags
- call `backward` to calculate the gradients of the parameters w.r.t. the loss
- take an optimizer `step` to update the parameters
- add to the running total of loss and accuracy

In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm, trange


writer = SummaryWriter()
n_epochs = 15
global_step = 0  # for writer
for epoch in trange(n_epochs, desc="Epochs"):
    train_loss = 0
    train_acc = 0
    model.train()
    for text, tags in tqdm(train_dataloader, desc="Train", leave=False):
        # YOUR CODE HERE
        # Use model to get prediction and compute loss using criterion.
        # After you've computed loss, zero gradients, run backprop and
        # update model with optimizer.

        train_loss += loss.item()
        train_acc += accuracy(pred, tags).item()
        writer.add_scalar("Training/loss", loss.item(), global_step)
        global_step += 1

    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
    writer.add_scalar("Evaluation/train_loss", train_loss, epoch)
    writer.add_scalar("Evaluation/train_acc", train_acc, epoch)

    val_loss = 0
    val_acc = 0
    model.eval()
    with torch.no_grad():
        for text, tags in tqdm(val_dataloader, desc="Val", leave=False):
            # YOUR CODE HERE
            # Once again compute model prediction and loss, but don't
            # try and update model parameters with it.
            # Just use it for model evaluation.

            val_loss += loss.item()
            val_acc += accuracy(pred, tags).item()

    val_loss /= len(val_dataloader)
    val_acc /= len(val_dataloader)
    writer.add_scalar("Evaluation/val_loss", val_loss, epoch)
    writer.add_scalar("Evaluation/val_acc", val_acc, epoch)

## Inference

We should see validation accuracy around 90%, which looks pretty good. Let's see our model tag some actual sentences!

We define a `tag_sentence` function that will:
- put the model into evaluation mode
- tokenize the sentence if it is not a list
- lowercase the tokens
- numericalize the tokens using the vocabulary
- convert the numericalized tokens into a tensor and add a batch dimension
- feed the tensor into the model
- get the predictions over the sentence
- convert the predictions into readable tags

In [None]:
from nltk.tokenize import WordPunctTokenizer


tokenizer = WordPunctTokenizer()


@torch.no_grad()
def tag_sentence(sent, model):
    if isinstance(sent, str):
        tokens = tokenizer.tokenize(sent.lower())
    else:
        tokens = [token.lower() for token in sent]

    encoded = [word_vocab[token] for token in tokens]
    encoded = torch.tensor(encoded, device=device).unsqueeze(1)

    model.eval()
    pred = model(encoded).squeeze(1)
    pred = pred.argmax(dim=1)
    tag_itos = tag_vocab.get_itos()
    pred_tags = [tag_itos[t.item()] for t in pred]
    return tokens, pred_tags

We'll get an already tokenized example from the validation set and test our model's performance.

In [None]:
example_idx = 5
example_sent, example_tags, _ = val_data[example_idx]
tokens, pred = tag_sentence(example_sent, model)

print("-----------------------------------------------------")
print("Pred. Tag\tActual Tag\tCorrect?\tToken")
print("-----------------------------------------------------")
for token, pred_tag, actual_tag in zip(tokens, pred, example_tags):
    correct = "✔" if pred_tag == actual_tag else "✘"
    print(f"{pred_tag}\t\t{actual_tag}\t\t{correct}\t\t{token}")

We can then check how well it did!

Let's now make up our own sentence and see how well the model does.

In [None]:
sent = "The Queen will deliver a speech about the conflict in North Korea at 1pm tomorrow."
tokens, tags = tag_sentence(sent, model)

print("-------------------------")
print("Pred. Tag\tToken")
print("-------------------------")
for token, tag in zip(tokens, tags):
    print(f"{tag}\t\t{token}")

We've now seen how to implement PoS tagging with PyTorch and torchtext! 

The BiLSTM isn't a state-of-the-art model, in terms of performance, but is a strong baseline for PoS tasks and is a good tool to have in your arsenal.

## Going deeper
What if we could combine word-level and char-level approaches?

![title](https://i.postimg.cc/tT9hsBfj/ive-put-an-rnn-in-your-rnn-so-you-can-train-an-rnn-on-every-step-of-your-rnn-training-loop.jpg)


Actually, we can. Let's use LSTM to generate embedding for every word on char-level.
![title](https://guillaumegenthial.github.io/assets/char_representation.png)
*Image source: https://guillaumegenthial.github.io/sequence-tagging-with-tensorflow.html*

![title](https://guillaumegenthial.github.io/assets/bi-lstm.png)
*Image source: https://guillaumegenthial.github.io/sequence-tagging-with-tensorflow.html*

To do that we need to make few adjustments to the code we've written above.

First of all, we would need a new vocabulary for the chars, which may sound redundant, but our dataset in fact contains quite a lot of non-latin characters, such as brackets and dashes and everything. We even need an `<unk>` token, because, apparently, validation dataset has the `/` character, whilst the training set doesn't.

In [None]:
char_counts = Counter()
for text, _, _ in train_data:
    for word in text:
        char_counts.update(c for c in word)

char_vocab = Vocab(char_counts)
char_vocab.insert_token("<unk>", 0)
char_vocab.set_default_index(0)

char_vocab.append_token("<bos>")
char_vocab.append_token("<eos>")
char_vocab.append_token("<pad>")
print(f"Unique tokens in char vocabulary: {len(char_vocab)}")

Note that we added the `<bos>` and `<eos>` tokens into the vocabulary. That has to do with a fact that we want our model to be able to distinguish the situation when it is in a middle of a word from the situation where it is the last character and we need to prepare the summarization.

Once we're done collecting the vocabulary, we need to modify our `collate_fn`. We will use the old `collate_batch` to preprocess texts and tags for us and do the chars separately. The biggest challenge with chars is the fact that we need to deal with difference in words' lenghts as well as with difference in sentences' lengths and we need to deal with both simultaneously. The `pad_sequence` just won't cut it. For this I haven't found better solution than to pre-compute the resulting tensor size and create it before-hand. After that we do the second pass through our batch and fill our padded tensor with values. The code looks like this:

In [None]:
def collate_batch_with_chars(batch):
    texts_padded, tags_padded = collate_batch(batch)
    max_text_len = texts_padded.shape[0]
    max_word_len = 0
    for text, _, _ in batch:
        for word in text:
            if len(word) > max_word_len:
                max_word_len = len(word)

    max_word_len += 2  # for <bos> and <eos>
    chars_padded = torch.full(
        (max_word_len, max_text_len, len(batch)), fill_value=char_vocab["<pad>"]
    )
    for k, (text, _, _) in enumerate(batch):
        for j, word in enumerate(text):
            chars_padded[: len(word) + 2, j, k] = torch.tensor(
                [char_vocab["<bos>"]] + [char_vocab[c] for c in word] + [char_vocab["<eos>"]]
            )

    return texts_padded, tags_padded, chars_padded


train_dataloader = DataLoader(
    train_data, batch_size, shuffle=True, collate_fn=collate_batch_with_chars
)
val_dataloader = DataLoader(val_data, batch_size, shuffle=True, collate_fn=collate_batch_with_chars)

text, tags, chars = next(iter(train_dataloader))
text.shape, tags.shape, chars.shape

Now that we can load our data, we need to define our model. The model also resembles the `BiLSTMPOSTagger` a lot. However, there're differences, mainly instead of just one `Embedding` (and `LSTM`) layer we now have two, one for chars and one for words. To keep the competition fair, we init the second one with the same GloVe vectors. The `forward` method also gets a little bit more complicated: we now need to process the chars and concatenate the resulting hidden state to the word embeddings. But that's about it. Nothing too hard. Here it goes:

In [None]:
class BiLSTMPOSTaggerWithChars(nn.Module):
    def __init__(
        self,
        n_chars,
        char_emb_dim,
        char_hid_dim,
        word_vectors,
        word_hid_dim,
        n_tags,
        n_layers,
        dropout,
    ):
        super().__init__()

        self.char_embedding = nn.Embedding(n_chars, char_emb_dim, padding_idx=char_vocab["<pad>"])
        self.char_lstm = nn.LSTM(char_emb_dim, char_hid_dim, bidirectional=True)

        self.word_embedding = nn.Embedding.from_pretrained(
            word_vectors, freeze=False, padding_idx=word_vocab["<pad>"]
        )
        self.word_lstm = nn.LSTM(
            word_vectors.shape[1] + 2 * char_hid_dim,
            word_hid_dim,
            n_layers,
            bidirectional=True,
            dropout=dropout,
        )

        self.fc = nn.Linear(2 * word_hid_dim, n_tags)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text, chars):
        # chars has a shape of [word_len, text_len, batch_size]
        # text has a shape of [text_len, batch_size]
        word_len, text_len, batch_size = chars.shape

        # Integrate the text length into batch_size to compute
        # all words with LSTM simultaneously.
        chars = chars.view(word_len, -1)

        chars_embedded = self.char_embedding(chars)
        chars_embedded = self.dropout(chars_embedded)

        # chars_embedded now has a shape of [word_len, text_len * batch_size, char_emb_dim]

        # We take only hidden state, which is the last state for the
        # LSTM pass in both directions.
        _, (chars_hid, _) = self.char_lstm(chars_embedded)

        # chars_hid has a shape of [text_len * batch_size, 2 * char_hid_dim].
        # Let's "unwrap" our texts back:
        chars_hid = chars_hid.view(text_len, batch_size, -1)

        # Now we compute an embedding for the whole word, concatenate the
        # character-based embedding and apply dropout.
        word_embedded = self.word_embedding(text)
        word_embedded = torch.cat([word_embedded, chars_hid], dim=2)
        word_embedded = self.dropout(word_embedded)

        # Now that we have our word_embedded tensor, we are ready to compute
        # the LSTM outputs for it in order to predict POS tags from them.
        outputs, _ = self.word_lstm(word_embedded)

        # Now we apply the dropout and compute our final prediction.
        outputs = self.dropout(outputs)
        predictions = self.fc(outputs)

        return predictions

Let's create our model, apply weights init and compute the number of parameters.

In [None]:
model = BiLSTMPOSTaggerWithChars(
    len(char_vocab),
    char_emb_dim=32,
    char_hid_dim=32,
    word_vectors=word_vectors,
    word_hid_dim=128,
    n_tags=len(tag_vocab),
    n_layers=2,
    dropout=0.25,
).to(device)
model.apply(init_weights);

In [None]:
print(f"The model has {count_parameters(model):,} trainable parameters")

Model has grown a little, but not too much as we set the all the char-related dims to be small.

Finally, let's create the optimizer and train our model!

In [None]:
optimizer = torch.optim.Adam(model.parameters())

Because we did a good work of hiding the implementation details in our model, the training procedure doesn't differ much from the one we used to train the previous model. The only difference really is the third tensor we receive from dataloaders and the second parameter (third if you count `self`) in the `model.__call__` method. That's it!

In [None]:
writer = SummaryWriter()
n_epochs = 15
global_step = 0  # for writer
for epoch in trange(n_epochs, desc="Epochs"):
    train_loss = 0
    train_acc = 0
    model.train()
    for text, tags, chars in tqdm(train_dataloader, desc="Train", leave=False):
        text, tags, chars = text.to(device), tags.to(device), chars.to(device)
        pred = model(text, chars)
        pred = pred.view(-1, pred.shape[-1])
        tags = tags.view(-1)

        loss = criterion(pred, tags)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(pred, tags).item()
        writer.add_scalar("Training/loss", loss.item(), global_step)
        global_step += 1

    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
    writer.add_scalar("Evaluation/train_loss", train_loss, epoch)
    writer.add_scalar("Evaluation/train_acc", train_acc, epoch)

    val_loss = 0
    val_acc = 0
    model.eval()
    with torch.no_grad():
        for text, tags, chars in tqdm(val_dataloader, desc="Val", leave=False):
            text, tags, chars = text.to(device), tags.to(device), chars.to(device)
            pred = model(text, chars)
            pred = pred.view(-1, pred.shape[-1])
            tags = tags.view(-1)

            val_loss += criterion(pred, tags).item()
            val_acc += accuracy(pred, tags).item()

    val_loss /= len(val_dataloader)
    val_acc /= len(val_dataloader)
    writer.add_scalar("Evaluation/val_loss", val_loss, epoch)
    writer.add_scalar("Evaluation/val_acc", val_acc, epoch)

And we can observe the training progress using the same tensorboard. The plots you'll see will probably almost precisely repeat the previous run as the change we introduced is minimal and isn't a deal breaker in any sence, however, it is still cool that we can do something complicated like that with so little code and effort!