Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
49 lines (37 sloc) 1.71 KB
import torch
import torch.nn as nn
from allennlp.models import Model
from import Vocabulary
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from import SpanBasedF1Measure
from typing import Dict, Optional
class NerLstm(Model):
def __init__(self,
vocab: Vocabulary,
embedder: TextFieldEmbedder,
encoder: Seq2SeqEncoder) -> None:
self._embedder = embedder
self._encoder = encoder
self._classifier = torch.nn.Linear(in_features=encoder.get_output_dim(),
self._f1 = SpanBasedF1Measure(vocab, 'labels', 'IOB1')
def get_metrics(self, reset: bool = True) -> Dict[str, float]:
return self._f1.get_metric(reset)
def forward(self,
tokens: Dict[str, torch.Tensor],
label: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
mask = get_text_field_mask(tokens)
embedded = self._embedder(tokens)
encoded = self._encoder(embedded, mask)
classified = self._classifier(encoded)
output: Dict[str, torch.Tensor] = {}
output['logits'] = classified
if label is not None:
self._f1(classified, label, mask)
output['loss'] = sequence_cross_entropy_with_logits(classified, label, mask)
return output
You can’t perform that action at this time.