In [30]:
import checklist 
import spacy
from transformers import pipeline
from transformers import AutoModelForTokenClassification, AutoTokenizer
from typing import List


In [31]:
nlp = spacy.load('en_core_web_sm')



In [32]:
tokenizer = AutoTokenizer.from_pretrained("./results/model/checkpoint-240/")
model = AutoModelForTokenClassification.from_pretrained("./results/model/checkpoint-240/")

In [33]:
token_classifier = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy= "average")

In [69]:
def predict_model(inputs):
    """
    Wrapper function for model prediction
    Args: 
    """
    predictions = token_classifier(inputs)
    # predictions = [p['word'] for p in predictions]
    return predictions

In [70]:
def read_lines(filename: str) -> List[str]:
    """
    Read templates from a .txt file containing templates.
    Args: 
        Path to file (str)
    Returns: 
        List of templates (str)
    """
    with open(filename, 'r') as f:
        lines = f.read().splitlines()
    return lines

In [71]:
# Wrapper just returns dummy confidence in addition to predictions
from checklist.pred_wrapper import PredictorWrapper
predict_and_conf = PredictorWrapper.wrap_predict(predict_model)

In [72]:
import checklist
from checklist.editor import Editor
from checklist.perturb import Perturb
from checklist.test_types import MFT, INV, DIR
from checklist.expect import Expect
editor = Editor()

# Person Tests

In [73]:
def found_people(x, pred, conf, label=None, meta=None):
    people = set([meta['first_name'].lower(), meta['last_name'].lower()])
    pass_ = True
    for x in pred:
        words = set(x['word'].split())
        if len(words & people) == len(words | people) and x['entity_group'] != 'PERSON':
            pass_ = False
        if len(words & people) != len(words | people) and x['entity_group'] == 'PERSON':
            pass_ = False
    return pass_
expect_fn = Expect.single(found_people)

In [74]:
def format_ner(x, pred, conf, label=None, meta=None):
    print(pred)
    return ' '.join(['%s(%s)' % (x, x['entity_group']) for x in pred])

In [75]:
person_templates = read_lines('./data/person_test_templates.txt')

In [82]:
t = editor.template('{first_name} {last_name}',  meta=True, nsamples=300)
test = MFT(**t, expect=expect_fn)
test.run(predict_and_conf)
test.visual_summary()

Predicting 300 examples


TestSummarizer(stats={'npassed': 288, 'nfailed': 12, 'nfiltered': 0}, summarizer={'name': None, 'description':…

# GPE Tests

In [87]:
# This assumes that pred is a spacy Doc, and that 'meta' contains 'first_name' and 'last_name'.
def found_city(x, pred, conf, label=None, meta=None):
    city = set([meta['city'].lower()])
    pass_ = True
    for p in pred:
        words = set(p['word'].split())
        if len(words & city) == len(words | city) and p['entity_group'] != 'GPE':
            # print(words, city)
            pass_ = False
        if len(words & city) != len(words | city) and p['entity_group'] == 'GPE':
            # print(words, city)
            pass_ = False
    return pass_
expect_fn = Expect.single(found_city)

In [88]:
city_templates = read_lines('./data/city_templates_testing.txt')

In [89]:
t = editor.template('{city}',  meta=True, nsamples=300)
test = MFT(**t, expect=expect_fn)
test.run(predict_and_conf)
test.summary(format_example_fn=format_ner)

Predicting 300 examples


In [86]:
test.visual_summary()

TestSummarizer(stats={'npassed': 204, 'nfailed': 93, 'nfiltered': 0}, summarizer={'name': None, 'description':…