# CheckList test of two SRL sytems

In this notebook, we'll apply CheckList tests to two Semantic Role Labeling (SRL) models:
1. A logistic regression model, trained on three features.
2. A DistillBERT model, fine-tuned on a CoNLL SRL dataset.

## Importing dependencies

In [1]:
# For the BERT model

import time
import pandas as pd
import transformers
import numpy as np
import torch
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
from datasets import Dataset
from utils import read_data_as_sentence,map_labels_in_dataframe,tokenize_and_align_labels,get_label_mapping,get_labels_from_map,load_srl_model,load_dataset,compute_metrics,write_predictions_to_csv,compute_evaluation_metrics_from_csv, print_sentences
from bert_srl import main, define_args

# For the logistic regression model
import json
import sys
import pickle
from datetime import datetime
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction import DictVectorizer
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

sys.path.append('feature_extraction')
from extract_position_rel2pred import extract_word_position_and_voice
from extract_dependency_path import extract_dependency_paths
from extract_predicate import extract_predicate_lemma

## 1. Declare standalone functions

### 1.1 Logistic regression

In [2]:
def extract_features(data):
	"""
	Extract features from the data.

	Returns a list of samples.
	"""

	samples = []

	for sentence in data:

		# Extract features
		positions_rel2pred, verb_voice = extract_word_position_and_voice(sentence)
		d_paths = extract_dependency_paths(sentence)
		predicate_lemma = extract_predicate_lemma(sentence)
  
		# Create a sample for each token in the sentence.
		for i, token in enumerate(sentence):
			# Skip predicate tokens.	
			if token['predicate'] == '_':
				sample = {
					'token': token['form'],
					'position_rel2pred': positions_rel2pred[i] + verb_voice,
					'dep_path+lemma': d_paths[i] + predicate_lemma
				}
				samples.append(sample)

	return samples

def format_sentence(sentence, predicate_location, argument_labels, predicate_form):
    """
    Formats a sentence into a list of dictionaries, to match the input format of the feature extraction functions.
    
    Args:
        sentence (list): A list of words.
        predicate_location (list): A one-hot vector, indicating the location of the predicate.
        argument_labels (list): A list of argument labels.
        predicate_form (str): The sense label of the predicate.
    """
    output = []
    
    for i, word in enumerate(sentence):
        word_dict = {
            "form": word,
            "predicate": predicate_form if predicate_location[i] == 1 else "_",
            "argument": argument_labels[i]
        }
        output.append(word_dict)
    
    return output

def classify_sentence_logreg(sentence, predicate_location, argument_labels, predicate_sense, model, vectorizer):
    """
    The standalone function that takes a sentence and predicts the argument labels, using logistic regression
    
    Args:
        sentence (list): A list of words.
        predicate_location (list): A one-hot vector, indicating the location of the predicate.
        argument_labels (list): A list of argument labels.
        predicate_sense (str): The sense label of the predicate.
    """
    formatted_output = format_sentence(sentence, predicate_location, argument_labels, predicate_sense)
    sample = extract_features([formatted_output])
    feature_vectors = vectorizer.transform(sample)
    predictions = model.predict(feature_vectors)
    predictions = np.insert(predictions, predicate_location.index(1), '_')
    return predictions

sentence = ["The", "dog", "ran", "and", "the", "man", "fell", "."]

predicate_location = [0, 0, 1, 0, 0, 0, 0, 0] 
argument_labels = ['_', 'ARG0', '_', '_', '_', '_', '_', '_']
predicate_sense = "run.01"

with open("learned-models/vectorizer.pkl", "rb") as f:
    vectorizer = pickle.load(f)

with open("learned-models/model.pkl", "rb") as f:
    model = pickle.load(f)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


### 1.2 DistillBERT 

In [3]:
def create_input_sequence(sentence, predicate_position, argument_labels):
    """
    Creates a DataFrame with columns 'input_form' and 'argument' for a single sentence.

    Parameters:
    - sentence (list of str): The words in the sentence.
    - predicate_position (list of int): One-hot encoding indicating the predicate position.
    - argument_labels (list of str): The argument labels for each token in the sentence.

    Returns:
    - DataFrame with two columns: 'input_form' and 'argument'.
    """
    # Ensure input lengths match
    assert len(sentence) == len(predicate_position) == len(argument_labels), "Input lists must have the same length."
    
    # Determine the predicate form based on the one-hot encoding
    predicate_index = predicate_position.index(1)
    predicate_form = sentence[predicate_index]
    
    # Append special tokens to input_form and argument lists
    input_form = sentence + ['[SEP]', predicate_form]
    argument = argument_labels + [None, None]
    
    # Create a DataFrame
    df = pd.DataFrame([{"input_form": input_form, "argument": argument}])
    return df

def map_labels_to_words(predicted_labels, gold_labels, dataset):
    tokens = []
    for i, (predictions, gold_labels) in enumerate(zip(predicted_labels, gold_labels)):
        subword_tokens = tokenizer.convert_ids_to_tokens(dataset[i]["input_ids"], skip_special_tokens=True)
        
        word_tokens = []
        word_labels_gold = []
        word_labels_pred = []
    
        current_word = ""
        current_gold_label = None
        current_pred_label = None
    
        for idx, (subword, gold, pred) in enumerate(zip(subword_tokens, gold_labels, predictions)):
            if subword.startswith("##"):  # Continuation of a word
                current_word += subword[2:]
            else:  # New word starts
                if current_word:  # Save the previous word and its label
                    word_tokens.append(current_word)
                    word_labels_gold.append(current_gold_label)
                    word_labels_pred.append(current_pred_label)
                
                current_word = subword  # Start new word
                current_gold_label = gold  # Take the first subword's label
                current_pred_label = pred  # Take the first subword's label
    
        if current_word:
            word_tokens.append(current_word)
            word_labels_gold.append(current_gold_label)
            word_labels_pred.append(current_pred_label)
    
        tokens.extend(zip(word_tokens, word_labels_gold, word_labels_pred))
    
    # Create a dataframe and write to CVS
    df = pd.DataFrame(tokens, columns=["word", "gold_label", "predicted_label"])
    return df

def classify_sentence_bert(sentence, predicate_location, argument_labels, predicate_sense, trainer, tokenizer):
    """
    The standalone function that takes a sentence and predicts the argument labels, using DistilBERT.
    
    Args:
        sentence (list): A list of words.
        predicate_location (list): A one-hot vector, indicating the location of the predicate.
        argument_labels (list): A list of argument labels.
        predicate_sense (str): The sense label of the predicate.
        trainer: The HuggingFace Trainer instance to predict labels with.
        tokenizer: The HuggingFace Tokenizer to tokenize input sequences with. 
    """
    inference_input = create_input_sequence(sentence, predicate_location, argument_labels)
    label_map = {'_': 0, 'ARG0': 1, 'ARG1': 2, 'ARG1-DSP': 3, 'ARG2': 4, 'ARG3': 5, 'ARG4': 6, 'ARG5': 7, 'ARGA': 8, 'ARGM-ADJ': 9, 'ARGM-ADV': 10, 'ARGM-CAU': 11, 'ARGM-COM': 12, 'ARGM-CXN': 13, 'ARGM-DIR': 14, 'ARGM-DIS': 15, 'ARGM-EXT': 16, 'ARGM-GOL': 17, 'ARGM-LOC': 18, 'ARGM-LVB': 19, 'ARGM-MNR': 20, 'ARGM-MOD': 21, 'ARGM-NEG': 22, 'ARGM-PRD': 23, 'ARGM-PRP': 24, 'ARGM-PRR': 25, 'ARGM-REC': 26, 'ARGM-TMP': 27, 'C-ARG0': 28, 'C-ARG1': 29, 'C-ARG1-DSP': 30, 'C-ARG2': 31, 'C-ARG3': 32, 'C-ARG4': 33, 'C-ARGM-ADV': 34, 'C-ARGM-COM': 35, 'C-ARGM-CXN': 36, 'C-ARGM-DIR': 37, 'C-ARGM-EXT': 38, 'C-ARGM-GOL': 39, 'C-ARGM-LOC': 40, 'C-ARGM-MNR': 41, 'C-ARGM-PRP': 42, 'C-ARGM-PRR': 43, 'C-ARGM-TMP': 44, 'R-ARG0': 45, 'R-ARG1': 46, 'R-ARG2': 47, 'R-ARG3': 48, 'R-ARG4': 49, 'R-ARGM-ADJ': 50, 'R-ARGM-ADV': 51, 'R-ARGM-CAU': 52, 'R-ARGM-COM': 53, 'R-ARGM-DIR': 54, 'R-ARGM-GOL': 55, 'R-ARGM-LOC': 56, 'R-ARGM-MNR': 57, 'R-ARGM-TMP': 58, None: None}
    inference_data = map_labels_in_dataframe(inference_input, label_map)
    tokenized_input = tokenize_and_align_labels(tokenizer, inference_data, label_all_tokens=True)
    dataset_inference_sample = load_dataset(tokenized_input)
    label_list = get_labels_from_map(label_map)
    
    predictions, labels, _ = trainer.predict(dataset_inference_sample)
    argmax_predictions = np.argmax(predictions, axis=2)
    
    # Extract predicted labels for each token, filtering out special tokens
    predicted_labels = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(argmax_predictions, labels)
    ]
    
    return predicted_labels[0]

tokenizer = AutoTokenizer.from_pretrained("learned-models/tokenizer.save_pretrained.distillbert-base-uncased-finetuned-srl")
bert_model = AutoModelForTokenClassification.from_pretrained("learned-models/model.save_pretrained.distillbert-base-uncased-finetuned-srl")
training_args = TrainingArguments(output_dir="learned-models/trainer.save_model.distillbert-base-uncased-finetuned-srl")

trainer = Trainer(
    model=bert_model,
    args=training_args,
    tokenizer=tokenizer
)

  trainer = Trainer(


## 2. Load CheckList dataset

### 2.1 

Load the `testing-dataset.json` from the `data` directory as a Python dictionary.

In [37]:
with open("data/testing-dataset.json", "r", encoding="utf-8") as file:
    checklist_dataset = json.load(file)

In [38]:
def get_relevant_index(argument_labels, item):
    # Find the index of the non-"_" label. Assert exactly one such index exists.
    relevant_indices = [i for i, label in enumerate(argument_labels) if label != "_"]
    assert len(relevant_indices) == 1, f"Expected exactly one relevant label in item: {item}, found: {relevant_indices}"
    return relevant_indices[0]

# Initialize a dictionary to collect all results
aggregated_results = {"capabilities": []}

# Iterate over capabilities in the dataset
for capability in checklist_dataset:
    cap_name = capability['capability_name']
    cap_description = capability['capability_description']
    cap_category = capability['capability_category']
    print(f"Capability: {cap_name} ({cap_category})")
    
    cap_result = {
        "capability_name": cap_name,
        "capability_description": cap_description,
        "capability_category": cap_category,
        "tests": []
    }
    
    # Track performance for this capability
    cap_total_gold = 0
    cap_failures_logreg = 0
    cap_failures_bert = 0
    
    # Iterate over tests
    for test in capability['tests']:
        test_type = test['test_type']
        test_result = {
            "test_name": test['test_name'],
            "test_description": test['test_description'],
            "test_type": test_type,
            "samples": []
        }

        print(f"\n  * Test: {test_result['test_name']} ({test_result['test_type']})")
        
        # Counters per test
        test_total_gold = 0
        test_failures_logreg = 0
        test_failures_bert = 0
        
        # Iterate over the samples in this test
        for sample in test['samples']:
            sample_results = []  # to store results for every inference in this sample
            
            if test_type == "MFT":
                # For MFT tests, each sample is a list containing one item.
                for item in sample:
                    tokens = item['tokens']
                    argument_labels = item['argument_labels']
                    predicate_sense = item['predicate_name']
                    num_classes = len(tokens)
                    
                    # Convert predicate position to one-hot vector
                    predicate_position = torch.tensor([item['predicate_position']])
                    predicate_position = torch.nn.functional.one_hot(predicate_position, num_classes).squeeze(0).tolist()
                    
                    # Get inferences from your models
                    logreg_inference = classify_sentence_logreg(
                        tokens, predicate_position, argument_labels, predicate_sense, model, vectorizer
                    )
                    bert_inference = classify_sentence_bert(
                        tokens, predicate_position, argument_labels, predicate_sense, trainer, tokenizer
                    )
                    
                    # Process inferences and compare with gold labels (ignoring tokens with "_" label)
                    for gold_label, logreg_label, bert_label in zip(argument_labels, logreg_inference, bert_inference):
                        if gold_label != "_":
                            test_total_gold += 1
                            cap_total_gold += 1
                            
                            logreg_correct = (gold_label == logreg_label)
                            bert_correct = (gold_label == bert_label)
                            
                            if not logreg_correct:
                                test_failures_logreg += 1
                                cap_failures_logreg += 1
                            if not bert_correct:
                                test_failures_bert += 1
                                cap_failures_bert += 1
                            
                            sample_results.append({
                                "gold_label": gold_label,
                                "logreg_label": logreg_label,
                                "bert_label": bert_label,
                                "logreg_correct": logreg_correct,
                                "bert_correct": bert_correct
                            })
                            
            elif test_type == "INV":
                # For INV tests, each sample should contain exactly two items.
                if len(sample) != 2:
                    print("Warning: INV test sample does not contain exactly 2 items.")
                    continue

                item1, item2 = sample[0], sample[1]
                
                # Process item 1
                tokens1 = item1['tokens']
                num_classes1 = len(tokens1)
                argument_labels_1 = item1['argument_labels']
                predicate_sense_1 = item1['predicate_name']
                # Find the index of the relevant (non-"_") gold label for item1.
                relevant_index_1 = get_relevant_index(argument_labels_1, item1)
                
                # Convert predicate position to one-hot vector                
                predicate_position_tensor1 = torch.tensor([item1['predicate_position']])
                predicate_position1 = torch.nn.functional.one_hot(predicate_position_tensor1, num_classes1).squeeze(0).tolist()
                
                # Infer both SRL models for sample 1
                logreg_inference_1 = classify_sentence_logreg(
                    tokens1, predicate_position1, argument_labels_1, predicate_sense_1, model, vectorizer
                )
                bert_inference_1 = classify_sentence_bert(
                    tokens1, predicate_position1, argument_labels_1, predicate_sense_1, trainer, tokenizer
                )
                
                # Process item 2
                tokens2 = item2['tokens']
                num_classes2 = len(tokens2)
                argument_labels_2 = item2['argument_labels']
                predicate_sense_2 = item2['predicate_name']
                relevant_index_2 = get_relevant_index(argument_labels_2, item2)
                
                # Convert predicate position to one-hot vector                
                predicate_position_tensor2 = torch.tensor([item2['predicate_position']])
                predicate_position2 = torch.nn.functional.one_hot(predicate_position_tensor2, num_classes2).squeeze(0).tolist()
                
                # Infer both SRL models for sample 1
                logreg_inference_2 = classify_sentence_logreg(
                    tokens2, predicate_position2, argument_labels_2, predicate_sense_2, model, vectorizer
                )
                bert_inference_2 = classify_sentence_bert(
                    tokens2, predicate_position2, argument_labels_2, predicate_sense_2, trainer, tokenizer
                )
                
                # Extract the prediction for the relevant index for each item.
                pred_logreg_1 = logreg_inference_1[relevant_index_1]
                pred_bert_1 = bert_inference_1[relevant_index_1]
                pred_logreg_2 = logreg_inference_2[relevant_index_2]
                pred_bert_2 = bert_inference_2[relevant_index_2]
                
                # Update counters.
                test_total_gold += 1
                cap_total_gold += 1
                
                # Check if the class changed between the two items in the sample.
                logreg_correct = (pred_logreg_1 == pred_logreg_2)
                bert_correct = (pred_bert_1 == pred_bert_2)
                
                if not logreg_correct:
                    test_failures_logreg += 1
                    cap_failures_logreg += 1
                if not bert_correct:
                    test_failures_bert += 1
                    cap_failures_bert += 1
                
                sample_results.append({
                    "relevant_index_item1": relevant_index_1,
                    "relevant_index_item2": relevant_index_2,
                    "logreg_prediction_item1": pred_logreg_1,
                    "logreg_prediction_item2": pred_logreg_2,
                    "bert_prediction_item1": pred_bert_1,
                    "bert_prediction_item2": pred_bert_2,
                    "logreg_correct": logreg_correct,
                    "bert_correct": bert_correct
                })
            
            # Save the detailed results for this sample
            test_result["samples"].append(sample_results)
        
        # Calculate failure rates for this test (as percentages)
        test_result["failure_rate_logreg"] = (
            100.0 * test_failures_logreg / test_total_gold if test_total_gold else None
        )
        test_result["failure_rate_bert"] = (
            100.0 * test_failures_bert / test_total_gold if test_total_gold else None
        )

        print(f"""
      * Failure rates:
            Logistic regression: {test_result["failure_rate_logreg"]}%
            DistillBERT: {test_result["failure_rate_bert"]}%
""")
        
        # Add the test result to the capability's results
        cap_result["tests"].append(test_result)
    
    # Calculate overall failure rates for the capability
    cap_result["failure_rate_logreg"] = (
        100.0 * cap_failures_logreg / cap_total_gold if cap_total_gold else None
    )
    cap_result["failure_rate_bert"] = (
        100.0 * cap_failures_bert / cap_total_gold if cap_total_gold else None
    )

    print(f"""
  > Failure rates (total for capability):
        Logistic regression: {cap_result["failure_rate_logreg"]}%
        DistillBERT: {cap_result["failure_rate_bert"]}%

=====================================================================================
""")
    
    # Add the capability result to the aggregated results
    aggregated_results["capabilities"].append(cap_result)

# Write the aggregated results and details to a JSON file
with open("checklist-results.json", "w", encoding="utf-8") as outfile:
    json.dump(aggregated_results, outfile, indent=2)

Capability: Long-distance dependencies between predicate and ARG0 (syntactic)

  * Test: Effect of injecting relative clause between predicate and ARG0 (INV)



      * Failure rates:
            Logistic regression: 20.0%
            DistillBERT: 30.0%


  * Test: Sentences without relative clause between predicate and ARG0 (MFT)



      * Failure rates:
            Logistic regression: 70.0%
            DistillBERT: 20.0%


  * Test: Sentences with relative clause between predicate and ARG0 (MFT)



      * Failure rates:
            Logistic regression: 90.0%
            DistillBERT: 50.0%


  > Failure rates (total for capability):
        Logistic regression: 60.0%
        DistillBERT: 33.333333333333336%


Capability: Long-distance dependencies between predicate and ARG1 (syntactic)

  * Test: Effect of injecting adverbial or participial phrase between predicate and ARG1 (INV)



      * Failure rates:
            Logistic regression: 60.0%
            DistillBERT: 60.0%


  * Test: Sentence without adverbial or participial phrase between predicate and ARG1 (MFT)



      * Failure rates:
            Logistic regression: 10.0%
            DistillBERT: 0.0%


  * Test: Sentence with adverbial or participial phrase between predicate and ARG1 (MFT)



      * Failure rates:
            Logistic regression: 70.0%
            DistillBERT: 60.0%


  > Failure rates (total for capability):
        Logistic regression: 46.666666666666664%
        DistillBERT: 40.0%


Capability: Robustness to noise in the form of typos (lexical)

  * Test: Effect of typos in proper nouns as ARG0 on ARG0 labeling (INV)



      * Failure rates:
            Logistic regression: 30.0%
            DistillBERT: 20.0%


  * Test: Effect of typos in common nouns as ARG0 on ARG0 labeling (INV)



      * Failure rates:
            Logistic regression: 20.0%
            DistillBERT: 0.0%


  * Test: Sentences without typos in ARG0 for proper noun ARG0 labeling (MFT)



      * Failure rates:
            Logistic regression: 60.0%
            DistillBERT: 0.0%


  * Test: Sentences with typos in ARG0 for proper noun ARG0 labeling (MFT)



      * Failure rates:
            Logistic regression: 90.0%
            DistillBERT: 20.0%


  * Test: Sentences without typos in ARG0 for common noun ARG0 labeling (MFT)



      * Failure rates:
            Logistic regression: 60.0%
            DistillBERT: 0.0%


  * Test: Sentences with typos in ARG0 for common noun ARG0 labeling (MFT)



      * Failure rates:
            Logistic regression: 80.0%
            DistillBERT: 0.0%


  > Failure rates (total for capability):
        Logistic regression: 56.666666666666664%
        DistillBERT: 6.666666666666667%


Capability: Effect of semantic atypicality in active voice syntactically simple SVO sentences on ARG0 labeling (lexical)

  * Test: Animate objects as ARG0 (MFT)



      * Failure rates:
            Logistic regression: 90.0%
            DistillBERT: 0.0%


  * Test: Inanimate objects as ARG0 (MFT)



      * Failure rates:
            Logistic regression: 100.0%
            DistillBERT: 40.0%


  * Test: Animate versus inanimate concepts as ARG0 (INV)



      * Failure rates:
            Logistic regression: 10.0%
            DistillBERT: 40.0%


  * Test: Non-abstract concepts as ARG0 (MFT)



      * Failure rates:
            Logistic regression: 100.0%
            DistillBERT: 70.0%


  * Test: Abstract concepts as ARG0 (MFT)



      * Failure rates:
            Logistic regression: 100.0%
            DistillBERT: 90.0%


  * Test: Abstract vs non-abstract concepts as ARG0 (INV)



      * Failure rates:
            Logistic regression: 0.0%
            DistillBERT: 30.0%


  > Failure rates (total for capability):
        Logistic regression: 66.66666666666667%
        DistillBERT: 45.0%


Capability: Dealing with dative verb alternations (syntactic)

  * Test: Dative verb alternations effect on ARG1 (INV)



      * Failure rates:
            Logistic regression: 0.0%
            DistillBERT: 0.0%


  * Test: Dative verb alternations with prepositional dative construction for ARG1 (MFT)



      * Failure rates:
            Logistic regression: 80.0%
            DistillBERT: 30.0%


  * Test: Dative verb alternations with double object construction for ARG1 (MFT)



      * Failure rates:
            Logistic regression: 80.0%
            DistillBERT: 30.0%


  > Failure rates (total for capability):
        Logistic regression: 53.333333333333336%
        DistillBERT: 20.0%




In [39]:
with open("checklist-results.json", "r", encoding="utf-8") as file:
    checklist_results = json.load(file)

In [46]:
for capability in checklist_results['capabilities']:
    print(capability['capability_name'])
    print(f"\tFailure rate logistic regression: {capability['failure_rate_logreg']}%")
    print(f"\tFailure rate DistillBERT        : {capability['failure_rate_bert']}%")

Long-distance dependencies between predicate and ARG0
	Failure rate logistic regression: 60.0%
	Failure rate DistillBERT        : 33.333333333333336%
Long-distance dependencies between predicate and ARG1
	Failure rate logistic regression: 46.666666666666664%
	Failure rate DistillBERT        : 40.0%
Robustness to noise in the form of typos
	Failure rate logistic regression: 56.666666666666664%
	Failure rate DistillBERT        : 6.666666666666667%
Effect of semantic atypicality in active voice syntactically simple SVO sentences on ARG0 labeling
	Failure rate logistic regression: 66.66666666666667%
	Failure rate DistillBERT        : 45.0%
Dealing with dative verb alternations
	Failure rate logistic regression: 53.333333333333336%
	Failure rate DistillBERT        : 20.0%
