Skip to content

Commit

Permalink
Add LIT integration example
Browse files Browse the repository at this point in the history
  • Loading branch information
mhagiwara committed Mar 16, 2021
1 parent a6a63c4 commit 8197178
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 4 deletions.
87 changes: 87 additions & 0 deletions examples/interpret/run_lit.py
@@ -0,0 +1,87 @@
import numpy as np

from allennlp.models.archival import load_archive
from allennlp.predictors.predictor import Predictor
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types

from examples.sentiment.sst_classifier import LstmClassifier
from examples.sentiment.sst_reader import StanfordSentimentTreeBankDatasetReaderWithTokenizer


class SSTData(lit_dataset.Dataset):
"""Stanford Sentiment Treebank, binary version (SST-2).
See https://www.tensorflow.org/datasets/catalog/glue#gluesst2.
"""

def __init__(self, labels):
self._labels = labels
self._examples = [
{'sentence': 'This is the best movie ever!!!', 'label': '4'},
{'sentence': 'A good movie.', 'label': '3'},
{'sentence': 'A mediocre movie.', 'label': '1'},
{'sentence': 'It was such an awful movie...', 'label': '0'}
]

def spec(self):
return {
'sentence': lit_types.TextSegment(),
'label': lit_types.CategoryLabel(vocab=self._labels)
}


class SentimentClassifierModel(lit_model.Model):
def __init__(self):
cuda_device = 0
archive_file = 'model/model.tar.gz'
predictor_name = 'sentence_classifier_predictor'

archive = load_archive(
archive_file=archive_file,
cuda_device=cuda_device
)

predictor = Predictor.from_archive(archive, predictor_name=predictor_name)

self.predictor = predictor
label_map = archive.model.vocab.get_index_to_token_vocabulary('labels')
self.labels = [label for _, label in sorted(label_map.items())]

def predict_minibatch(self, inputs):
for inst in inputs:
pred = self.predictor.predict(inst['sentence'])
tokens = self.predictor._tokenizer.tokenize(inst['sentence'])
yield {
'tokens': tokens,
'probas': np.array(pred['probs']),
'cls_emb': np.array(pred['cls_emb'])
}


def input_spec(self):
return {
"sentence": lit_types.TextSegment(),
"label": lit_types.CategoryLabel(vocab=self.labels, required=False)
}

def output_spec(self):
return {
"tokens": lit_types.Tokens(),
"probas": lit_types.MulticlassPreds(parent="label", vocab=self.labels),
"cls_emb": lit_types.Embeddings()
}


def main():
model = SentimentClassifierModel()
models = {"sst": model}
datasets = {"sst": SSTData(labels=model.labels)}

lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
lit_demo.serve()

if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion examples/sentiment/sst_classifier.py
Expand Up @@ -65,9 +65,10 @@ def forward(self,
encoder_out = self.encoder(embeddings, mask)
logits = self.linear(encoder_out)

probs = torch.softmax(logits, dim=-1)
# In AllenNLP, the output of forward() is a dictionary.
# Your output dictionary must contain a "loss" key for your model to be trained.
output = {"logits": logits}
output = {"logits": logits, "cls_emb": encoder_out, "probs": probs}
if label is not None:
self.accuracy(logits, label)
self.f1_measure(logits, label)
Expand Down
5 changes: 2 additions & 3 deletions realworldnlp/predictors.py
@@ -1,6 +1,5 @@
from allennlp.common import JsonDict
from allennlp.data import DatasetReader, Instance
from allennlp.data.tokenizers import SpacyTokenizer
from allennlp.models import Model
from allennlp.predictors import Predictor
from overrides import overrides
Expand All @@ -13,7 +12,7 @@
class SentenceClassifierPredictor(Predictor):
def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
super().__init__(model, dataset_reader)
self._tokenizer = SpacyTokenizer(language='en_core_web_sm', pos_tags=True)
self._tokenizer = dataset_reader._tokenizer

def predict(self, sentence: str) -> JsonDict:
return self.predict_json({"sentence" : sentence})
Expand All @@ -22,7 +21,7 @@ def predict(self, sentence: str) -> JsonDict:
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
sentence = json_dict["sentence"]
tokens = self._tokenizer.tokenize(sentence)
return self._dataset_reader.text_to_instance([str(t) for t in tokens])
return self._dataset_reader.text_to_instance(tokens)


@Predictor.register("universal_pos_predictor")
Expand Down

0 comments on commit 8197178

Please sign in to comment.