In [1]:
import os
from typing import List
from typing import Tuple
import logging
FORMAT = "%(asctime)-15s %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO,
                    datefmt="%Y-%m-%d %H:%M")
logger = logging.getLogger(__name__)
from collections import defaultdict
from collections import Counter
import json
import torch
import numpy as np

from xlwt import Workbook

import sys
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..', 'dataset'))
from world import *
from vocabulary import Vocabulary as ReaSCANVocabulary
from object_vocabulary import *

import numpy as np
from typing import Tuple
from typing import List
from typing import Any
import matplotlib.pyplot as plt

In [2]:
def bar_plot(values: dict, title: str, save_path: str, errors={}, y_axis_label="Occurrence"):
    sorted_values = list(values.items())
    sorted_values = [(y, x) for x, y in sorted_values]
    sorted_values.sort()
    values_per_label = [value[0] for value in sorted_values]
    if len(errors) > 0:
        sorted_errors = [errors[value[1]] for value in sorted_values]
    else:
        sorted_errors = None
    labels = [value[1] for value in sorted_values]
    assert len(labels) == len(values_per_label)
    y_pos = np.arange(len(labels))

    plt.bar(y_pos, values_per_label, yerr=sorted_errors, align='center', alpha=0.5)
    plt.gcf().subplots_adjust(bottom=0.2, )
    plt.xticks(y_pos, labels, rotation=90, fontsize="xx-small")
    plt.ylabel(y_axis_label)
    plt.title(title)

    plt.savefig(save_path)
    plt.close()


def grouped_bar_plot(values: dict, group_one_key: Any, group_two_key: Any, title: str, save_path: str,
                     errors_group_one={}, errors_group_two={}, y_axis_label="Occurence", sort_on_key=True):
    sorted_values = list(values.items())
    if sort_on_key:
        sorted_values.sort()
    values_group_one = [value[1][group_one_key] for value in sorted_values]
    values_group_two = [value[1][group_two_key] for value in sorted_values]
    if len(errors_group_one) > 0:
        sorted_errors_group_one = [errors_group_one[value[0]] for value in sorted_values]
        sorted_errors_group_two = [errors_group_two[value[0]] for value in sorted_values]
    else:
        sorted_errors_group_one = None
        sorted_errors_group_two = None
    labels = [value[0] for value in sorted_values]
    assert len(labels) == len(values_group_one)
    assert len(labels) == len(values_group_two)
    y_pos = np.arange(len(labels))

    fig, ax = plt.subplots()
    width = 0.35
    p1 = ax.bar(y_pos, values_group_one, width, yerr=sorted_errors_group_one, align='center', alpha=0.5)
    p2 = ax.bar(y_pos + width, values_group_two, width, yerr=sorted_errors_group_two, align='center', alpha=0.5)
    plt.gcf().subplots_adjust(bottom=0.2, )
    plt.xticks(y_pos, labels, rotation=90, fontsize="xx-small")
    plt.ylabel(y_axis_label)
    plt.title(title)
    ax.legend((p1[0], p2[0]), (group_one_key, group_two_key))

    plt.savefig(save_path)
    plt.close()

In [36]:
# Change to your directories here!
pattern = "e"
seed = 88
split = "test"

In [37]:
predictions_file = f"../../../testing_logs/{pattern}-random-seed-{seed}/{split}_{pattern}-random-seed-{seed}.json"
output_file = f"../../../testing_logs/{pattern}-random-seed-{seed}/{split}_{pattern}-random-seed-{seed}-analysis.txt"
save_directory = f"../../../testing_logs/{pattern}-random-seed-{seed}/"
# test out the vocabulary
intransitive_verbs = ["walk"]
transitive_verbs = ["push", "pull"]
adverbs = ["while zigzagging", "while spinning", "cautiously", "hesitantly"]
nouns = ["circle", "cylinder", "square", "box"]
color_adjectives = ["red", "blue", "green", "yellow"]
size_adjectives = ["big", "small"]
relative_pronouns = ["that is"]
relation_clauses = ["in the same row as", 
                    "in the same column as", 
                    "in the same color as", 
                    "in the same shape as", 
                    "in the same size as",
                    "inside of"]
vocabulary = Vocabulary.initialize(intransitive_verbs=intransitive_verbs,
                                   transitive_verbs=transitive_verbs, adverbs=adverbs, nouns=nouns,
                                   color_adjectives=color_adjectives,
                                   size_adjectives=size_adjectives, 
                                   relative_pronouns=relative_pronouns, 
                                   relation_clauses=relation_clauses)

# test out the object vocab
min_object_size = 1
max_object_size = 4
object_vocabulary = ObjectVocabulary(shapes=vocabulary.get_semantic_shapes(),
                                     colors=vocabulary.get_semantic_colors(),
                                     min_size=min_object_size, max_size=max_object_size)
# object_vocabulary.generate_objects()

assert os.path.exists(predictions_file), "Trying to open a non-existing predictions file."
error_analysis = {
    "target_length": defaultdict(lambda: {"accuracy": [], "exact_match": []}),
    "input_length": defaultdict(lambda: {"accuracy": [], "exact_match": []}),
    "verb_in_command": defaultdict(lambda: {"accuracy": [], "exact_match": []}),
    "manner": defaultdict(lambda: {"accuracy": [], "exact_match": []}),
    "distance_to_target": defaultdict(lambda: {"accuracy": [], "exact_match": []}),
    "direction_to_target": defaultdict(lambda: {"accuracy": [], "exact_match": []}),
    "actual_target": defaultdict(lambda: {"accuracy": [], "exact_match": []}),
}
        
all_accuracies = []
exact_matches = []
with open(predictions_file, 'r') as infile:
    data = json.load(infile)
    logger.info("Running error analysis on {} examples.".format(len(data)))
    for predicted_example in data:

        # Get the scores of the current example.
        accuracy = predicted_example["accuracy"]
        exact_match = predicted_example["exact_match"]
        all_accuracies.append(accuracy)
        exact_matches.append(exact_match)
        
        # Get the information about the current example.
        example_information = {
            "input_length": len(predicted_example["input"]),
            "verb_in_command": vocabulary.translate_word(predicted_example["input"][0])}

        adverb = ""
        if predicted_example['input'][-1] in vocabulary.get_adverbs():
            adverb = predicted_example['input'][-1]
        manner = vocabulary.translate_word(adverb)
        
        example_information["target_length"] = len(predicted_example["target"])
        situation_repr = predicted_example["situation"]
        situation = Situation.from_representation(situation_repr[0])
        example_information["actual_target"] = ' '.join([str(situation.target_object.object.size),
                                                         situation.target_object.object.color,
                                                         situation.target_object.object.shape])
        example_information["direction_to_target"] = situation.direction_to_target
        example_information["distance_to_target"] = situation.distance_to_target
        example_information["manner"] = manner

        # Add that information to the analysis.
        for key in error_analysis.keys():
            error_analysis[key][example_information[key]]["accuracy"].append(accuracy)
            error_analysis[key][example_information[key]]["exact_match"].append(exact_match)

2021-06-14 15:27 Running error analysis on 8003 examples.


In [38]:
# Write the information to a file and make plots
workbook = Workbook()
with open(output_file, 'w') as outfile:
    outfile.write("Error Analysis\n\n")
    outfile.write(" Mean accuracy: {}\n".format(np.mean(np.array(all_accuracies))))
    exact_matches_counter = Counter(exact_matches)
    outfile.write(" Num. exact matches: {}\n".format(exact_matches_counter[True]))
    outfile.write(" Num not exact matches: {}\n\n".format(exact_matches_counter[False]))
    for key, values in error_analysis.items():
        sheet = workbook.add_sheet(key)
        sheet.write(0, 0, key)
        sheet.write(0, 1, "Num examples")
        sheet.write(0, 2, "Mean accuracy")
        sheet.write(0, 3, "Std. accuracy")
        sheet.write(0, 4, "Exact Match")
        sheet.write(0, 5, "Not Exact Match")
        sheet.write(0, 6, "Exact Match Percentage")
        outfile.write("\nDimension {}\n\n".format(key))
        means = {}
        standard_deviations = {}
        num_examples = {}
        exact_match_distributions = {}
        exact_match_relative_distributions = {}
        for i, (item_key, item_values) in enumerate(values.items()):
            outfile.write("  {}:{}\n\n".format(key, item_key))
            accuracies = np.array(item_values["accuracy"])
            mean_accuracy = np.mean(accuracies)
            means[item_key] = mean_accuracy
            num_examples[item_key] = len(item_values["accuracy"])
            standard_deviation = np.std(accuracies)
            standard_deviations[item_key] = standard_deviation
            exact_match_distribution = Counter(item_values["exact_match"])
            exact_match_distributions[item_key] = exact_match_distribution
            exact_match_relative_distributions[item_key] = exact_match_distribution[True] / (
                    exact_match_distribution[False] + exact_match_distribution[True])
            outfile.write("    Num. examples: {}\n".format(len(item_values["accuracy"])))
            outfile.write("    Mean accuracy: {}\n".format(mean_accuracy))
            outfile.write("    Min. accuracy: {}\n".format(np.min(accuracies)))
            outfile.write("    Max. accuracy: {}\n".format(np.max(accuracies)))
            outfile.write("    Std. accuracy: {}\n".format(standard_deviation))
            outfile.write("    Num. exact match: {}\n".format(exact_match_distribution[True]))
            outfile.write("    Num. not exact match: {}\n\n".format(exact_match_distribution[False]))
            sheet.write(i + 1, 0, item_key)
            sheet.write(i + 1, 1, len(item_values["accuracy"]))
            sheet.write(i + 1, 2, mean_accuracy)
            sheet.write(i + 1, 3, standard_deviation)
            sheet.write(i + 1, 4, exact_match_distribution[True])
            sheet.write(i + 1, 5, exact_match_distribution[False])
            sheet.write(i + 1, 6, exact_match_distribution[True] / (
                    exact_match_distribution[False] + exact_match_distribution[True]))
        outfile.write("\n\n\n")
        bar_plot(means, title=key, save_path=os.path.join(save_directory, key + '_accuracy'),
                 errors=standard_deviations, y_axis_label="accuracy")
        bar_plot(exact_match_relative_distributions, title=key, save_path=os.path.join(
            save_directory, key + '_exact_match_rel'),
                 errors={}, y_axis_label="Exact Match Percentage")
        grouped_bar_plot(values=exact_match_distributions, group_one_key=True, group_two_key=False,
                         title=key + ' Exact Matches', save_path=os.path.join(save_directory,
                                                                              key + '_exact_match'),
                         sort_on_key=True)
    outfile_excel = output_file.split(".txt")[0] + ".xls"
    workbook.save(outfile_excel)

In [39]:
exact_matches_counter[True] / (exact_matches_counter[True] + exact_matches_counter[False])

0.26252655254279644