# Test deletions statistics extractor

In [2]:
import contextlib
import json
import os
from collections import Counter

import pandas as pd
from torchtext.data import get_tokenizer

## Utils

In [3]:
LOGS_PATH = "../logs/"
DATA_PATH = "../generated/datasets"
SAVE_PATH = "../generated/statistics"
SAVE_EXPERIMENT_PATH = "../generated/experiments_statistics"

LOGS_DIR: dict[str, str] = {
    "a2c": "a2c",
    "reinforce": "reinforce",
    "dqn": "dqn",
    "cnn1d": "cnn_1d",
    "cnn2d": "cnn_2d",
    "lstm": "lstm",
    "bilstm": "lstm_bi",
    "lstm_attn": "lstm_attn",
    "bilstm_attn": "lstm_bi_attn",
}

RLS = [
    "a2c",
    "reinforce",
    "dqn",
]

BASELINES = [
    "cnn1d",
    "cnn2d",
    "lstm",
    "bilstm",
    "lstm_attn",
    "bilstm_attn",
]

In [4]:
tokenizer = get_tokenizer("basic_english")

In [5]:
def load_json(path: str) -> dict:
    with open(path, "r") as f:
        return json.load(f)


def save_json(path: str, data) -> None:
    with open(path, "w") as f:
        json.dump(data, f)


def soft_mkdir(path: str) -> None:
    with contextlib.suppress(Exception):
        os.makedirs(path)

In [6]:
def save_df(df: pd.DataFrame, path: str):
    df.loc[:, ~df.columns.str.contains("^Unnamed")].to_csv(path, index=False)

## Calculate test data frequency

In [7]:
def get_dataset_name(dataset_size: str) -> str:
    if dataset_size == "":
        return "test"
    return f"test_{dataset_size}"


def calculate_frequencies(dataset_name: str) -> Counter:
    test_data = pd.read_csv(os.path.join(DATA_PATH, f"{dataset_name}.csv"))
    sentences = [*test_data.target.to_list(), *test_data.candidate.to_list()]
    tokens = [t for x in sentences for t in tokenizer(x)]
    return Counter(tokens)


def save_frequencies(dataset_name: str, frequencies: Counter):
    soft_mkdir(SAVE_PATH)
    save_path = os.path.join(SAVE_PATH, f"{dataset_name}.json")
    save_json(save_path, frequencies)


# for size in ["", "sm","md"]:
#   name = get_dataset_name(size)
#   save_frequencies(name, calculate_frequencies(name))

## Calculate and save experiment statistics

In [8]:
def load_experiment_data(model: str, experiment: str, epoch: int) -> tuple[Counter, str]:
    experiment_path = os.path.join(LOGS_PATH, model, experiment)
    test_size = load_json(os.path.join(experiment_path, "configs.json"))["TEST_SIZE"]
    statistics = load_json(
        os.path.join(experiment_path, "best", str(epoch), "data", "test_deletions.json")
    )
    return Counter(statistics), test_size


def load_frequencies(dataset_size: str) -> Counter:
    return Counter(
        load_json(os.path.join(SAVE_PATH, f"{get_dataset_name(dataset_size)}.json"))
    )


def build_experiment_statistics(
    statistics: Counter, test_frequencies: Counter
) -> pd.DataFrame:
    results = []
    for token, total in test_frequencies.items():
        deleted = statistics[token]

        results.append((token, deleted, total, deleted / total))

    results.sort(key=lambda r: -r[3])
    df = pd.DataFrame(results, columns=["Token", "Deleted", "Total", "Ratio"])
    return df.sort_values(
        ["Ratio", "Total", "Token"], ascending=[False, False, True]
    ).reset_index(drop=True)


def save_experiment_statistics(
    statistics: pd.DataFrame, model: str, experiment: str, epoch: int
):
    save_path = os.path.join(SAVE_EXPERIMENT_PATH, model, experiment)
    soft_mkdir(save_path)
    save_df(statistics, os.path.join(save_path, f"{epoch}.csv"))


def load_save_pipeline(model: str, experiment: str, epoch: int) -> pd.DataFrame:
    model_path = LOGS_DIR[model]
    statistics, test_size = load_experiment_data(model_path, experiment, epoch)
    test_frequencies = load_frequencies(test_size)

    res_stats = build_experiment_statistics(statistics, test_frequencies)
    save_experiment_statistics(res_stats, model_path, experiment, epoch)
    return res_stats

In [9]:
# load_save_pipeline("a2c", "23_04",1410)
# load_save_pipeline("reinforce", "23_04",1400)
# load_save_pipeline("dqn", "23_04",460)
# load_save_pipeline("dqn", "23_04_",1000)
load_save_pipeline("dqn", "23_04_2", 1140)

Unnamed: 0,Token,Deleted,Total,Ratio
0,alan,19,19,1.0
1,curbishley,19,19,1.0
2,billion,18,18,1.0
3,friday,18,18,1.0
4,turnpike,18,18,1.0
...,...,...,...,...
1139,vibrant,0,1,0.0
1140,warm,0,1,0.0
1141,watch,0,1,0.0
1142,watched,0,1,0.0
