In [1]:
from itertools import chain
from typing import Dict
 
import numpy as np
import torch
import torch.optim as optim
 
from allennlp.data.data_loaders import MultiProcessDataLoader
from allennlp.data.samplers import BucketBatchSampler
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.training import GradientDescentTrainer
from allennlp_models.structured_prediction.dataset_readers.universal_dependencies import UniversalDependenciesDatasetReader


In [2]:
from allennlp.common import JsonDict
from allennlp.data import DatasetReader, Instance
from allennlp.models import Model
from allennlp.predictors import Predictor
from overrides import overrides
from typing import List

@Predictor.register("universal_pos_predictor")
class UniversalPOSPredictor(Predictor):
    def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
        super().__init__(model, dataset_reader)

    def predict(self, words: List[str]) -> JsonDict:
        return self.predict_json({"words" : words})

    @overrides
    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        words = json_dict["words"]
        # This is a hack - the second argument to text_to_instance is a list of POS tags
        # that has the same length as words. We don't need it for prediction so
        # just pass words.
        return self._dataset_reader.text_to_instance(words, words)

In [3]:
class LstmTagger(Model):
    def __init__(self,
                 embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.embedder = embedder
        self.encoder = encoder
        self.linear = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                      out_features=vocab.get_vocab_size('pos'))
        self.accuracy = CategoricalAccuracy()

    def forward(self,
                words: Dict[str, torch.Tensor],
                pos_tags: torch.Tensor = None,
                **args) -> Dict[str, torch.Tensor]:
        mask = get_text_field_mask(words)

        embeddings = self.embedder(words)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.linear(encoder_out)

        output = {"tag_logits": tag_logits}
        if pos_tags is not None:
            self.accuracy(tag_logits, pos_tags, mask)
            output["loss"] = sequence_cross_entropy_with_logits(
                tag_logits, pos_tags, mask)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}


In [4]:
reader = UniversalDependenciesDatasetReader()
sampler = BucketBatchSampler(batch_size=32, sorting_keys=["words"])

In [5]:
train_path = 'data/en_ewt-ud-train.conllu'
dev_path = 'data/en_ewt-ud-dev.conllu'

train_data_loader = MultiProcessDataLoader(
    reader, train_path, batch_sampler=sampler)
dev_data_loader = MultiProcessDataLoader(
    reader, dev_path, batch_sampler=sampler)

vocab = Vocabulary.from_instances(chain(train_data_loader.iter_instances(),
                                        dev_data_loader.iter_instances()))
train_data_loader.index_with(vocab)
dev_data_loader.index_with(vocab)

loading instances: 0it [00:00, ?it/s]

Your label namespace was 'pos'. We recommend you use a namespace ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by default to your vocabulary.  See documentation for `non_padded_namespaces` parameter in Vocabulary.


loading instances: 0it [00:00, ?it/s]

building vocab: 0it [00:00, ?it/s]

In [6]:
EMBEDDING_SIZE = 128
HIDDEN_SIZE = 128

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_SIZE)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

In [7]:
encoder = PytorchSeq2SeqWrapper(
    torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, batch_first=True))

In [10]:
model = LstmTagger(word_embeddings, encoder, vocab)
model.to(torch.device('cuda'))

LstmTagger(
  (embedder): BasicTextFieldEmbedder(
    (token_embedder_tokens): Embedding()
  )
  (encoder): PytorchSeq2SeqWrapper(
    (_module): LSTM(128, 128, batch_first=True)
  )
  (linear): Linear(in_features=128, out_features=19, bias=True)
)

In [9]:
optimizer = optim.Adam(model.parameters())

In [11]:
trainer = GradientDescentTrainer(
    model=model,
    optimizer=optimizer,
    data_loader=train_data_loader,
    validation_data_loader=dev_data_loader,
    patience=10,
    num_epochs=10,
    cuda_device=0)

In [12]:
trainer.train()



  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/392 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

{'best_epoch': 9,
 'peak_worker_0_memory_MB': 5494.640625,
 'peak_gpu_0_memory_MB': 65.056640625,
 'training_duration': '0:00:36.720819',
 'epoch': 9,
 'training_accuracy': 0.9528704450472909,
 'training_loss': 0.3693570998326248,
 'training_worker_0_memory_MB': 5494.640625,
 'training_gpu_0_memory_MB': 65.056640625,
 'validation_accuracy': 0.882853507237156,
 'validation_loss': 0.6837965382470025,
 'best_validation_accuracy': 0.882853507237156,
 'best_validation_loss': 0.6837965382470025}

In [13]:
predictor = UniversalPOSPredictor(model, reader)

In [14]:
tokens = ['The', 'dog', 'ate', 'the', 'apple', '.']
logits = predictor.predict(tokens)['tag_logits']
tag_ids = np.argmax(logits, axis=-1)

[vocab.get_token_from_index(tag_id, 'pos') for tag_id in tag_ids]

['DET', 'NOUN', 'VERB', 'DET', 'NOUN', 'PUNCT']

In [15]:
tokens = ['the', 'interview', 'this', 'afternoon', 'will', 'go', 'well', '.']
logits = predictor.predict(tokens)['tag_logits']
tag_ids = np.argmax(logits, axis=-1)

[vocab.get_token_from_index(tag_id, 'pos') for tag_id in tag_ids]

['DET', 'NOUN', 'DET', 'NOUN', 'AUX', 'VERB', 'ADV', 'PUNCT']