In [1]:
import ssl
import os

ssl._create_default_https_context = ssl._create_unverified_context

from allennlp.predictors.predictor import Predictor
import allennlp_models.tagging

predictor_bert = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz")

predictor_bilstm = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/openie-model.2020.03.26.tar.gz")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### INV: Sentences with em dash phrases.

In [2]:
import json
with open('data/em_dash_inv.json', 'r') as infile:
    sentences = json.loads(infile.read())
print(sentences[0])

{'original': 'She sang the song in front of a large audience.', 'em_dash': 'She sang the song—beautifully, powerfully, and passionately—in front of a large audience.'}


In [3]:
import json

def arg_class_em_dash_inv(predictor, sentences, model_name):
    # Initialize an empty list for storing failure cases
    failures = []
    model_outputs = []

    # Loop through each sentence in the input sentences
    for sentence in sentences:
        # Get the original and em-dash versions of the sentence
        original = sentence['original']
        dashed = sentence['em_dash']
        
        # Use the predictor to make predictions for both sentence versions
        original_pred = predictor.predict(original)
        dashed_pred = predictor.predict(dashed)
        
        # Find the indices of em-dashes in the dashed_pred
        em_index1 = dashed_pred['words'].index('—')
        em_index2 = dashed_pred['words'].index('—', em_index1+1)

        # Extract the arguments for each predicate from the original prediction
        original_arguments = {}
        for predicate in original_pred['verbs']:
            verb = predicate['verb']
            tags = predicate['tags']
            original_arguments[verb] = tags
        
        # Extract the arguments for each predicate from the dashed prediction, excluding the em-dash
        dashed_arguments = {}
        for predicate in dashed_pred['verbs']:
            verb = predicate['verb']
            tags = predicate['tags']
            dashed_arguments[verb] = tags[:em_index1] + tags[em_index2+1:]
        
        # Initialize a variable to track if the current instance failed
        failure = 0

        # Compare the original and dashed arguments for consistency
        for verb in original_arguments:
            if verb not in dashed_arguments:
                failure = 1
            else:
                for expected, predicted in zip(original_arguments[verb], dashed_arguments[verb]):
                    # If there is a mismatch in the classification of arguments, mark it as a failure
                    if expected != predicted:
                        failure = 1

        # Append the failure status to the list of failures
        failures.append(failure)

        # Add output to model_outputs
        model_outputs.append({
            'input': {'original': original, 'em_dash': dashed},
            'output': {'original_pred': original_pred, 'dashed_pred': dashed_pred},
            'failure': bool(failure)
        })
    
    # Write model_outputs to a JSON file
    if not os.path.exists("model_outputs"):
        os.makedirs("model_outputs")
    with open(f'model_outputs/arg_class_em_dash_inv_{model_name}.json', 'w') as f:
        json.dump(model_outputs, f, indent=4)

    # Calculate and return the failure rate
    return sum(failures) / len(failures)


print("(BERT) Argument Classification INV to em dash phrases:", 100 * arg_class_em_dash_inv(predictor_bert, sentences, 'BERT'))
print("(BiLSTM) Argument Classification INV to em dash phrases:", 100 * arg_class_em_dash_inv(predictor_bilstm, sentences, 'BiLSTM'))

(BERT) Argument Classification INV to em dash phrases: 24.0
(BiLSTM) Argument Classification INV to em dash phrases: 90.0
