In [1]:
!pip install --quiet pytorch-crf

In [2]:
import os
import random
import numpy as np

import functools

import torch
import torch.nn as nn

from torchtext import datasets
from torchtext.data import Field
from torchtext.data import BucketIterator

from torchcrf import CRF


SEED = 241

In [3]:
def seed_everything(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)

  if torch.cuda.is_available(): 
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [4]:
TEXT = Field(lower=True,
             use_vocab=True,
             sequential=True,
             batch_first=True,
             include_lengths=True)

LABEL = Field(lower=True,
              use_vocab=True,
              sequential=True,
              unk_token = None,
              batch_first=True)

In [5]:
fields = [('text', TEXT), ('tags', LABEL)]

train_data, valid_data, test_data = datasets.UDPOS.splits(fields)

en-ud-v2.zip: 100%|██████████| 688k/688k [00:00<00:00, 34.6MB/s]

downloading en-ud-v2.zip
extracting





In [6]:
TEXT.build_vocab(train_data,
                 max_size=25000,
                 vectors='glove.6B.100d')
LABEL.build_vocab(train_data)

.vector_cache/glove.6B.zip: 862MB [06:29, 2.22MB/s]                           
100%|█████████▉| 398570/400000 [00:15<00:00, 24772.01it/s]

In [8]:
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')


train_iterator = BucketIterator.splits((train_data,), batch_size=batch_size, device=device)[0]
valid_iterator = BucketIterator.splits((valid_data,), batch_size=batch_size, device=device)[0]
test_iterator = BucketIterator.splits((test_data,), batch_size=batch_size, device=device)[0]

In [9]:
class BiLSTM_CRF_Tagger(nn.Module):

  def __init__(self, vocab_size, emb_size, hidden_size, n_layers, dropout, num_tags, pad_idx):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=pad_idx)
    self.rnn = nn.LSTM(emb_size,
                       hidden_size, 
                       num_layers=n_layers,
                       dropout=0.3, 
                       bidirectional=True, 
                       batch_first=True)
    
    self.dropout = nn.Dropout(dropout)
    self.hidden2tag = nn.Linear(emb_size, num_tags)
    self.crf = CRF(num_tags, batch_first=True)

  def _generate_mask(self, text_lens):
    bs = text_lens.size(0)
    max_seq_len = torch.max(text_lens).item()
    mask = torch.ByteTensor(bs, max_seq_len).fill_(0)
    for i in range(bs):
      mask[i, :text_lens[i]] = 1
    return mask

  def forward(self, text, text_lens, tags=None):
    text_embed = self.embedding(text)

    text_packed = nn.utils.rnn.pack_padded_sequence(text_embed, text_lens, batch_first=True, enforce_sorted=False)
    rnn_outputs, (last_hidden, cell_state) = self.rnn(text_packed)
    text_unpacked, lens_unpacked = nn.utils.rnn.pad_packed_sequence(text_packed, batch_first=True)
    last_hidden = last_hidden.permute(1, 0, 2)

    emission = self.hidden2tag(text_unpacked)
    mask = self._generate_mask(text_lens).to(device)

    if tags is not None:
      loss = -self.crf.forward(torch.log_softmax(emission, dim=2), tags, mask, reduction='mean')
      return loss
    else:
      prediction = self.crf.decode(emission, mask)
      return prediction

In [10]:
VOCAB_SIZE = len(TEXT.vocab)
EMB_SIZE = 100
HIDDEN_SIZE = 128
N_LAYERS = 2
DROPOUT = 0.3
NUM_TAGS = len(LABEL.vocab)
PAD_IDX = LABEL.vocab.stoi['<pad>']


model = BiLSTM_CRF_Tagger(VOCAB_SIZE, EMB_SIZE, HIDDEN_SIZE, N_LAYERS, DROPOUT, NUM_TAGS, PAD_IDX)
optimizer = torch.optim.Adam(model.parameters())

In [11]:
model.embedding.weight.data.copy_(TEXT.vocab.vectors)

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0382, -0.2449,  0.7281,  ..., -0.1459,  0.8278,  0.2706],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.2634,  0.0742, -0.1081,  ..., -0.2977, -0.5655,  0.5218],
        [ 0.4244,  0.6004, -0.1528,  ...,  0.2536, -0.4969,  0.8964]])

In [12]:
model = model.to(device)

In [13]:
from tqdm import tqdm


for i in range(10):
  model.train()
  error = 0.

  for batch in tqdm(train_iterator):
    optimizer.zero_grad()

    text, lens = batch.text
    lens = lens.cpu()
    tags = batch.tags
    
    loss = model(text, lens, tags)
    loss.backward()

    optimizer.step()
    error += loss.detach().cpu().numpy()
  print('train error', error / len(train_iterator))

  error = 0.
  model.eval()
  with torch.no_grad():
    for batch in tqdm(valid_iterator):
      text, lens = batch.text
      lens = lens.cpu()
      tags = batch.tags
      
      loss = model(text, lens, tags)
      error += loss.detach().cpu().numpy()
    print('valid error', error / len(valid_iterator))


  0%|          | 0/392 [00:00<?, ?it/s][A
  0%|          | 1/392 [00:00<01:52,  3.48it/s][A
  1%|          | 3/392 [00:00<01:25,  4.55it/s][A
  1%|▏         | 5/392 [00:00<01:06,  5.81it/s][A
  2%|▏         | 7/392 [00:00<00:54,  7.01it/s][A
  2%|▏         | 9/392 [00:00<00:46,  8.24it/s][A
  3%|▎         | 11/392 [00:00<00:40,  9.52it/s][A
  3%|▎         | 13/392 [00:01<00:37, 10.00it/s][A
  4%|▍         | 15/392 [00:01<00:34, 10.88it/s][A
  4%|▍         | 17/392 [00:01<00:32, 11.69it/s][A
  5%|▍         | 19/392 [00:01<00:29, 12.84it/s][A
  5%|▌         | 21/392 [00:01<00:30, 12.10it/s][A
  6%|▌         | 23/392 [00:01<00:28, 12.86it/s][A
  6%|▋         | 25/392 [00:01<00:25, 14.17it/s][A
  7%|▋         | 28/392 [00:02<00:23, 15.68it/s][A
  8%|▊         | 30/392 [00:02<00:22, 16.38it/s][A
  8%|▊         | 32/392 [00:02<00:22, 16.06it/s][A
  9%|▊         | 34/392 [00:02<00:27, 13.14it/s][A
  9%|▉         | 36/392 [00:02<00:25, 13.74it/s][A
 10%|▉         | 38/392 [

train error 19.54022534526124



 16%|█▌        | 10/63 [00:00<00:01, 46.19it/s][A
 24%|██▍       | 15/63 [00:00<00:01, 45.62it/s][A
 30%|███       | 19/63 [00:00<00:01, 43.29it/s][A
 38%|███▊      | 24/63 [00:00<00:00, 43.92it/s][A
 46%|████▌     | 29/63 [00:00<00:00, 44.99it/s][A
 54%|█████▍    | 34/63 [00:00<00:00, 45.75it/s][A
 63%|██████▎   | 40/63 [00:00<00:00, 47.90it/s][A
 73%|███████▎  | 46/63 [00:00<00:00, 48.12it/s][A
 81%|████████  | 51/63 [00:01<00:00, 47.35it/s][A
 89%|████████▉ | 56/63 [00:01<00:00, 46.73it/s][A
100%|██████████| 63/63 [00:01<00:00, 46.51it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A
  1%|          | 2/392 [00:00<00:27, 14.09it/s][A

valid error 7.441745409889827



  1%|          | 4/392 [00:00<00:25, 15.00it/s][A
  2%|▏         | 6/392 [00:00<00:25, 15.06it/s][A
  2%|▏         | 8/392 [00:00<00:25, 15.20it/s][A
  3%|▎         | 10/392 [00:00<00:24, 15.65it/s][A
  3%|▎         | 12/392 [00:00<00:23, 16.36it/s][A
  4%|▎         | 14/392 [00:00<00:23, 16.34it/s][A
  4%|▍         | 16/392 [00:01<00:24, 15.63it/s][A
  5%|▍         | 18/392 [00:01<00:23, 15.92it/s][A
  5%|▌         | 20/392 [00:01<00:22, 16.45it/s][A
  6%|▌         | 22/392 [00:01<00:24, 14.81it/s][A
  6%|▌         | 24/392 [00:01<00:27, 13.45it/s][A
  7%|▋         | 26/392 [00:01<00:25, 14.22it/s][A
  7%|▋         | 28/392 [00:01<00:24, 14.92it/s][A
  8%|▊         | 30/392 [00:01<00:24, 14.53it/s][A
  8%|▊         | 32/392 [00:02<00:24, 14.97it/s][A
  9%|▊         | 34/392 [00:02<00:23, 15.06it/s][A
  9%|▉         | 36/392 [00:02<00:24, 14.79it/s][A
 10%|▉         | 38/392 [00:02<00:22, 15.58it/s][A
 10%|█         | 40/392 [00:02<00:21, 16.30it/s][A
 11%|█        

train error 6.3897150219703205



 16%|█▌        | 10/63 [00:00<00:01, 44.00it/s][A
 24%|██▍       | 15/63 [00:00<00:01, 45.12it/s][A
 33%|███▎      | 21/63 [00:00<00:00, 47.55it/s][A
 41%|████▏     | 26/63 [00:00<00:00, 47.63it/s][A
 51%|█████     | 32/63 [00:00<00:00, 48.44it/s][A
 60%|██████    | 38/63 [00:00<00:00, 49.59it/s][A
 68%|██████▊   | 43/63 [00:00<00:00, 49.38it/s][A
 76%|███████▌  | 48/63 [00:01<00:00, 46.22it/s][A
 84%|████████▍ | 53/63 [00:01<00:00, 44.62it/s][A
100%|██████████| 63/63 [00:01<00:00, 47.25it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A
  1%|          | 2/392 [00:00<00:26, 14.72it/s][A

valid error 5.289716913586571



  1%|          | 4/392 [00:00<00:25, 15.31it/s][A
  2%|▏         | 6/392 [00:00<00:25, 14.91it/s][A
  2%|▏         | 8/392 [00:00<00:26, 14.73it/s][A
  3%|▎         | 10/392 [00:00<00:24, 15.70it/s][A
  3%|▎         | 12/392 [00:00<00:22, 16.64it/s][A
  4%|▎         | 14/392 [00:00<00:23, 16.30it/s][A
  4%|▍         | 16/392 [00:01<00:24, 15.47it/s][A
  5%|▍         | 18/392 [00:01<00:23, 16.06it/s][A
  5%|▌         | 20/392 [00:01<00:25, 14.78it/s][A
  6%|▌         | 22/392 [00:01<00:24, 15.08it/s][A
  6%|▌         | 24/392 [00:01<00:23, 15.66it/s][A
  7%|▋         | 26/392 [00:01<00:22, 16.29it/s][A
  7%|▋         | 28/392 [00:01<00:21, 16.65it/s][A
  8%|▊         | 30/392 [00:01<00:20, 17.47it/s][A
  8%|▊         | 32/392 [00:02<00:21, 16.47it/s][A
  9%|▊         | 34/392 [00:02<00:21, 16.94it/s][A
  9%|▉         | 36/392 [00:02<00:22, 16.16it/s][A
 10%|▉         | 38/392 [00:02<00:22, 15.86it/s][A
 10%|█         | 40/392 [00:02<00:23, 15.07it/s][A
 11%|█        

train error 4.406452140637806



 16%|█▌        | 10/63 [00:00<00:01, 47.83it/s][A
 24%|██▍       | 15/63 [00:00<00:01, 47.50it/s][A
 32%|███▏      | 20/63 [00:00<00:00, 47.16it/s][A
 40%|███▉      | 25/63 [00:00<00:00, 46.21it/s][A
 49%|████▉     | 31/63 [00:00<00:00, 49.23it/s][A
 57%|█████▋    | 36/63 [00:00<00:00, 46.38it/s][A
 65%|██████▌   | 41/63 [00:00<00:00, 46.72it/s][A
 73%|███████▎  | 46/63 [00:00<00:00, 47.66it/s][A
 83%|████████▎ | 52/63 [00:01<00:00, 50.53it/s][A
 90%|█████████ | 57/63 [00:01<00:00, 49.68it/s][A
100%|██████████| 63/63 [00:01<00:00, 48.34it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A
  1%|          | 2/392 [00:00<00:28, 13.83it/s][A

valid error 4.669826886010548



  1%|          | 4/392 [00:00<00:26, 14.85it/s][A
  2%|▏         | 6/392 [00:00<00:26, 14.70it/s][A
  2%|▏         | 8/392 [00:00<00:27, 14.04it/s][A
  3%|▎         | 10/392 [00:00<00:25, 14.75it/s][A
  3%|▎         | 12/392 [00:00<00:26, 14.23it/s][A
  4%|▎         | 14/392 [00:00<00:26, 14.05it/s][A
  4%|▍         | 16/392 [00:01<00:27, 13.55it/s][A
  5%|▍         | 18/392 [00:01<00:26, 14.16it/s][A
  5%|▌         | 20/392 [00:01<00:25, 14.41it/s][A
  6%|▌         | 23/392 [00:01<00:23, 15.55it/s][A
  6%|▋         | 25/392 [00:01<00:24, 14.82it/s][A
  7%|▋         | 27/392 [00:01<00:24, 14.95it/s][A
  7%|▋         | 29/392 [00:01<00:23, 15.49it/s][A
  8%|▊         | 31/392 [00:02<00:22, 15.93it/s][A
  8%|▊         | 33/392 [00:02<00:25, 14.18it/s][A
  9%|▉         | 35/392 [00:02<00:23, 14.95it/s][A
  9%|▉         | 37/392 [00:02<00:23, 15.35it/s][A
 10%|▉         | 39/392 [00:02<00:24, 14.46it/s][A
 10%|█         | 41/392 [00:02<00:24, 14.34it/s][A
 11%|█        

train error 3.5502855440183563



 16%|█▌        | 10/63 [00:00<00:01, 43.43it/s][A
 25%|██▌       | 16/63 [00:00<00:01, 45.71it/s][A
 33%|███▎      | 21/63 [00:00<00:00, 46.48it/s][A
 41%|████▏     | 26/63 [00:00<00:00, 47.26it/s][A
 51%|█████     | 32/63 [00:00<00:00, 48.25it/s][A
 59%|█████▊    | 37/63 [00:00<00:00, 47.28it/s][A
 68%|██████▊   | 43/63 [00:00<00:00, 48.95it/s][A
 76%|███████▌  | 48/63 [00:01<00:00, 47.25it/s][A
 86%|████████▌ | 54/63 [00:01<00:00, 48.00it/s][A
100%|██████████| 63/63 [00:01<00:00, 47.61it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A
  1%|          | 2/392 [00:00<00:26, 14.59it/s][A

valid error 4.3419364217727905



  1%|          | 4/392 [00:00<00:30, 12.87it/s][A
  2%|▏         | 6/392 [00:00<00:27, 14.12it/s][A
  2%|▏         | 8/392 [00:00<00:25, 15.24it/s][A
  3%|▎         | 10/392 [00:00<00:24, 15.87it/s][A
  3%|▎         | 12/392 [00:00<00:24, 15.31it/s][A
  4%|▎         | 14/392 [00:00<00:23, 16.33it/s][A
  4%|▍         | 16/392 [00:01<00:23, 15.97it/s][A
  5%|▍         | 18/392 [00:01<00:22, 16.36it/s][A
  5%|▌         | 20/392 [00:01<00:26, 13.81it/s][A
  6%|▌         | 22/392 [00:01<00:25, 14.39it/s][A
  6%|▌         | 24/392 [00:01<00:24, 14.89it/s][A
  7%|▋         | 26/392 [00:01<00:25, 14.30it/s][A
  7%|▋         | 28/392 [00:01<00:25, 14.13it/s][A
  8%|▊         | 30/392 [00:02<00:23, 15.46it/s][A
  8%|▊         | 32/392 [00:02<00:23, 15.30it/s][A
  9%|▊         | 34/392 [00:02<00:23, 14.96it/s][A
  9%|▉         | 36/392 [00:02<00:23, 14.86it/s][A
 10%|▉         | 38/392 [00:02<00:23, 14.85it/s][A
 10%|█         | 40/392 [00:02<00:23, 15.23it/s][A
 11%|█        

train error 3.061753275747202



 16%|█▌        | 10/63 [00:00<00:01, 46.43it/s][A
 24%|██▍       | 15/63 [00:00<00:01, 46.32it/s][A
 32%|███▏      | 20/63 [00:00<00:00, 45.39it/s][A
 40%|███▉      | 25/63 [00:00<00:00, 45.55it/s][A
 48%|████▊     | 30/63 [00:00<00:00, 45.77it/s][A
 54%|█████▍    | 34/63 [00:00<00:00, 43.86it/s][A
 63%|██████▎   | 40/63 [00:00<00:00, 46.17it/s][A
 73%|███████▎  | 46/63 [00:00<00:00, 48.00it/s][A
 83%|████████▎ | 52/63 [00:01<00:00, 49.76it/s][A
 92%|█████████▏| 58/63 [00:01<00:00, 50.47it/s][A
100%|██████████| 63/63 [00:01<00:00, 47.36it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A
  1%|          | 2/392 [00:00<00:26, 14.69it/s][A

valid error 4.170085922120109



  1%|          | 4/392 [00:00<00:27, 13.89it/s][A
  2%|▏         | 6/392 [00:00<00:26, 14.30it/s][A
  2%|▏         | 8/392 [00:00<00:27, 13.99it/s][A
  3%|▎         | 10/392 [00:00<00:26, 14.35it/s][A
  3%|▎         | 12/392 [00:00<00:26, 14.36it/s][A
  4%|▎         | 14/392 [00:01<00:27, 13.87it/s][A
  4%|▍         | 16/392 [00:01<00:28, 13.35it/s][A
  5%|▍         | 18/392 [00:01<00:27, 13.71it/s][A
  5%|▌         | 20/392 [00:01<00:26, 13.87it/s][A
  6%|▌         | 22/392 [00:01<00:25, 14.48it/s][A
  6%|▌         | 24/392 [00:01<00:25, 14.21it/s][A
  7%|▋         | 26/392 [00:01<00:24, 15.11it/s][A
  7%|▋         | 28/392 [00:01<00:23, 15.29it/s][A
  8%|▊         | 30/392 [00:02<00:22, 16.30it/s][A
  8%|▊         | 32/392 [00:02<00:23, 15.56it/s][A
  9%|▊         | 34/392 [00:02<00:23, 15.12it/s][A
  9%|▉         | 36/392 [00:02<00:26, 13.30it/s][A
 10%|▉         | 38/392 [00:02<00:24, 14.27it/s][A
 10%|█         | 40/392 [00:02<00:24, 14.51it/s][A
 11%|█        

train error 2.752590011577217



 16%|█▌        | 10/63 [00:00<00:01, 42.71it/s][A
 24%|██▍       | 15/63 [00:00<00:01, 43.81it/s][A
 32%|███▏      | 20/63 [00:00<00:00, 44.14it/s][A
 40%|███▉      | 25/63 [00:00<00:00, 43.91it/s][A
 49%|████▉     | 31/63 [00:00<00:00, 47.08it/s][A
 57%|█████▋    | 36/63 [00:00<00:00, 44.02it/s][A
 67%|██████▋   | 42/63 [00:00<00:00, 45.86it/s][A
 76%|███████▌  | 48/63 [00:01<00:00, 47.01it/s][A
 86%|████████▌ | 54/63 [00:01<00:00, 50.25it/s][A
100%|██████████| 63/63 [00:01<00:00, 47.16it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A
  0%|          | 1/392 [00:00<00:39,  9.98it/s][A

valid error 4.077122907789927



  1%|          | 3/392 [00:00<00:33, 11.46it/s][A
  1%|▏         | 5/392 [00:00<00:31, 12.10it/s][A
  2%|▏         | 7/392 [00:00<00:30, 12.47it/s][A
  2%|▏         | 9/392 [00:00<00:28, 13.65it/s][A
  3%|▎         | 11/392 [00:00<00:27, 13.96it/s][A
  3%|▎         | 13/392 [00:00<00:27, 13.88it/s][A
  4%|▍         | 15/392 [00:01<00:27, 13.56it/s][A
  4%|▍         | 17/392 [00:01<00:29, 12.62it/s][A
  5%|▍         | 19/392 [00:01<00:29, 12.52it/s][A
  5%|▌         | 21/392 [00:01<00:27, 13.56it/s][A
  6%|▌         | 23/392 [00:01<00:24, 15.01it/s][A
  6%|▋         | 25/392 [00:01<00:23, 15.65it/s][A
  7%|▋         | 27/392 [00:01<00:22, 16.44it/s][A
  7%|▋         | 29/392 [00:01<00:23, 15.61it/s][A
  8%|▊         | 31/392 [00:02<00:24, 15.00it/s][A
  8%|▊         | 33/392 [00:02<00:23, 15.57it/s][A
  9%|▉         | 35/392 [00:02<00:23, 15.31it/s][A
  9%|▉         | 37/392 [00:02<00:22, 15.49it/s][A
 10%|▉         | 39/392 [00:02<00:24, 14.22it/s][A
 10%|█         

train error 2.546892732685926



 17%|█▋        | 11/63 [00:00<00:01, 49.08it/s][A
 27%|██▋       | 17/63 [00:00<00:00, 50.02it/s][A
 35%|███▍      | 22/63 [00:00<00:00, 48.23it/s][A
 44%|████▍     | 28/63 [00:00<00:00, 49.00it/s][A
 54%|█████▍    | 34/63 [00:00<00:00, 51.23it/s][A
 62%|██████▏   | 39/63 [00:00<00:00, 47.93it/s][A
 70%|██████▉   | 44/63 [00:00<00:00, 47.79it/s][A
 78%|███████▊  | 49/63 [00:01<00:00, 48.39it/s][A
 86%|████████▌ | 54/63 [00:01<00:00, 44.28it/s][A
100%|██████████| 63/63 [00:01<00:00, 48.07it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A
  0%|          | 1/392 [00:00<00:49,  7.86it/s][A

valid error 4.008267024206737



  1%|          | 3/392 [00:00<00:44,  8.69it/s][A
  1%|▏         | 5/392 [00:00<00:41,  9.34it/s][A
  2%|▏         | 7/392 [00:00<00:37, 10.38it/s][A
  2%|▏         | 9/392 [00:00<00:32, 11.78it/s][A
  3%|▎         | 11/392 [00:00<00:29, 12.90it/s][A
  3%|▎         | 13/392 [00:00<00:27, 13.56it/s][A
  4%|▍         | 15/392 [00:01<00:28, 13.36it/s][A
  4%|▍         | 17/392 [00:01<00:32, 11.55it/s][A
  5%|▍         | 19/392 [00:01<00:29, 12.75it/s][A
  5%|▌         | 21/392 [00:01<00:28, 13.24it/s][A
  6%|▌         | 23/392 [00:01<00:26, 14.13it/s][A
  7%|▋         | 26/392 [00:01<00:23, 15.59it/s][A
  7%|▋         | 28/392 [00:02<00:22, 15.90it/s][A
  8%|▊         | 30/392 [00:02<00:21, 16.75it/s][A
  8%|▊         | 32/392 [00:02<00:22, 16.19it/s][A
  9%|▊         | 34/392 [00:02<00:21, 16.66it/s][A
  9%|▉         | 36/392 [00:02<00:22, 16.16it/s][A
 10%|▉         | 38/392 [00:02<00:24, 14.34it/s][A
 10%|█         | 40/392 [00:02<00:23, 15.07it/s][A
 11%|█         

train error 2.403647307230502



 17%|█▋        | 11/63 [00:00<00:00, 52.82it/s][A
 24%|██▍       | 15/63 [00:00<00:00, 48.12it/s][A
 32%|███▏      | 20/63 [00:00<00:00, 46.99it/s][A
 41%|████▏     | 26/63 [00:00<00:00, 48.32it/s][A
 51%|█████     | 32/63 [00:00<00:00, 49.62it/s][A
 60%|██████    | 38/63 [00:00<00:00, 50.74it/s][A
 68%|██████▊   | 43/63 [00:00<00:00, 49.81it/s][A
 76%|███████▌  | 48/63 [00:00<00:00, 48.69it/s][A
 84%|████████▍ | 53/63 [00:01<00:00, 46.51it/s][A
 92%|█████████▏| 58/63 [00:01<00:00, 46.67it/s][A
100%|██████████| 63/63 [00:01<00:00, 47.75it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A

valid error 3.9841508789668008



  0%|          | 1/392 [00:00<01:21,  4.78it/s][A
  1%|          | 3/392 [00:00<01:04,  6.04it/s][A
  1%|▏         | 5/392 [00:00<00:52,  7.39it/s][A
  2%|▏         | 7/392 [00:00<00:44,  8.64it/s][A
  2%|▏         | 9/392 [00:00<00:38,  9.95it/s][A
  3%|▎         | 11/392 [00:00<00:32, 11.58it/s][A
  3%|▎         | 13/392 [00:00<00:31, 11.98it/s][A
  4%|▍         | 15/392 [00:01<00:29, 12.88it/s][A
  4%|▍         | 17/392 [00:01<00:28, 13.32it/s][A
  5%|▍         | 19/392 [00:01<00:26, 14.03it/s][A
  5%|▌         | 21/392 [00:01<00:27, 13.39it/s][A
  6%|▌         | 23/392 [00:01<00:26, 13.97it/s][A
  6%|▋         | 25/392 [00:01<00:25, 14.54it/s][A
  7%|▋         | 27/392 [00:01<00:26, 13.64it/s][A
  7%|▋         | 29/392 [00:02<00:25, 14.12it/s][A
  8%|▊         | 31/392 [00:02<00:24, 14.70it/s][A
  8%|▊         | 33/392 [00:02<00:24, 14.71it/s][A
  9%|▉         | 35/392 [00:02<00:22, 15.78it/s][A
  9%|▉         | 37/392 [00:02<00:23, 15.01it/s][A
 10%|▉         |

train error 2.3009074956786875



 17%|█▋        | 11/63 [00:00<00:01, 47.19it/s][A
 25%|██▌       | 16/63 [00:00<00:01, 46.91it/s][A
 33%|███▎      | 21/63 [00:00<00:00, 46.64it/s][A
 41%|████▏     | 26/63 [00:00<00:00, 46.98it/s][A
 51%|█████     | 32/63 [00:00<00:00, 48.68it/s][A
 59%|█████▊    | 37/63 [00:00<00:00, 47.94it/s][A
 67%|██████▋   | 42/63 [00:00<00:00, 47.90it/s][A
 75%|███████▍  | 47/63 [00:01<00:00, 43.51it/s][A
 83%|████████▎ | 52/63 [00:01<00:00, 43.09it/s][A
 90%|█████████ | 57/63 [00:01<00:00, 41.76it/s][A
100%|██████████| 63/63 [00:01<00:00, 45.11it/s]

  0%|          | 0/392 [00:00<?, ?it/s][A
  1%|          | 2/392 [00:00<00:31, 12.50it/s][A

valid error 3.9766852855682373



  1%|          | 4/392 [00:00<00:28, 13.76it/s][A
  2%|▏         | 6/392 [00:00<00:27, 14.13it/s][A
  2%|▏         | 8/392 [00:00<00:27, 14.17it/s][A
  3%|▎         | 10/392 [00:00<00:26, 14.40it/s][A
  3%|▎         | 12/392 [00:00<00:24, 15.25it/s][A
  4%|▎         | 14/392 [00:00<00:24, 15.68it/s][A
  4%|▍         | 16/392 [00:01<00:23, 15.75it/s][A
  5%|▍         | 18/392 [00:01<00:24, 15.27it/s][A
  5%|▌         | 20/392 [00:01<00:24, 15.32it/s][A
  6%|▌         | 22/392 [00:01<00:24, 14.97it/s][A
  6%|▌         | 24/392 [00:01<00:23, 15.43it/s][A
  7%|▋         | 26/392 [00:01<00:23, 15.83it/s][A
  7%|▋         | 28/392 [00:01<00:22, 16.53it/s][A
  8%|▊         | 30/392 [00:01<00:21, 16.96it/s][A
  8%|▊         | 32/392 [00:02<00:20, 17.39it/s][A
  9%|▊         | 34/392 [00:02<00:23, 15.55it/s][A
  9%|▉         | 36/392 [00:02<00:22, 16.17it/s][A
 10%|▉         | 38/392 [00:02<00:21, 16.52it/s][A
 10%|█         | 40/392 [00:02<00:23, 14.67it/s][A
 11%|█        

train error 2.22788618900338



 16%|█▌        | 10/63 [00:00<00:01, 44.27it/s][A
 24%|██▍       | 15/63 [00:00<00:01, 45.33it/s][A
 32%|███▏      | 20/63 [00:00<00:00, 46.62it/s][A
 41%|████▏     | 26/63 [00:00<00:00, 48.97it/s][A
 49%|████▉     | 31/63 [00:00<00:00, 48.48it/s][A
 57%|█████▋    | 36/63 [00:00<00:00, 47.56it/s][A
 65%|██████▌   | 41/63 [00:00<00:00, 46.18it/s][A
 73%|███████▎  | 46/63 [00:00<00:00, 46.45it/s][A
 83%|████████▎ | 52/63 [00:01<00:00, 47.39it/s][A
 90%|█████████ | 57/63 [00:01<00:00, 47.46it/s][A
100%|██████████| 63/63 [00:01<00:00, 47.58it/s]

valid error 3.963253762986925





In [14]:
def calculate_accuracy(y_true, y_pred):
  assert y_true.shape == y_pred.shape
  assert len(y_true.shape) == 1
  y_true = y_true[1:]
  y_pred = y_pred[1:]
  return (y_true == y_pred).sum() / y_true.shape[0]

y_true_test = np.array([-1, 1, 2, 3])
y_pred_test = np.array([-1, 1, 0, 3])

calculate_accuracy(y_true_test, y_pred_test)

0.6666666666666666

In [20]:
total_true_labels = []
total_pred_labels = []

for index in range(len(test_data.examples)):

  text = test_data.examples[index].text
  true_labels = test_data.examples[index].tags

  with torch.no_grad():
    tokens = text
    ids = [TEXT.vocab.stoi[token] for token in tokens]
    ids_tensor = torch.tensor([ids], device=device)
    lens = torch.tensor([len(ids)])
    prediction = model(ids_tensor, lens)
    
  print('\t'.join(tokens))
  print('\t'.join(true_labels))
  print('\t'.join([LABEL.vocab.itos[p] for p in prediction[0]]))
  print('='*20)

  total_true_labels.extend(np.array([LABEL.vocab.itos[p] for p in prediction[0]]))
  total_pred_labels.extend(np.array(true_labels))

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
pairing	up	of	harry	and	ginny	is	cool
noun	noun	adp	propn	cconj	propn	aux	adj
verb	adp	adp	propn	cconj	noun	aux	adj
it	s	awesome	-	better	than	book	5
pron	aux	adj	punct	adj	adp	noun	num
pron	aux	adj	punct	adj	sconj	verb	num
an	orphaned	,	two	-	month	old	african	elephant	named	olly	received	an	extremely	uplifting	christmas	present	this	year	:	an	airplane	ride	just	for	him	,	courtesy	of	the	international	fund	for	animal	welfare	and	their	friends	'	the	bateleurs	'	.
det	verb	punct	num	punct	noun	adj	adj	noun	verb	propn	verb	det	adv	adj	propn	noun	det	noun	punct	det	noun	noun	adv	adp	pron	punct	noun	adp	det	propn	propn	adp	propn	propn	cconj	pron	noun	punct	det	propn	punct	punct
det	noun	punct	num	punct	noun	adj	adj	noun	verb	part	verb	det	adv	propn	propn	verb	det	noun	punct	det	noun	verb	adv	adp	pron	punct	noun	adp	det	adj	noun	adp	noun	noun	cconj	pron	noun	punct	det	noun	punct	punct
animal	news	center	webmas

In [22]:
calculate_accuracy(np.array(total_true_labels), np.array(total_pred_labels))

0.902414727446605

In [21]:
from sklearn.metrics import classification_report


print(classification_report(np.array(total_true_labels), np.array(total_pred_labels)))

              precision    recall  f1-score   support

         adj       0.86      0.87      0.86      1688
         adp       0.96      0.92      0.94      2119
         adv       0.85      0.89      0.87      1164
         aux       0.97      0.91      0.94      1599
       cconj       0.99      0.99      0.99       734
         det       0.99      0.98      0.98      1908
        intj       0.78      0.95      0.85        98
        noun       0.89      0.84      0.87      4380
         num       0.62      0.94      0.75       354
        part       0.96      0.90      0.93       670
        pron       0.96      0.98      0.97      2129
       propn       0.73      0.76      0.74      1981
       punct       0.99      1.00      0.99      3100
       sconj       0.74      0.80      0.77       357
         sym       0.78      0.99      0.87        73
        verb       0.89      0.89      0.89      2651
           x       0.19      0.29      0.23        92

    accuracy              