## Loading Utils

In [None]:
!pip install --upgrade nltk sentencepiece svgling torch tqdm

from copy import deepcopy
import json
import math
import random

import matplotlib.pyplot as plt
import numpy as np
import sentencepiece
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import tqdm.notebook

import nltk
from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader

import svgling
svgling.disable_nltk_png()

In [2]:
assert torch.cuda.is_available()
device = torch.device("cuda")
print("Using device:", device)

Using device: cuda


## Data Loading and Processing

In [3]:
%%bash
if [ ! -e parsing-data.zip ]; then
  wget --quiet https://storage.googleapis.com/cs288-parsing-project/parsing-data.zip
fi
rm -rf train dev test EVALB/
unzip parsing-data.zip

Archive:  parsing-data.zip
  inflating: train                   
  inflating: dev                     
  inflating: test                    
   creating: EVALB/
  inflating: EVALB/.DS_Store         
   creating: EVALB/bug/
  inflating: EVALB/bug/bug.gld       
  inflating: EVALB/bug/bug.rsl-new   
  inflating: EVALB/bug/bug.rsl-old   
  inflating: EVALB/bug/bug.tst       
  inflating: EVALB/COLLINS.prm       
  inflating: EVALB/evalb.c           
  inflating: EVALB/LICENSE           
  inflating: EVALB/Makefile          
  inflating: EVALB/new.prm           
  inflating: EVALB/nk.prm            
  inflating: EVALB/README            
   creating: EVALB/sample/
  inflating: EVALB/sample/sample.gld  
  inflating: EVALB/sample/sample.prm  
  inflating: EVALB/sample/sample.rsl  
  inflating: EVALB/sample/sample.tst  
  inflating: EVALB/tgrep_proc.prl    


In [4]:
!head -n 2 train

(TOP (S (PP (IN In) (NP (NP (DT an) (NNP Oct.) (CD 19) (NN review)) (PP (IN of) (NP (`` ``) (NP (DT The) (NN Misanthrope)) ('' '') (PP (IN at) (NP (NP (NNP Chicago) (POS 's)) (NNP Goodman) (NNP Theatre))))) (PRN (-LRB- -LRB-) (`` ``) (S (NP (VBN Revitalized) (NNS Classics)) (VP (VBP Take) (NP (DT the) (NN Stage)) (PP (IN in) (NP (NNP Windy) (NNP City))))) (, ,) ('' '') (NP (NN Leisure) (CC &) (NNS Arts)) (-RRB- -RRB-)))) (, ,) (NP (NP (NP (DT the) (NN role)) (PP (IN of) (NP (NNP Celimene)))) (, ,) (VP (VBN played) (PP (IN by) (NP (NNP Kim) (NNP Cattrall)))) (, ,)) (VP (VBD was) (VP (ADVP (RB mistakenly)) (VBN attributed) (PP (TO to) (NP (NNP Christina) (NNP Haag))))) (. .)))
(TOP (S (NP (NNP Ms.) (NNP Haag)) (VP (VBZ plays) (NP (NNP Elianti))) (. .)))


In [5]:
READER = BracketParseCorpusReader('.', ['train', 'dev', 'test'])

In [6]:
READER.sents('train')[1]

['Ms.', 'Haag', 'plays', 'Elianti', '.']

In [7]:
with open('sentences.txt', 'w') as f:
  for sent in READER.sents('train'):
    f.write(' '.join(sent) + '\n')

In [8]:
!head -n 2 sentences.txt

In an Oct. 19 review of `` The Misanthrope '' at Chicago 's Goodman Theatre -LRB- `` Revitalized Classics Take the Stage in Windy City , '' Leisure & Arts -RRB- , the role of Celimene , played by Kim Cattrall , was mistakenly attributed to Christina Haag .
Ms. Haag plays Elianti .


In [9]:
args = {
    "pad_id": 0,
    "bos_id": 1,
    "eos_id": 2,
    "unk_id": 3,
    "input": "sentences.txt",
    "vocab_size": 16000,
    "model_prefix": "ptb",
}
combined_args = " ".join(
    "--{}={}".format(key, value) for key, value in args.items())
sentencepiece.SentencePieceTrainer.Train(combined_args)

In [10]:
!head -n 10 ptb.vocab

<pad>	0
<s>	0
</s>	0
<unk>	0
s	-2.85521
▁,	-3.27876
▁the	-3.38611
▁.	-3.51129
▁	-3.69929
▁to	-4.02176


In [11]:
VOCAB = sentencepiece.SentencePieceProcessor()
VOCAB.Load("ptb.model")

True

In [12]:
PAD_ID = VOCAB.PieceToId("<pad>")
BOS_ID = VOCAB.PieceToId("<s>")
EOS_ID = VOCAB.PieceToId("</s>")
UNK_ID = VOCAB.PieceToId("<unk>")

## Part-of-Speech Tagging: Task Setup

In [13]:
def encode_sentence(sent):

  ids = []
  is_word_end = []
  for word in sent:
    word_ids = VOCAB.EncodeAsIds(word)
    ids.extend(word_ids)
    is_word_end.extend([False] * (len(word_ids)-1) + [True])
  return ids, is_word_end


In [14]:
print("Vocabulary size:", VOCAB.GetPieceSize())
print()

for sent in READER.sents('train')[:2]:
  indices, is_word_end = encode_sentence(sent)
  pieces = [VOCAB.IdToPiece(index) for index in indices]
  print(sent)
  print(pieces)
  print(VOCAB.DecodePieces(pieces))
  print(indices)
  print(VOCAB.DecodeIds(indices))
  print()

Vocabulary size: 16000

['In', 'an', 'Oct.', '19', 'review', 'of', '``', 'The', 'Misanthrope', "''", 'at', 'Chicago', "'s", 'Goodman', 'Theatre', '-LRB-', '``', 'Revitalized', 'Classics', 'Take', 'the', 'Stage', 'in', 'Windy', 'City', ',', "''", 'Leisure', '&', 'Arts', '-RRB-', ',', 'the', 'role', 'of', 'Celimene', ',', 'played', 'by', 'Kim', 'Cattrall', ',', 'was', 'mistakenly', 'attributed', 'to', 'Christina', 'Haag', '.']
['▁In', '▁an', '▁Oct', '.', '▁19', '▁review', '▁of', '▁``', '▁The', '▁Mi', 's', 'anthrop', 'e', "▁''", '▁at', '▁Chicago', "▁'", 's', '▁Good', 'man', '▁The', 'at', 're', '▁-', 'L', 'RB', '-', '▁``', '▁Rev', 'ital', 'ized', '▁Classic', 's', '▁Take', '▁the', '▁St', 'age', '▁in', '▁Wind', 'y', '▁City', '▁,', "▁''", '▁L', 'eisure', '▁', '&', '▁Art', 's', '▁-', 'R', 'RB', '-', '▁,', '▁the', '▁role', '▁of', '▁Cel', 'imene', '▁,', '▁play', 'ed', '▁by', '▁Kim', '▁Ca', 't', 't', 'rall', '▁,', '▁was', '▁mistaken', 'ly', '▁attribute', 'd', '▁to', '▁Christin', 'a', '▁Haag', '▁.

In [15]:
READER.tagged_sents('train')[1]

[('Ms.', 'NNP'),
 ('Haag', 'NNP'),
 ('plays', 'VBZ'),
 ('Elianti', 'NNP'),
 ('.', '.')]

In [16]:
nltk.download('tagsets')
nltk.help.upenn_tagset()

$: dollar
    $ -$ --$ A$ C$ HK$ M$ NZ$ S$ U.S.$ US$
'': closing quotation mark
    ' ''
(: opening parenthesis
    ( [ {
): closing parenthesis
    ) ] }
,: comma
    ,
--: dash
    --
.: sentence terminator
    . ! ?
:: colon or ellipsis
    : ; ...
CC: conjunction, coordinating
    & 'n and both but either et for less minus neither nor or plus so
    therefore times v. versus vs. whether yet
CD: numeral, cardinal
    mid-1890 nine-thirty forty-two one-tenth ten million 0.5 one forty-
    seven 1987 twenty '79 zero two 78-degrees eighty-four IX '60s .025
    fifteen 271,124 dozen quintillion DM2,000 ...
DT: determiner
    all an another any both del each either every half la many much nary
    neither no some such that the them these this those
EX: existential there
    there
FW: foreign word
    gemeinschaft hund ich jeux habeas Haementeria Herr K'ang-si vous
    lutihaw alai je jour objets salutaris fille quibusdam pas trop Monte
    terram fiche oui corporis ...
IN: preposition or

[nltk_data] Downloading package tagsets to /root/nltk_data...
[nltk_data]   Unzipping help/tagsets.zip.


In [17]:
def get_pos_vocab():
  all_pos = set()
  for sent in READER.tagged_sents('train'):
    for word, pos in sent:
      all_pos.add(pos)
  return sorted(all_pos)

PARTS_OF_SPEECH = get_pos_vocab()
print(PARTS_OF_SPEECH)

['#', '$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', '``']


In [18]:
class POSTaggingDataset(torch.utils.data.Dataset):
  def __init__(self, split):
    assert split in ('train', 'dev', 'test')
    self.sents = READER.tagged_sents(split)
    if split == 'train':
      # To speed up training, we only train on short sentences.
      self.sents = [sent for sent in self.sents if len(sent) <= 40]

  def __len__(self):
    return len(self.sents)

  def __getitem__(self, index):
    sent = self.sents[index]
    ids, is_word_end = encode_sentence([word for word, pos in sent])
    ids = [BOS_ID] + ids + [EOS_ID]
    is_word_end = [False] + is_word_end + [False]
    ids = torch.tensor(ids)
    is_word_end = torch.tensor(is_word_end)
    labels = torch.full_like(ids, -1)
    labels[is_word_end] = torch.tensor(
        [PARTS_OF_SPEECH.index(pos) for word, pos in sent])
    return {'ids': ids, 'labels': labels}

  @staticmethod
  def collate(batch):
    ids = pad_sequence(
        [item['ids'] for item in batch],
        batch_first=True, padding_value=PAD_ID)
    labels = pad_sequence(
        [item['labels'] for item in batch],
        batch_first=True, padding_value=-1)
    return {'ids': ids.to(device), 'labels': labels.to(device)}

In [19]:
dataset_for_inspection = POSTaggingDataset('train')
datum = dataset_for_inspection[0]
datum

{'ids': tensor([   1,  126,    4,   14, 9343,  711,    4, 4388, 8356,    7,    2]),
 'labels': tensor([-1, -1, -1, 20, 20, -1, 39, -1, 20,  6, -1])}

In [20]:
for i, (piece_id, label) in enumerate(zip(datum['ids'].tolist(),
                                          datum['labels'].tolist())):
  print('{:2d} {: <5} {}'.format(
      i, "-" if label == -1 else PARTS_OF_SPEECH[label],
      VOCAB.IdToPiece(piece_id)))

 0 -     <s>
 1 -     ▁M
 2 -     s
 3 NNP   .
 4 NNP   ▁Haag
 5 -     ▁play
 6 VBZ   s
 7 -     ▁Eli
 8 NNP   anti
 9 .     ▁.
10 -     </s>


In [21]:
data_loader_for_inspection = torch.utils.data.DataLoader(
    dataset_for_inspection, batch_size=2, shuffle=True,
    collate_fn=dataset_for_inspection.collate)
next(iter(data_loader_for_inspection))

{'ids': tensor([[    1,    77,  1321,   148,   214,    41,    49,    46,    17,    11,
           1143,    24,    17,    11,    98,    20,     4,    84,    64,   943,
              8,   213,   115,     7,     2,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [    1,  1386,  1697,    21,    20,     4,   690,    31,   304,    19,
           5922,     8, 11195,     4,  2650,     4,   978,     8,  6136,    10,
              6, 10073,   862,     5,    63,  2390,     4,   231,   482,     4,
              9,   831,  1451,    17,    18,   420,     4,    82,    29,  1194,
             53,    11,    73,  6705,  3941,     9, 13148,    17,    18,  3368,
            139,  3566,    13,    16,   358,     7,     2]], device='cuda:0'),
 'labels': tensor([[-1,  8, -1, 14, 22, 35, -1, -1, 27

## Baseline POS Tagging Model

In [22]:
class POSTaggingModel(nn.Module):
  def encode(self, batch):
    raise NotImplementedError()

  def compute_loss(self, batch):
    logits = self.encode(batch)
    logits = logits.reshape((-1,logits.shape[-1]))
    labels = batch['labels'].reshape((-1,))
    res = F.cross_entropy(logits, labels, ignore_index=-1, reduction='mean')
    return res

  def get_validation_metric(self, batch_size=8):
    dataset = POSTaggingDataset('dev')
    data_loader = torch.utils.data.DataLoader(
      dataset, batch_size=batch_size, collate_fn=dataset.collate)
    self.eval()
    correct = 0
    total = 0
    with torch.no_grad():
      for batch in data_loader:
        mask = (batch['labels'] != -1)
        predicted_labels = self.encode(batch).argmax(-1)
        predicted_labels = predicted_labels[mask]
        gold_labels = batch['labels'][mask]
        correct += (predicted_labels == gold_labels).sum().item()
        total += gold_labels.shape[0]
    return correct / total

In [23]:
def train(model, num_epochs, batch_size, model_file,
          learning_rate=8e-4, dataset_cls=POSTaggingDataset):
  """Train the model and save its best checkpoint.

  Model performance across epochs is evaluated on the validation set. The best
  checkpoint obtained during training will be stored on disk and loaded back
  into the model at the end of training.
  """
  dataset = dataset_cls('train')
  data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate)
  optimizer = torch.optim.Adam(
      model.parameters(),
      lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)
  scheduler = torch.optim.lr_scheduler.OneCycleLR(
      optimizer,
      learning_rate,
      epochs=num_epochs,
      steps_per_epoch=len(data_loader),
      pct_start=0.02,  # Warm up for 2% of the total training time
      )
  best_metric = 0.0
  for epoch in tqdm.notebook.trange(num_epochs, desc="training", unit="epoch"):
    with tqdm.notebook.tqdm(
        data_loader,
        desc="epoch {}".format(epoch + 1),
        unit="batch",
        total=len(data_loader)) as batch_iterator:
      model.train()
      total_loss = 0.0
      for i, batch in enumerate(batch_iterator, start=1):
        optimizer.zero_grad()
        loss = model.compute_loss(batch)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
        batch_iterator.set_postfix(mean_loss=total_loss / i)
      validation_metric = model.get_validation_metric()
      batch_iterator.set_postfix(
          mean_loss=total_loss / i,
          validation_metric=validation_metric)
      if validation_metric > best_metric:
        print(
            "Obtained a new best validation metric of {:.3f}, saving model "
            "checkpoint to {}...".format(validation_metric, model_file))
        torch.save(model.state_dict(), model_file)
        best_metric = validation_metric
  print("Reloading best model checkpoint from {}...".format(model_file))
  model.load_state_dict(torch.load(model_file))

Baseline Model: A classifier where each word is assigned its most-frequent tag from the training data.

In [24]:
class BaselineModel(POSTaggingModel):
  def __init__(self):
    super().__init__()
    self.lookup = nn.Embedding(VOCAB.GetPieceSize(), len(PARTS_OF_SPEECH))

  def encode(self, batch):
    ids = batch['ids']
    return self.lookup(ids)

In [25]:
baseline_model = BaselineModel().to(device)
train(baseline_model, num_epochs=5, batch_size=64,
      model_file="baseline_model.pt", learning_rate=0.1)

training:   0%|          | 0/5 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/575 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.870, saving model checkpoint to baseline_model.pt...


epoch 2:   0%|          | 0/575 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.873, saving model checkpoint to baseline_model.pt...


epoch 3:   0%|          | 0/575 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.875, saving model checkpoint to baseline_model.pt...


epoch 4:   0%|          | 0/575 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.875, saving model checkpoint to baseline_model.pt...


epoch 5:   0%|          | 0/575 [00:00<?, ?batch/s]

Reloading best model checkpoint from baseline_model.pt...


In [26]:
def predict_tags(tagging_model, split, limit=None):
  assert split in ('dev', 'test')
  sents = READER.sents(split)
  dataset = POSTaggingDataset(split)
  data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=8, shuffle=False, collate_fn=dataset.collate)
  tagging_model.eval()
  pred_tagged_sents = []
  with torch.no_grad():
    for batch in data_loader:
      mask = (batch['labels'] != -1)
      predicted_labels = tagging_model.encode(batch).argmax(-1)
      for i in range(batch['ids'].shape[0]):
        example_predicted_tags = [
            PARTS_OF_SPEECH[label] for label in predicted_labels[i][mask[i]]]
        sent = sents[len(pred_tagged_sents)]
        assert len(sent) == len(example_predicted_tags)
        pred_tagged_sents.append(list(zip(sent, example_predicted_tags)))
        if limit is not None and len(pred_tagged_sents) >= limit:
          return pred_tagged_sents
  return pred_tagged_sents

In [27]:
predict_tags(baseline_model, 'dev', limit=1)

[[('Influential', 'JJ'),
  ('members', 'NNS'),
  ('of', 'IN'),
  ('the', 'DT'),
  ('House', 'NNP'),
  ('Ways', 'NNS'),
  ('and', 'CC'),
  ('Means', 'NNS'),
  ('Committee', 'NNP'),
  ('introduced', 'VBD'),
  ('legislation', 'NNP'),
  ('that', 'IN'),
  ('would', 'MD'),
  ('restrict', 'VB'),
  ('how', 'WRB'),
  ('the', 'DT'),
  ('new', 'JJ'),
  ('savings-and-loan', 'NNP'),
  ('bailout', 'NN'),
  ('agency', 'NN'),
  ('can', 'MD'),
  ('raise', 'VB'),
  ('capital', 'NN'),
  (',', ','),
  ('creating', 'VBG'),
  ('another', 'DT'),
  ('potential', 'JJ'),
  ('obstacle', 'NN'),
  ('to', 'TO'),
  ('the', 'DT'),
  ('government', 'NN'),
  ("'s", 'NNS'),
  ('sale', 'NN'),
  ('of', 'IN'),
  ('sick', 'JJ'),
  ('thrifts', 'NNS'),
  ('.', '.')]]

## Transformer POS Tagging Model

In [28]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model = 256, n_head = 4, d_qkv = 32, dropout=0.1, **kwargs):
    super().__init__()
    self.d_model = d_model
    self.n_head = n_head
    self.d_qkv = d_qkv

    self.w_q = nn.Parameter(torch.Tensor(n_head, d_model, d_qkv))
    self.w_k = nn.Parameter(torch.Tensor(n_head, d_model, d_qkv))
    self.w_v = nn.Parameter(torch.Tensor(n_head, d_model, d_qkv))
    self.w_o = nn.Parameter(torch.Tensor(n_head, d_qkv, d_model))

    nn.init.xavier_normal_(self.w_q)
    nn.init.xavier_normal_(self.w_k)
    nn.init.xavier_normal_(self.w_v)
    nn.init.xavier_normal_(self.w_o)

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    """
    Args:
      x: the input to the layer, a tensor of shape [batch size, length, d_model]
    """

    batch_size = x.size(0)
    seq_len = x.size(1)

    q = torch.einsum('blm,hmq->bhlq', x, self.w_q)
    k = torch.einsum('blm,hmq->bhlq', x, self.w_k)
    v = torch.einsum('blm,hmq->bhlq', x, self.w_v)

    attention_scores = torch.einsum('bhiq, bhjq->bhij', q, k) / math.sqrt(self.d_qkv) # B x N_H x L x L
    mask = mask.unsqueeze(1)
    mask = mask.unsqueeze(1)
    attention_scores = attention_scores.masked_fill(mask == 0, -1e10)
    attention_weights = F.softmax(attention_scores, dim = -1)
    attention_weights = self.dropout(attention_weights)
    attention_weights = attention_weights.reshape(batch_size * self.n_head, attention_weights.size(2), attention_weights.size(3))
    v = v.reshape(batch_size * self.n_head, v.size(2), v.size(3))
    z_n = attention_weights.bmm(v) # B * N_H x L x d_qkv
    z_n = z_n.reshape(batch_size, self.n_head, z_n.size(1), z_n.size(2))
    z = torch.einsum('bhlq, hqm->blm', z_n, self.w_o)
    z = self.dropout(z)
    return z

In [29]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self, d_model, d_ff, dropout=0.1):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(d_model, d_ff),
        nn.ReLU(),
        nn.Linear(d_ff, d_model),
        nn.Dropout(dropout),
    )

  def forward(self, x):
    return self.net(x)

In [30]:
class TransformerEncoder(nn.Module):
  def __init__(self, d_model = 256, d_ff = 1024, n_layers = 4, n_head = 4, d_qkv = 32, dropout = 0.1):
    super().__init__()
    self.n_layers = n_layers
    self.dropout = nn.Dropout(dropout)
    self.multi_head_attentions = nn.ModuleList([MultiHeadAttention(d_model, n_head, d_qkv, dropout) for i in range(n_layers)])
    self.ff_layers = nn.ModuleList([PositionwiseFeedForward(d_model, d_ff, dropout) for i in range(n_layers)])
    self.ln_1 = nn.ModuleList([nn.LayerNorm(d_model) for i in range(n_layers)])
    self.ln_2 = nn.ModuleList([nn.LayerNorm(d_model) for i in range(n_layers)])

  def forward(self, x, mask):
    for i in range(self.n_layers):
      x = x + self.multi_head_attentions[i](x, mask)
      x = self.ln_1[i](x)
      x = x + self.ff_layers[i](x)
      x = self.ln_2[i](x)

    return x

In [31]:
class AddPositionalEncoding(nn.Module):
  def __init__(self, d_model=256, input_dropout=0.1, timing_dropout=0.1,
               max_len=512):
    super().__init__()
    self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model))
    nn.init.normal_(self.timing_table)
    self.input_dropout = nn.Dropout(input_dropout)
    self.timing_dropout = nn.Dropout(timing_dropout)

  def forward(self, x):
    """
    Args:
      x: A tensor of shape [batch size, length, d_model]
    """
    x = self.input_dropout(x)
    timing = self.timing_table[None, :x.shape[1], :]
    timing = self.timing_dropout(timing)
    return x + timing

In [32]:
class TransformerPOSTaggingModel(POSTaggingModel):
  def __init__(self):
    super().__init__()
    d_model = 256
    self.add_timing = AddPositionalEncoding(d_model)
    self.encoder = TransformerEncoder(d_model)
    self.final_linear = nn.Linear(d_model, 16000)
    self.final_ln = nn.LayerNorm(d_model)

  def encode(self, batch):
    """
    Args:
      batch: an input batch as a dictionary; the key 'ids' holds the vocab ids
        of the subword tokens in a tensor of size [batch_size, sequence_length]
    """
    x = batch['ids']
    mask = (x != PAD_ID)
    embeds = F.embedding(x, self.final_linear.weight)
    embeds = self.add_timing(embeds)
    encoder_output = self.final_ln(self.encoder(embeds, mask))
    output = self.final_linear(encoder_output)
    output = F.log_softmax(output, dim = -1)
    return output

In [33]:
num_epochs = 8
batch_size = 16

tagging_model = TransformerPOSTaggingModel().to(device)
train(tagging_model, num_epochs, batch_size, "tagging_model.pt")

training:   0%|          | 0/8 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/2298 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.882, saving model checkpoint to tagging_model.pt...


epoch 2:   0%|          | 0/2298 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.931, saving model checkpoint to tagging_model.pt...


epoch 3:   0%|          | 0/2298 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.944, saving model checkpoint to tagging_model.pt...


epoch 4:   0%|          | 0/2298 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.952, saving model checkpoint to tagging_model.pt...


epoch 5:   0%|          | 0/2298 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.958, saving model checkpoint to tagging_model.pt...


epoch 6:   0%|          | 0/2298 [00:00<?, ?batch/s]

Obtained a new best validation metric of 0.960, saving model checkpoint to tagging_model.pt...


epoch 7:   0%|          | 0/2298 [00:00<?, ?batch/s]

epoch 8:   0%|          | 0/2298 [00:00<?, ?batch/s]

Reloading best model checkpoint from tagging_model.pt...


In [34]:
predict_tags(tagging_model, 'dev', limit=1)

[[('Influential', 'JJ'),
  ('members', 'NNS'),
  ('of', 'IN'),
  ('the', 'DT'),
  ('House', 'NNP'),
  ('Ways', 'NNPS'),
  ('and', 'CC'),
  ('Means', 'NNP'),
  ('Committee', 'NNP'),
  ('introduced', 'VBD'),
  ('legislation', 'NN'),
  ('that', 'WDT'),
  ('would', 'MD'),
  ('restrict', 'VB'),
  ('how', 'WRB'),
  ('the', 'DT'),
  ('new', 'JJ'),
  ('savings-and-loan', 'JJ'),
  ('bailout', 'NN'),
  ('agency', 'NN'),
  ('can', 'MD'),
  ('raise', 'VB'),
  ('capital', 'NN'),
  (',', ','),
  ('creating', 'VBG'),
  ('another', 'DT'),
  ('potential', 'JJ'),
  ('obstacle', 'NN'),
  ('to', 'TO'),
  ('the', 'DT'),
  ('government', 'NN'),
  ("'s", 'POS'),
  ('sale', 'NN'),
  ('of', 'IN'),
  ('sick', 'JJ'),
  ('thrifts', 'NNS'),
  ('.', '.')]]