# Test deletions statistics extractor

In [1]:
import contextlib
import json
import os
from collections import Counter

import pandas as pd
from torchtext.data import get_tokenizer

## Utils

In [2]:
LOGS_PATH = "../logs/"
DATA_PATH = "../generated/datasets"
SAVE_PATH = "../generated/statistics"
SAVE_EXPERIMENT_PATH = "../generated/experiments_statistics"

LOGS_DIR: dict[str, str] = {
    "a2c": "a2c",
    "reinforce": "reinforce",
    "dqn": "dqn",
    "cnn1d": "cnn_1d",
    "cnn2d": "cnn_2d",
    "lstm": "lstm",
    "bilstm": "lstm_bi",
    "lstm_attn": "lstm_attn",
    "bilstm_attn": "lstm_bi_attn",
}

RLS = [
    "a2c",
    "reinforce",
    "dqn",
]

BASELINES = [
    "cnn1d",
    "cnn2d",
    "lstm",
    "bilstm",
    "lstm_attn",
    "bilstm_attn",
]

In [3]:
tokenizer = get_tokenizer("basic_english")

In [4]:
def load_json(path: str) -> dict:
    with open(path, "r") as f:
        return json.load(f)


def save_json(path: str, data) -> None:
    with open(path, "w") as f:
        json.dump(data, f)


def soft_mkdir(path: str) -> None:
    with contextlib.suppress(Exception):
        os.makedirs(path)

In [5]:
def save_df(df: pd.DataFrame, path: str):
    df.loc[:, ~df.columns.str.contains("^Unnamed")].to_csv(path, index=False)

## Calculate test data frequency

In [6]:
def get_dataset_name(dataset_size: str) -> str:
    if dataset_size == "":
        return "test"
    return f"test_{dataset_size}"


def calculate_frequencies(dataset_name: str) -> Counter:
    test_data = pd.read_csv(os.path.join(DATA_PATH, f"{dataset_name}.csv"))
    sentences = [*test_data.target.to_list(), *test_data.candidate.to_list()]
    tokens = [t for x in sentences for t in tokenizer(x)]
    return Counter(tokens)


def save_frequencies(dataset_name: str, frequencies: Counter):
    soft_mkdir(SAVE_PATH)
    save_path = os.path.join(SAVE_PATH, f"{dataset_name}.json")
    save_json(save_path, frequencies)


# for size in ["", "sm","md"]:
#   name = get_dataset_name(size)
#   save_frequencies(name, calculate_frequencies(name))

## Calculate and save experiment statistics

In [7]:
def load_experiment_data(model: str, experiment: str, epoch: int) -> tuple[Counter, str]:
    experiment_path = os.path.join(LOGS_PATH, model, experiment)
    test_size = load_json(os.path.join(experiment_path, "configs.json"))["TEST_SIZE"]
    statistics = load_json(
        os.path.join(experiment_path, "best", str(epoch), "data", "test_deletions.json")
    )
    return Counter(statistics), test_size


def load_frequencies(dataset_size: str) -> Counter:
    return Counter(
        load_json(os.path.join(SAVE_PATH, f"{get_dataset_name(dataset_size)}.json"))
    )


def build_experiment_statistics(
    statistics: Counter, test_frequencies: Counter
) -> pd.DataFrame:
    results = []
    for token, total in test_frequencies.items():
        deleted = statistics[token]

        results.append((token, deleted, total, deleted / total))

    results.sort(key=lambda r: -r[3])
    df = pd.DataFrame(results, columns=["Token", "Deleted", "Total", "Ratio"])
    return df.sort_values(
        ["Ratio", "Total", "Token"], ascending=[False, False, True]
    ).reset_index(drop=True)


def save_experiment_statistics(
    statistics: pd.DataFrame, model: str, experiment: str, epoch: int
):
    save_path = os.path.join(SAVE_EXPERIMENT_PATH, model, experiment)
    soft_mkdir(save_path)
    save_df(statistics, os.path.join(save_path, f"{epoch}.csv"))


def load_save_pipeline(model: str, experiment: str, epoch: int) -> pd.DataFrame:
    model_path = LOGS_DIR[model]
    statistics, test_size = load_experiment_data(model_path, experiment, epoch)
    test_frequencies = load_frequencies(test_size)

    res_stats = build_experiment_statistics(statistics, test_frequencies)
    save_experiment_statistics(res_stats, model_path, experiment, epoch)
    return res_stats

In [8]:
# load_save_pipeline("a2c", "23_04",1410)
# load_save_pipeline("reinforce", "23_04",1400)
# load_save_pipeline("dqn", "23_04",460)
# load_save_pipeline("dqn", "23_04_",1000)
# load_save_pipeline("dqn", "23_04_2", 1140)

## Compare statistics

In [16]:
def load_experiment_statistics(model: str, experiment: str, epoch: int) -> dict:
    load_path = os.path.join(SAVE_EXPERIMENT_PATH, model, experiment)
    res_dict = {}
    df = pd.read_csv(os.path.join(load_path, f"{epoch}.csv"))
    for _, r in df.iterrows():
        res_dict[r["Token"]] = (r["Deleted"], r["Total"], r["Ratio"])
    return res_dict


def compare_statistics(data: list[tuple[str, str, int]], words: list[str]):
    stat_dicts = {m: load_experiment_statistics(m, exp, ep) for (m, exp, ep) in data}

    for w in words:
        print(w.upper())
        for k, d in stat_dicts.items():
            print(f"{k}: {d[w]}")
        print()

In [126]:
def statistics_intersection(data: list[tuple[str, str, int]], n: int = 50):
    stat_dicts = {m: load_experiment_statistics(m, exp, ep) for (m, exp, ep) in data}
    most_sets = [
        set(sorted(d.keys(), key=lambda x: -d[x][2])[:n]) for d in stat_dicts.values()
    ]
    least_sets = [
        set(sorted(d.keys(), key=lambda x: +d[x][2])[:n]) for d in stat_dicts.values()
    ]
    return set.intersection(*most_sets), set.intersection(*least_sets)


# b, a = statistics_intersection([("reinforce", "23_04",1400), ("a2c", "23_04",1410), ("dqn", "23_04_2", 1140)],n=100)
# sorted(b,key=len)

In [127]:
# ["say", "got", "now", "who", "under"]
# ["new","must", "need", "while", "understand"]
# ["the","this", "his","no","but"]

In [128]:
compare_statistics(
    [("reinforce", "23_04", 1400), ("a2c", "23_04", 1410), ("dqn", "23_04_2", 1140)],
    ["got", "now", "who", "under"],
)

GOT
reinforce: (10, 10, 1.0)
a2c: (10, 10, 1.0)
dqn: (10, 10, 1.0)

NOW
reinforce: (9, 9, 1.0)
a2c: (9, 9, 1.0)
dqn: (9, 9, 1.0)

WHO
reinforce: (11, 13, 0.8461538461538461)
a2c: (13, 13, 1.0)
dqn: (10, 13, 0.7692307692307693)

UNDER
reinforce: (5, 8, 0.625)
a2c: (8, 8, 1.0)
dqn: (7, 8, 0.875)



In [129]:
compare_statistics(
    [("reinforce", "23_04", 1400), ("a2c", "23_04", 1410), ("dqn", "23_04_2", 1140)],
    ["new", "must", "need", "while", "understand"],
)

NEW
reinforce: (0, 13, 0.0)
a2c: (0, 13, 0.0)
dqn: (0, 13, 0.0)

MUST
reinforce: (0, 7, 0.0)
a2c: (0, 7, 0.0)
dqn: (0, 7, 0.0)

NEED
reinforce: (0, 5, 0.0)
a2c: (0, 5, 0.0)
dqn: (0, 5, 0.0)

WHILE
reinforce: (0, 6, 0.0)
a2c: (0, 6, 0.0)
dqn: (0, 6, 0.0)

UNDERSTAND
reinforce: (0, 14, 0.0)
a2c: (0, 14, 0.0)
dqn: (0, 14, 0.0)



In [130]:
compare_statistics(
    [("reinforce", "23_04", 1400), ("a2c", "23_04", 1410), ("dqn", "23_04_2", 1140)],
    ["the", "this", "his", "no", "but"],
)

THE
reinforce: (315, 581, 0.5421686746987951)
a2c: (563, 581, 0.9690189328743546)
dqn: (158, 581, 0.2719449225473322)

THIS
reinforce: (11, 27, 0.4074074074074074)
a2c: (27, 27, 1.0)
dqn: (26, 27, 0.9629629629629628)

HIS
reinforce: (16, 27, 0.5925925925925926)
a2c: (27, 27, 1.0)
dqn: (16, 27, 0.5925925925925926)

NO
reinforce: (5, 5, 1.0)
a2c: (5, 5, 1.0)
dqn: (0, 5, 0.0)

BUT
reinforce: (10, 20, 0.5)
a2c: (20, 20, 1.0)
dqn: (18, 20, 0.9)

