# Analysis 

## Imports

In [None]:
import copy
import itertools
import os
import glob
import time
import pickle
import sys

import numpy as np
import pandas as pd
import seaborn as sns
from collections import defaultdict
from joblib import Parallel, delayed
from tqdm.notebook import tqdm
from emlangkit.language import Language
from matplotlib import pyplot as plt
from matplotlib.ticker import PercentFormatter
import matplotlib as mpl

# Workaround so we can re-use the project functions
module_path = os.path.abspath(os.path.join("../"))
if module_path not in sys.path:
    sys.path.append(module_path)

from tpg.utils.npmi import (
    compute_compositional_ngrams_positionals_npmi,
    compute_compositional_ngrams_integers_npmi,
    compute_non_compositional_npmi,
)
from tpg.utils.dict_utils import default_to_regular

In [None]:
sns.set(palette="pastel")
sns.set_style("whitegrid")
palette = sns.color_palette()

import matplotlib.pylab as pylab

params = {
    "legend.title_fontsize": "32",
    "legend.fontsize": "24",
    "axes.labelsize": "32",
    "axes.titlesize": "32",
    "xtick.labelsize": "22",
    "ytick.labelsize": "26",
}
pylab.rcParams.update(params)

## Global defs

In [None]:
top_ns = [1, 2, 3, 5, 10, 15]
confidences = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

## Dataset loading

In [None]:
all_files = glob.glob(os.path.join("./data/logs/", "*.json"))
li = []
params = []
for filename in tqdm(all_files):
    split = filename.split("-")
    run_id = split[1]
    architecture = split[2]
    params.append([run_id, architecture])
    df = pd.read_json(filename, orient="index")
    if run_id != "test":
        # TODO Comment out the below for full analysis
        # df.drop(index=[f"exchange_{x}" for x in range(35000)], inplace=True)
        pass
    for k in [
        "sequence",
        "cut_inputs",
        "tds",
        "message",
        "guess",
        "target",
        "target_id",
    ]:
        df[k] = df[k].apply(lambda x: np.array(x))
    li.append(df)

In [None]:
# This code is commented out not to bash WandB servers everytime we run analysis
# We run this only once and save to pickle

# import wandb
#
# wandb.login()
# api = wandb.Api(timeout=60)
#
# runs = api.runs("user/TPGv5")
# summary_list, config_list, name_list = [], [], []
# for run in tqdm(runs):
#     summary_list.append(
#         run.history(
#             samples=400,
#         )
#     )
#
#     config_list.append({k: v for k, v in run.config.items()})
#
#     name_list.append(run.name.split("-")[1])
#
# runs_full_df = pd.DataFrame(
#     {
#         "summary": summary_list,
#         "config": config_list,
#         "name": name_list,
#     }
# )
#
# runs_full_df.to_pickle("./data/runs_full_df.pickle")

In [None]:
runs_full_df = pd.read_pickle("./data/runs_full_df_v5.pickle")
runs_full_df = runs_full_df[["summary", "config", "name"]]

In [None]:
def df_stats(x: pd.DataFrame):
    # Some are empty, we'll drop them later
    x = x[x["val_acc"].notnull()]
    try:
        x["val_acc"].idxmax()
    except:
        return None
    max_acc_index = x["val_acc"].idxmax()
    over_75_index = x["val_acc"].ge(0.75).any() and x["val_acc"].ge(0.75).idxmax()
    over_85_index = x["val_acc"].ge(0.85).any() and x["val_acc"].ge(0.85).idxmax()
    end_acc_index = x["epoch"].idxmax()

    stats_dict = {
        "max_acc_epoch": x["epoch"][max_acc_index],
        "max_acc_value": x["val_acc"][max_acc_index],
        "over_75_epoch": x["epoch"][over_75_index] if over_75_index else -1,
        "over_75_value": x["val_acc"][over_75_index] if over_75_index else -1,
        "over_85_epoch": x["epoch"][over_85_index] if over_85_index else -1,
        "over_85_value": x["val_acc"][over_85_index] if over_85_index else -1,
        "end_acc_epoch": x["epoch"][end_acc_index],
        "end_acc_value": x["val_acc"][end_acc_index],
    }

    return stats_dict

In [None]:
runs_full_df["summary"] = runs_full_df["summary"].apply(df_stats).dropna()
df_temp = pd.json_normalize(runs_full_df.pop("config"))
runs_full_df = runs_full_df.join(df_temp)
df_temp = pd.json_normalize(runs_full_df.pop("summary"))
runs_full_df = runs_full_df.join(df_temp)
runs_full_df = runs_full_df.set_index("name")

In [None]:
matches = {
    f"match_{x}": {
        "run_id": params[x][0],
        "architecture": params[x][1],
    }
    for x in range(len(li))
}

In [None]:
for match in tqdm(matches):
    run_id = matches[match]["run_id"]
    if run_id == "test":
        continue
    matches[match]["max_epochs"] = runs_full_df.loc[f"{run_id}"]["max_epochs"].iloc[0]
    matches[match]["dataset_size"] = runs_full_df.loc[f"{run_id}"]["dataset_size"].iloc[
        0
    ]
    matches[match]["num_distractors"] = runs_full_df.loc[f"{run_id}"][
        "num_distractors"
    ].iloc[0]
    matches[match]["seq_length"] = runs_full_df.loc[f"{run_id}"]["seq_length"].iloc[0]
    matches[match]["seq_window"] = runs_full_df.loc[f"{run_id}"]["seq_window"].iloc[0]
    matches[match]["repeat_chance"] = runs_full_df.loc[f"{run_id}"][
        "repeat_chance"
    ].iloc[0]
    matches[match]["max_length"] = runs_full_df.loc[f"{run_id}"]["max_length"].iloc[0]
    matches[match]["vocab_size"] = runs_full_df.loc[f"{run_id}"]["vocab_size"].iloc[0]
    matches[match]["one_hot"] = runs_full_df.loc[f"{run_id}"]["one_hot"].iloc[0]

In [None]:
for idx, match in enumerate(matches):
    for col in li[idx].columns:
        arr = []
        for x in li[idx][col]:
            arr.append(x)
        arr = np.array(arr)
        matches[match][col] = arr

del li

### Test set accuracy

In [None]:
for match in tqdm(matches):
    guesses = matches[match]["guess"].flatten()
    targets = matches[match]["target_id"].flatten()
    correct = sum(guesses == targets)
    total = len(targets)
    matches[match]["test_len"] = total
    matches[match]["test_acc"] = correct / total

### Prune unconverged runs

In [None]:
# Have to pull out the keys, otherwise dict size changes and loop fails
mtcs = list(matches.keys())
for match in tqdm(mtcs):
    if matches[match]["test_acc"] < 0.75:
        del matches[match]

# Re-number matches for later easier processing
matches = {f"match_{i}": v for i, (k, v) in enumerate(matches.items())}

## Temporal references across dataset

In [None]:
def compute_language_stats(match_to_compute) -> (float, float):
    """
    Compute the language stats for a match.

    This uses the emlangkit Langauge function.

    Parameters
    ----------
    match_to_compute: dict
        Match for which to compute the stats

    Returns
    -------
        Computed mpn and mutual information.
    """
    lang = Language(
        messages=match_to_compute["message"],
        observations=match_to_compute["tds"],
        prev_horizon=30,
    )
    mpn_val = lang.mpn()
    mi_val = lang.mutual_information()
    return mpn_val, mi_val

In [None]:
start_time = time.perf_counter()

results = Parallel(n_jobs=os.cpu_count(), verbose=10)(
    delayed(compute_language_stats)(match_to_compute=matches[match])
    for match in matches
)

for x in range(len(matches)):
    matches[f"match_{x}"]["mpn_val"] = copy.deepcopy(results[x][0])
    matches[f"match_{x}"]["mi_val"] = copy.deepcopy(results[x][1])

finish_time = time.perf_counter()
print(f"Computing stats finished in {finish_time-start_time} seconds")
del results

## Temporal references within dataset

In [None]:
def compute_trwd_stats(match_to_compute) -> (dict, dict):
    tpg_dict = defaultdict(
        lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
    )
    obs_counts_dict = {x: 0 for x in ["begin", "begin+1", "end-1", "end"]}
    for x in range(match_to_compute["test_len"]):
        if "total" not in tpg_dict[f'{match_to_compute["message"][x]}']:
            tpg_dict[f'{match_to_compute["message"][x]}']["total"] = 0
        if "correct" not in tpg_dict[f'{match_to_compute["message"][x]}']:
            tpg_dict[f'{match_to_compute["message"][x]}']["correct"] = 0
        if "indices" not in tpg_dict[f'{match_to_compute["message"][x]}']:
            tpg_dict[f'{match_to_compute["message"][x]}']["indices"] = []

        tpg_dict[f'{match_to_compute["message"][x]}']["total"] += 1
        if match_to_compute["target_id"][x][0] == match_to_compute["guess"][x]:
            tpg_dict[f'{match_to_compute["message"][x]}']["correct"] += 1

        s_obs = match_to_compute["cut_inputs"][x]
        if s_obs[0] == -1:
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r1"][
                f"{s_obs[1]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r2"][
                f"{s_obs[2]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r3"][
                f"{s_obs[3]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r4"][
                f"{s_obs[4]}"
            ] += 1

            if "begin" not in tpg_dict[f'{match_to_compute["message"][x]}']:
                tpg_dict[f'{match_to_compute["message"][x]}']["begin"] = 0
            tpg_dict[f'{match_to_compute["message"][x]}']["begin"] += 1
            obs_counts_dict["begin"] += 1

        elif s_obs[1] == -1:
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l1"][
                f"{s_obs[0]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r1"][
                f"{s_obs[2]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r2"][
                f"{s_obs[3]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r3"][
                f"{s_obs[4]}"
            ] += 1

            if "begin+1" not in tpg_dict[f'{match_to_compute["message"][x]}']:
                tpg_dict[f'{match_to_compute["message"][x]}']["begin+1"] = 0
            tpg_dict[f'{match_to_compute["message"][x]}']["begin+1"] += 1
            obs_counts_dict["begin+1"] += 1

        elif s_obs[2] == -1:
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l1"][
                f"{s_obs[1]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l2"][
                f"{s_obs[0]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r1"][
                f"{s_obs[3]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r2"][
                f"{s_obs[4]}"
            ] += 1
        elif s_obs[3] == -1:
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l1"][
                f"{s_obs[2]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l2"][
                f"{s_obs[1]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l3"][
                f"{s_obs[0]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["r1"][
                f"{s_obs[4]}"
            ] += 1
            if "end-1" not in tpg_dict[f'{match_to_compute["message"][x]}']:
                tpg_dict[f'{match_to_compute["message"][x]}']["end-1"] = 0
            tpg_dict[f'{match_to_compute["message"][x]}']["end-1"] += 1
            obs_counts_dict["end-1"] += 1
        elif s_obs[4] == -1:
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l1"][
                f"{s_obs[3]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l2"][
                f"{s_obs[2]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l3"][
                f"{s_obs[1]}"
            ] += 1
            tpg_dict[f'{match_to_compute["message"][x]}']["obs_neighbours"]["l4"][
                f"{s_obs[0]}"
            ] += 1
            if "end" not in tpg_dict[f'{match_to_compute["message"][x]}']:
                tpg_dict[f'{match_to_compute["message"][x]}']["end"] = 0
            tpg_dict[f'{match_to_compute["message"][x]}']["end"] += 1
            obs_counts_dict["end"] += 1

    for key, v in tpg_dict.items():
        values, counts = np.unique(
            np.array(tpg_dict[key]["indices"]), return_counts=True
        )
        tpg_dict[key]["indices_unq"] = {
            value: countt for value, countt in zip(values, counts)
        }
    return tpg_dict, obs_counts_dict

In [None]:
start_time = time.perf_counter()

results = Parallel(n_jobs=os.cpu_count(), verbose=10)(
    delayed(compute_trwd_stats)(match_to_compute=matches[match]) for match in matches
)

for x in range(len(matches)):
    matches[f"match_{x}"]["tpg_stats"] = copy.deepcopy(results[x][0])
    matches[f"match_{x}"]["obs_counts"] = copy.deepcopy(results[x][1])

finish_time = time.perf_counter()
print(f"Computing stats finished in {finish_time-start_time} seconds")
del results

### Non-compositional NPMI

Here we calculate the NPMI for non-compositional messages, including those which are used as special operators, i.e., "begin", "begin+1", "end-1" or "end".

In [None]:
# Calculate the normalised pointwise mutual information for non-compositional messages
# There is some divisions by nans, so we ignore this for this block
np.seterr(divide="ignore", invalid="ignore")

# Check for the top_n integer/pos combinations
for top_n in top_ns:
    for match in tqdm(matches):
        non_compositional_npmi_dict = compute_non_compositional_npmi(
            matches[match], top_n
        )
        matches[match][f"nc_npmi_{top_n}"] = non_compositional_npmi_dict

np.seterr(divide="warn", invalid="warn")

### Compositional (n-grams) NPMI

Here, we assume some compositionality, and test whether specific ngrams refer to position, together with another n-gram referring to the integer in that position

In [None]:
# Generate all n-grams
# We do it the easy way-ish by generating all messages with say starting 1 and checking them all
# "16 20" in "[1 16 20]"
n_grams = defaultdict(dict)
for x in [1, 2, 3]:
    for n_gram in list(itertools.product([x for x in range(26)], repeat=x)):
        n_grams[n_gram]["length"] = x
n_grams = {
    str(n_gram)
    .replace("(", "")
    .replace(")", "")
    .replace(",", ""): n_grams[n_gram]["length"]
    for n_gram in n_grams.keys()
}

In [None]:
start_time = time.perf_counter()

for top_n in top_ns:
    results = Parallel(n_jobs=os.cpu_count(), verbose=10)(
        delayed(compute_compositional_ngrams_integers_npmi)(
            match=matches[match], n_grams=n_grams, top_n=top_n
        )
        for match in matches
    )

    for x in range(len(matches)):
        matches[f"match_{x}"][f"ngram_npmi_integers_{top_n}"] = copy.deepcopy(
            results[x][0]
        )
        if "ngrams_pruned" not in matches[f"match_{x}"]:
            matches[f"match_{x}"]["ngrams_pruned"] = copy.deepcopy(results[x][1])

finish_time = time.perf_counter()
print(f"Computing stats finished in {finish_time-start_time} seconds")
del results

In [None]:
start_time = time.perf_counter()

for confidence in confidences:
    for top_n in top_ns:
        results = Parallel(n_jobs=os.cpu_count(), verbose=10)(
            delayed(compute_compositional_ngrams_positionals_npmi)(
                match=matches[match],
                n_grams=n_grams,
                confidence=confidence,
                top_n=top_n,
                scale=10,
            )
            for match in matches
        )

        for x in range(len(matches)):
            matches[f"match_{x}"][
                f"ngram_npmi_positionals_{top_n}_{confidence}"
            ] = copy.deepcopy(results[x])

finish_time = time.perf_counter()
print(f"Computing metrics finished in {finish_time-start_time} seconds")
del results

## Visualising data

In [None]:
for match in matches:
    matches[match]["non_compositional_emerged"] = set()
    matches[match]["non_compositional_reserved_emerged"] = set()
    matches[match]["non_compositional_int_emerged"] = set()
    matches[match]["compositional_pos_emerged"] = set()
    matches[match]["compositional_int_emerged"] = set()

In [None]:
# Find all the messages that are non-compositional
nc_dicts = {}

for top_n in tqdm(top_ns):
    for confidence in confidences:
        non_compositional_message_translation_dict = {}
        for match in matches:
            non_compositional_identified = []
            non_compositional_message_translation_dict[match] = {
                "arch": matches[match]["architecture"],
                "run_id": matches[match]["run_id"],
                "positional_messages": {
                    x: [] for x in ["begin", "begin+1", "end-1", "end"]
                },
                "other_messages": defaultdict(lambda: defaultdict(list)),
            }
            non_compositional_npmi_dict = matches[match][f"nc_npmi_{top_n}"]
            for msg in non_compositional_npmi_dict:
                for special in ["begin", "begin+1", "end-1", "end"]:
                    if non_compositional_npmi_dict[msg][special] >= confidence:
                        # print(f"{msg} is {special} in {match}")
                        non_compositional_identified.append(msg)
                        non_compositional_message_translation_dict[match][
                            "positional_messages"
                        ][special].append(
                            np.fromstring(
                                msg.replace("[", "").replace("]", "").strip(),
                                sep=" ",
                                dtype=np.int8,
                            )
                        )
                        matches[match]["non_compositional_emerged"].add(
                            f"{top_n}_{confidence}"
                        )
                for pos in non_compositional_npmi_dict[msg]:
                    if pos in ["begin", "begin+1", "end-1", "end"]:
                        continue
                    if non_compositional_npmi_dict[msg][pos]["npmi"] >= confidence:
                        ints = [
                            int(x)
                            for x in non_compositional_npmi_dict[msg][pos]["ints"]
                        ]
                        # print(f"{msg} is {pos} for {ints} in {match}")
                        for x in ints:
                            non_compositional_message_translation_dict[match][
                                "other_messages"
                            ][pos][x].append(
                                np.fromstring(
                                    msg.replace("[", "").replace("]", "").strip(),
                                    sep=" ",
                                    dtype=np.int8,
                                )
                            )
                        matches[match]["non_compositional_int_emerged"].add(
                            f"{top_n}_{confidence}"
                        )

            for msg in non_compositional_identified:
                count = 0
                msg_c = [
                    x
                    for x in msg.replace("[", "").replace("]", "").strip().split(" ")
                    if x
                ]

                if msg_c[0] == msg_c[1] == msg_c[2]:
                    if len(msg_c[0]) == 1:
                        msg_c[0] = msg_c[0].join(
                            " "
                        )  # Make sure 1 is present as 1 not as 11, for example
                    for msg1 in matches[match]["tpg_stats"]:
                        if msg_c[0] in msg1:
                            count += 1
                else:
                    continue
                if count <= 2:
                    matches[match]["non_compositional_reserved_emerged"].add(
                        f"{top_n}_{confidence}"
                    )

        nc_dicts[
            f"topn_{top_n}-confidence_{confidence}"
        ] = non_compositional_message_translation_dict

In [None]:
# Find all n-grams that may represent some integers
c_dicts = {}
for top_n in tqdm(top_ns):
    for confidence in confidences:
        compositional_message_translation_dict = {}
        for match in matches:
            compositional_message_translation_dict[match] = {
                "arch": matches[match]["architecture"],
                "run_id": matches[match]["run_id"],
                "positional_ngrams": defaultdict(
                    lambda: defaultdict(list)
                ),  # format is {requested_pos_reference: {needed_pos: [ngrams]}}
                "integer_ngrams": defaultdict(
                    lambda: defaultdict(list)
                ),  # format is {requested_int_reference: {needed_pos: [ngrams]}}
            }
            ngram_npmi_integers_dict = matches[match][f"ngram_npmi_integers_{top_n}"]
            ngram_npmi_positionals_dict = matches[match][
                f"ngram_npmi_positionals_{top_n}_{confidence}"
            ]

            if len(ngram_npmi_positionals_dict.keys()) > 1:
                matches[match]["compositional_pos_emerged"].add(f"{top_n}_{confidence}")

            if len(ngram_npmi_integers_dict.keys()) > 1:
                matches[match]["compositional_int_emerged"].add(f"{top_n}_{confidence}")

            for ngram in ngram_npmi_integers_dict:
                ngram_np = np.array([x for x in ngram.split(" ") if x], dtype=np.uint8)
                for pos in ngram_npmi_integers_dict[ngram]:
                    if len(ngram_npmi_integers_dict[ngram][pos]) == 0:
                        continue
                    if ngram_npmi_integers_dict[ngram][pos]["value"] > confidence:
                        for x in ngram_npmi_integers_dict[ngram][pos]["integers"]:
                            compositional_message_translation_dict[match][
                                "integer_ngrams"
                            ][pos][int(x)].append(ngram_np)
            for ngram in ngram_npmi_positionals_dict:
                ngram_np = np.array([x for x in ngram.split(" ")], dtype=np.uint8)
                for pos in ngram_npmi_positionals_dict[ngram]:
                    if len(ngram_npmi_positionals_dict[ngram][pos]) == 0:
                        continue
                    for referent_pos in ngram_npmi_positionals_dict[ngram][pos]:
                        if (
                            ngram_npmi_positionals_dict[ngram][pos][referent_pos]
                            > confidence
                        ):
                            compositional_message_translation_dict[match][
                                "positional_ngrams"
                            ][pos][referent_pos].append(ngram_np)
        c_dicts[
            f"topn_{top_n}-confidence_{confidence}"
        ] = compositional_message_translation_dict

In [None]:
# Find all runs where temporal references have emerged
for match in matches:
    if any(matches[match]["mpn_val"] > 99):
        matches[match]["mpn_emerged"] = 1

### Save the data
This data is needed to run the evaluations.

In [None]:
matches_dict = default_to_regular(matches)
with open("matches.pickle", "wb") as handle:
    pickle.dump(matches_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
nc_dicts = default_to_regular(nc_dicts)
with open("dictionary_nc.pickle", "wb") as handle:
    pickle.dump(
        nc_dicts,
        handle,
        protocol=pickle.HIGHEST_PROTOCOL,
    )

In [None]:
c_dicts = default_to_regular(c_dicts)
with open("dictionary_c.pickle", "wb") as handle:
    pickle.dump(
        c_dicts,
        handle,
        protocol=pickle.HIGHEST_PROTOCOL,
    )

### Reload the data

In [None]:
with open("matches.pickle", "rb") as handle:
    matches = pickle.load(handle)

In [None]:
with open("dictionary_nc.pickle", "rb") as handle:
    nc_dicts = pickle.load(handle)

In [None]:
with open("dictionary_c.pickle", "rb") as handle:
    c_dicts = pickle.load(handle)

### Load the Accuracy Data
This data comes from the evaluation. So to run the evaluation, the dictionaries must be saved above, the evaluation ran, and then we come back to this point.

In [None]:
with open("agent_accuracy_full.pickle", "rb") as handle:
    agent_accuracy_full = pickle.load(handle)

In [None]:
with open("agent_dict_lens_full.pickle", "rb") as handle:
    agent_dict_lens_full = pickle.load(handle)

## Plots

In [None]:
emergence_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
for top_n in tqdm(top_ns):
    for confidence in confidences:
        for match in matches:
            # We skip the other architectures
            if "GRU" not in matches[match]["architecture"]:
                continue
            emergence_dict[matches[match]["architecture"]]["total"][
                f"{top_n}_{confidence}"
            ] += 1
            emergence_dict[matches[match]["architecture"]]["accuracy"][
                f"{top_n}_{confidence}"
            ] += matches[match]["test_acc"]
            emergence_dict[matches[match]["architecture"]]["non_compositional_emerged"][
                f"{top_n}_{confidence}"
            ] += (
                1
                if f"{top_n}_{confidence}"
                in matches[match]["non_compositional_emerged"]
                else 0
            )
            emergence_dict[matches[match]["architecture"]][
                "non_compositional_int_emerged"
            ][f"{top_n}_{confidence}"] += (
                1
                if f"{top_n}_{confidence}"
                in matches[match]["non_compositional_int_emerged"]
                else 0
            )
            emergence_dict[matches[match]["architecture"]][
                "non_compositional_reserved_emerged"
            ][f"{top_n}_{confidence}"] += (
                1
                if f"{top_n}_{confidence}"
                in matches[match]["non_compositional_reserved_emerged"]
                else 0
            )
            emergence_dict[matches[match]["architecture"]]["compositional_int_emerged"][
                f"{top_n}_{confidence}"
            ] += (
                1
                if f"{top_n}_{confidence}"
                in matches[match]["compositional_int_emerged"]
                else 0
            )
            emergence_dict[matches[match]["architecture"]]["compositional_pos_emerged"][
                f"{top_n}_{confidence}"
            ] += (
                1
                if f"{top_n}_{confidence}"
                in matches[match]["compositional_pos_emerged"]
                else 0
            )

for top_n in tqdm(top_ns):
    for confidence in confidences:
        for k in emergence_dict:
            emergence_dict[k]["accuracy"][f"{top_n}_{confidence}"] /= emergence_dict[k][
                "total"
            ][f"{top_n}_{confidence}"]
            emergence_dict[k]["non_compositional_emerged"][
                f"{top_n}_{confidence}"
            ] /= emergence_dict[k]["total"][f"{top_n}_{confidence}"]
            emergence_dict[k]["non_compositional_int_emerged"][
                f"{top_n}_{confidence}"
            ] /= emergence_dict[k]["total"][f"{top_n}_{confidence}"]
            emergence_dict[k]["non_compositional_reserved_emerged"][
                f"{top_n}_{confidence}"
            ] /= emergence_dict[k]["total"][f"{top_n}_{confidence}"]
            emergence_dict[k]["compositional_int_emerged"][
                f"{top_n}_{confidence}"
            ] /= emergence_dict[k]["total"][f"{top_n}_{confidence}"]
            emergence_dict[k]["compositional_pos_emerged"][
                f"{top_n}_{confidence}"
            ] /= emergence_dict[k]["total"][f"{top_n}_{confidence}"]

In [None]:
for k in emergence_dict["BaseGRU"]:
    total = 0
    for ks in emergence_dict["BaseGRU"][k]:
        if "1_" in ks:
            total += emergence_dict["BaseGRU"][k][ks]
    print(f"Average {k} is {total/9}")

## Vocabulary coverage

In [None]:
coverages_p_full = []
coverages_nc_full = []

for confidence in confidences:
    for top_n in top_ns:
        coverages_p = []
        coverages_nc = []
        coverages_c = []
        coverages_total = []

        for match in matches:
            # We are only interested in BaseGRU
            if matches[match]["architecture"] != "BaseGRU":
                continue
            msg_set = set(list(matches[match]["tpg_stats"].keys()))

            # Non-compositional messages are quite straightforward
            identified_msgs_nc_p = list(
                nc_dicts[f"topn_{top_n}-confidence_{confidence}"][match][
                    "positional_messages"
                ].values()
            )
            identified_msgs_nc_o = list(
                nc_dicts[f"topn_{top_n}-confidence_{confidence}"][match][
                    "other_messages"
                ].values()
            )
            count_nc_p = len(list(itertools.chain(*identified_msgs_nc_p)))
            count_nc_o = len(list(itertools.chain(*identified_msgs_nc_o)))
            coverages_p.append(count_nc_p / len(msg_set))
            coverages_nc.append(count_nc_o / len(msg_set))

        coverages_p_full.append(
            {
                "confidence": confidence,
                "mean": np.mean(coverages_p),
            }  # There is no top_n in positional messages
        )
        coverages_nc_full.append(
            {"top_n": top_n, "confidence": confidence, "mean": np.mean(coverages_nc)}
        )

In [None]:
def compute_compo_stats(
    matchh, match_int_c_dict, match_pos_c_dict, top_nn, confidencee
):
    coverages_compo_pos = []
    coverages_compo_int = []
    int_ngrams_present = 0
    pos_ngrams_present = 0

    messages = matchh["message"]
    obs = matchh["cut_inputs"]

    pos_int_dict = {
        int_pos: list(match_int_c_dict[int_pos].keys())
        for int_pos in ["npmi_pos_0", "npmi_pos_1", "npmi_pos_2", "inv_npmi"]
    }

    pos_pos_dict = {
        pos_pos: list(match_pos_c_dict[pos_pos].keys())
        for pos_pos in ["npmi_pos_0", "npmi_pos_1", "npmi_pos_2", "inv_npmi"]
    }

    # Check if integer n-grams possibly used
    for msg_id, (msgg, observation) in enumerate(zip(messages, obs)):
        msgg = list(msgg)
        int_ngram_found = False
        pos_ngram_found = False

        int_poses = []
        for pos in ["npmi_pos_0", "npmi_pos_1", "npmi_pos_2", "inv_npmi"]:
            valid_ints_pos = pos_int_dict[pos]
            # valid_ints = [valid_int for valid_int in valid_ints if valid_int in observation]
            valid_ints = []
            # Less generalisable, but so much faster
            for xx in range(5):
                if observation[xx] in valid_ints_pos:
                    valid_ints.append(observation[xx])
            if any(valid_ints):
                for valid_int in valid_ints:
                    for ngram in match_int_c_dict[pos][valid_int]:
                        ngram = list(ngram)
                        if "0" in pos:
                            start_pos = 0
                        elif "1" in pos:
                            start_pos = 1
                        elif "2" in pos:
                            start_pos = 2
                        elif "inv" in pos:
                            start_pos = -1
                        else:
                            raise ValueError("Invalid pos.")

                        if start_pos != -1:
                            if ngram == msgg[start_pos : start_pos + len(ngram)]:
                                int_ngram_found = True
                                int_poses.append(
                                    {
                                        "ngram": ngram,
                                        "msg": msgg,
                                        "length": len(ngram),
                                        "pos": pos,
                                        "integer": valid_int,
                                    }
                                )
                        else:
                            # Check all possible positions for the invariant ngrams
                            for start_pos in range(4 - len(ngram)):
                                if ngram == msgg[start_pos : start_pos + len(ngram)]:
                                    int_ngram_found = True
                                    int_poses.append(
                                        {
                                            "ngram": ngram,
                                            "msg": msgg,
                                            "length": len(ngram),
                                            "pos": pos,
                                            "integer": valid_int,
                                        }
                                    )

        relative_pos_integers = {
            "l": [],
            "r": [],
        }
        target_id = np.where(observation == -1)[0][0]
        for msg_id in range(1, 5):
            # Traverse left
            if target_id - msg_id >= 0:
                relative_pos_integers[f"l{msg_id}"] = observation[
                    target_id - msg_id
                ].item()
                relative_pos_integers["l"].append(
                    observation[target_id - msg_id].item()
                )
            # Traverse right
            if target_id + msg_id < len(observation):
                relative_pos_integers[f"r{msg_id}"] = observation[
                    target_id + msg_id
                ].item()
                relative_pos_integers["r"].append(
                    observation[target_id + msg_id].item()
                )

        valid_obs_pos = [
            k
            for k in relative_pos_integers.keys()
            if relative_pos_integers[k] != -1 and relative_pos_integers[k] != []
        ]

        pos_poses = []
        if f"{top_nn}_{confidencee}" in matchh["compositional_pos_emerged"]:
            for pos in ["npmi_pos_0", "npmi_pos_1", "npmi_pos_2", "inv_npmi"]:
                valid_poss = pos_pos_dict[pos]
                valid_poss = [x for x in valid_poss if x in valid_obs_pos]
                for obs_pos in valid_poss:
                    for ngram in match_pos_c_dict[pos][obs_pos]:
                        ngram = list(ngram)
                        if "0" in pos:
                            start_pos = 0
                        elif "1" in pos:
                            start_pos = 1
                        elif "2" in pos:
                            start_pos = 2
                        elif "inv" in pos:
                            start_pos = -1
                        else:
                            raise ValueError("Invalid pos.")

                        if start_pos != -1:
                            if ngram == msgg[start_pos : start_pos + len(ngram)]:
                                pos_ngram_found = True
                                pos_poses.append(
                                    {
                                        "ngram": ngram,
                                        "msg": msgg,
                                        "length": len(ngram),
                                        "pos": pos,
                                        "ref_pos": obs_pos,
                                    }
                                )
                        else:
                            # Check all possible positions for the invariant ngrams
                            for start_pos in range(4 - len(ngram)):
                                if ngram == msgg[start_pos : start_pos + len(ngram)]:
                                    pos_ngram_found = True
                                    pos_poses.append(
                                        {
                                            "ngram": ngram,
                                            "msg": msgg,
                                            "length": len(ngram),
                                            "pos": pos,
                                            "ref_pos": obs_pos,
                                        }
                                    )

        int_ngrams_present += int_ngram_found
        pos_ngrams_present += pos_ngram_found

    count_cc = len(matchh["message"])
    coverages_compo_int.append(int_ngrams_present / count_cc)
    if f"{top_nn}_{confidencee}" in matchh["compositional_pos_emerged"]:
        coverages_compo_pos.append(pos_ngrams_present / count_cc)

    return (
        coverages_compo_int,
        coverages_compo_pos,
        confidencee,
        top_nn,
    )

In [None]:
coverages_compo_pos_full = []
coverages_compo_int_full = []

results = Parallel(n_jobs=os.cpu_count() * 2, verbose=10)(
    delayed(compute_compo_stats)(
        matchh=matches[match],
        confidencee=confidence,
        top_nn=top_n,
        match_int_c_dict=c_dicts[f"topn_{top_n}-confidence_{confidence}"][match][
            "integer_ngrams"
        ],
        match_pos_c_dict=c_dicts[f"topn_{top_n}-confidence_{confidence}"][match][
            "positional_ngrams"
        ],
    )
    for match in matches
    if matches[match]["architecture"] == "BaseGRU"
    for confidence in confidences
    for top_n in top_ns
)

for res in results:
    res_int = res[0]
    res_pos = res[1]
    conf = res[2]
    topnn = res[3]
    coverages_compo_pos_full.append(
        {
            "confidence": conf,
            "top_n": topnn,
            "mean": res_pos,
        }
    )
    coverages_compo_int_full.append(
        {
            "confidence": conf,
            "top_n": topnn,
            "mean": res_int,
        }
    )

In [None]:
with open("coverages_compo_pos_full.pickle", "wb") as handle:
    pickle.dump(
        coverages_compo_pos_full,
        handle,
        protocol=pickle.HIGHEST_PROTOCOL,
    )
with open("coverages_compo_int_full.pickle", "wb") as handle:
    pickle.dump(
        coverages_compo_int_full,
        handle,
        protocol=pickle.HIGHEST_PROTOCOL,
    )

In [None]:
with open("coverages_compo_pos_full.pickle", "rb") as handle:
    coverages_compo_pos_full = pickle.load(handle)
with open("coverages_compo_int_full.pickle", "rb") as handle:
    coverages_compo_int_full = pickle.load(handle)

In [None]:
n_coverages_compo_pos_full = []
for idx in range(len(coverages_compo_pos_full)):
    if not coverages_compo_pos_full[idx]["mean"]:
        continue
    coverages_compo_pos_full[idx]["mean"] = coverages_compo_pos_full[idx]["mean"][0]
    n_coverages_compo_pos_full.append(coverages_compo_pos_full[idx])

In [None]:
n_coverages_compo_int_full = []
for idx in range(len(coverages_compo_int_full)):
    if not coverages_compo_int_full[idx]["mean"]:
        continue
    coverages_compo_int_full[idx]["mean"] = coverages_compo_int_full[idx]["mean"][0]
    n_coverages_compo_int_full.append(coverages_compo_int_full[idx])

In [None]:
df = pd.DataFrame(n_coverages_compo_pos_full)
print(df.mean())
(df.groupby(["confidence", "top_n"]).mean())

In [None]:
df = pd.DataFrame(n_coverages_compo_int_full)
print(df.mean())
(df.groupby(["confidence", "top_n"]).mean())

## Accuracy and Generalisation Plots

In [None]:
# Plot runs with one_hots
fig, ax = plt.subplots(figsize=(14, 8))
sns.boxplot(
    data=runs_full_df[
        runs_full_df.index.isin(["DeGcChUp", "PaMMJSFQ", "PGszP2Fb", "HNEVXAXA"])
    ],
    x="one_hot",
    y="max_acc_value",
    ax=ax,
).set(
    xlabel="Input type",
    ylabel="Max accuracy",
    xticklabels=[
        "Scalar",
        "One Hot",
    ],
)
ax.yaxis.set_major_formatter(PercentFormatter(xmax=1))
ax.set_ylim([0, 1])
plt.savefig("one_hot.pdf", bbox_inches="tight", pad_inches=0)
plt.savefig("one_hot.png", bbox_inches="tight", pad_inches=0)
plt.show()

In [None]:
not_one_hot_df = runs_full_df[runs_full_df["one_hot"] == False]

In [None]:
fig, ax = plt.subplots(figsize=(14, 8))
sns.boxplot(
    data=runs_full_df,
    x="seq_length",
    y="max_acc_value",
    hue="sender_hidden",
    palette=palette,
).set(
    xlabel="Sequence Length",
    ylabel="Max accuracy",
)
ax.yaxis.set_major_formatter(PercentFormatter(xmax=1))
ax.set_ylim([0, 1])
hatches = ["///", "///", "///", "///", ".."]
# noinspection PyUnresolvedReferences
patches = [patch for patch in ax.patches if type(patch) == mpl.patches.PathPatch]
# iterate through the patches for each subplot
for patch, hatch in zip(patches, hatches):
    patch.set_hatch(hatch)
    fc = patch.get_facecolor()
    patch.set_edgecolor(fc)
    patch.set_facecolor("none")
h, _ = ax.get_legend_handles_labels()
l = ax.legend(
    h,
    [64, 128],
    title="Hidden Size",
    ncols=2,
    bbox_to_anchor=(0.45, 1.25),
    loc="upper center",
    labelspacing=0.35,
    columnspacing=1,
    handletextpad=0.7,
)
for lp, hatch in zip(
    l.get_patches(),
    [
        "///",
        "..",
    ],
):
    lp.set_hatch(hatch)
    fc = lp.get_facecolor()
    lp.set_edgecolor(fc)
    lp.set_facecolor("none")
plt.savefig("seq_len.pdf", bbox_inches="tight", pad_inches=0)
plt.savefig("seq_len.png", bbox_inches="tight", pad_inches=0)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(14, 8))
sns.boxplot(
    data=runs_full_df[runs_full_df["seq_length"] == 60],
    x="vocab_size",
    y="max_acc_value",
).set(
    xlabel="Vocabulary size",
    ylabel="Max accuracy",
)
ax.yaxis.set_major_formatter(PercentFormatter(xmax=1, decimals=0))
ax.set_ylim([0, 1])
plt.savefig("vocab.pdf", bbox_inches="tight", pad_inches=0)
plt.savefig("vocab.png", bbox_inches="tight", pad_inches=0)
plt.show()

### Evaluation of different sequence lengths

The files required here are generated by eval_model.py

In [None]:
with open("data/eval_data.pickle", "rb") as handle:
    eval_data = pickle.load(handle)

In [None]:
data_list = []

for k in eval_data:
    for seq in eval_data[k]["BaseGRU"]:
        average = sum(eval_data[k]["BaseGRU"][seq]) / len(eval_data[k]["BaseGRU"][seq])
        data_list.append(
            {
                "train_seq_length": not_one_hot_df["seq_length"][k],
                "neg_seq": seq,
                "average": average,
            }
        )

fig_data = pd.DataFrame(data_list)

In [None]:
fig, ax = plt.subplots(figsize=(14, 8))
sns.boxplot(
    data=fig_data,
    x="neg_seq",
    y="average",
    hue="train_seq_length",
    palette=palette,
).set(
    xlabel="Seq. shortened by",
    ylabel="Avg. accuracy",
)
hatches = sorted(
    [
        "///",
        "..",
        "xx",
        "OO",
        "///",
        "..",
        "xx",
        "OO",
        "///",
        "..",
        "xx",
        "OO",
        "..",
        "xx",
        "OO",
        "xx",
        "OO",
    ]
)
# noinspection PyUnresolvedReferences
patches = [patch for patch in ax.patches if type(patch) == mpl.patches.PathPatch]
# iterate through the patches for each subplot
for patch, hatch in zip(patches, hatches):
    patch.set_hatch(hatch)
    fc = patch.get_facecolor()
    patch.set_edgecolor(fc)
    patch.set_facecolor("none")
h, labels = ax.get_legend_handles_labels()
l = ax.legend(
    h,
    labels,
    title="Training Seq. Length",
    ncols=2,
    bbox_to_anchor=(0.45, 1.35),
    loc="upper center",
    labelspacing=0.35,
    columnspacing=1,
    handletextpad=0.7,
)
for lp, hatch in zip(l.get_patches(), sorted(["///", "..", "xx", "OO"])):
    lp.set_hatch(hatch)
    fc = lp.get_facecolor()
    lp.set_edgecolor(fc)
    lp.set_facecolor("none")
ax.yaxis.set_major_formatter(PercentFormatter(xmax=1, decimals=0))
ax.set_ylim([0, 1])
plt.savefig("eval.pdf", bbox_inches="tight", pad_inches=0)
plt.savefig("eval.png", bbox_inches="tight", pad_inches=0)
plt.show()

In [None]:
table = (
    fig_data.groupby(["train_seq_length", "neg_seq"])
    .mean()
    .reset_index()
    .pivot_table(index="train_seq_length", columns="neg_seq", values="average")
)

In [None]:
table[0] *= 100
table[0] = table[0].round(2)
columns = list(table.columns)
columns.remove(0)
for column in columns:
    table[column] *= 100
    table[column] -= table[0]
    table[column] = table[column].round(2)

In [None]:
table