In [None]:
import json
import pandas as pd
import os
import random
import itertools
import pandas as pd
import numpy as np
import re
import plotly.express as px
import plotly.graph_objects as go
from scipy import stats

In [None]:
## Parameters
# PATHs
RESULTS_DIR = "../results"

# Plotting
width = 800
height = 600
scale = 1

In [None]:
## Miscellaneous Functions
# START: COPIED FROM https://stackoverflow.com/questions/30069846/how-to-find-out-chinese-or-japanese-character-in-a-string-in-python
ranges = [
    {"from": ord("\u3300"), "to": ord("\u33ff")},  # compatibility ideographs
    {"from": ord("\ufe30"), "to": ord("\ufe4f")},  # compatibility ideographs
    {"from": ord("\uf900"), "to": ord("\ufaff")},  # compatibility ideographs
    {"from": ord("\U0002F800"), "to": ord("\U0002fa1f")},  # compatibility ideographs
    {"from": ord("\u3040"), "to": ord("\u309f")},  # Japanese Hiragana
    {"from": ord("\u30a0"), "to": ord("\u30ff")},  # Japanese Katakana
    {"from": ord("\u2e80"), "to": ord("\u2eff")},  # cjk radicals supplement
    {"from": ord("\u4e00"), "to": ord("\u9fff")},
    {"from": ord("\u3400"), "to": ord("\u4dbf")},
    {"from": ord("\U00020000"), "to": ord("\U0002a6df")},
    {"from": ord("\U0002a700"), "to": ord("\U0002b73f")},
    {"from": ord("\U0002b740"), "to": ord("\U0002b81f")},
    {"from": ord("\U0002b820"), "to": ord("\U0002ceaf")},  # included as of Unicode 8.0
]


def is_cjk(char):
    return any([range["from"] <= ord(char) <= range["to"] for range in ranges])


def cjk_substrings(string):
    i = 0
    while i < len(string):
        if is_cjk(string[i]):
            start = i
            while is_cjk(string[i]):
                i += 1
            yield string[start:i]
        i += 1


# END: COPIED FROM https://stackoverflow.com/questions/30069846/how-to-find-out-chinese-or-japanese-character-in-a-string-in-python


def set_default(obj):
    if isinstance(obj, set):
        return list(obj)
    raise TypeError


def longest_continuous_common_subsequence(list1, list2):
    max_length = 0

    for i in range(len(list1)):
        for j in range(len(list2)):
            temp_length = 0
            # Check how long the sequence is starting from list1[i] and list2[j]
            while (
                i + temp_length < len(list1)
                and j + temp_length < len(list2)
                and list1[i + temp_length] == list2[j + temp_length]
            ):
                temp_length += 1
            # Update max_length if a longer sequence is found
            max_length = max(max_length, temp_length)

    return max_length


def remove_text_inside_quotes(text):
    return re.sub(r"\".*?\"", "", text)

## Correlation Analysis (Section 6.1)

In [None]:
## Load performance metrics
# Full Performance Scores
tab_large = {
    "data": {
        0: "One-Shot ($1$)",
        1: "Few-Shot ($10$)",
        2: "Few-Shot ($100$)",
        3: "Few-Shot ($1000$)",
        4: "Full-Shot ($2473$)",
        5: "English",
        6: "Full ($4946$)",
    },
    "rouge": {0: 0.161, 1: 0.18, 2: 0.421, 3: 0.451, 4: 0.461, 5: 0.454, 6: 0.536},
    "meteor": {0: 0.116, 1: 0.135, 2: 0.258, 3: 0.376, 4: 0.388, 5: 0.399, 6: 0.489},
    "bleurt": {0: -0.631, 1: -0.597, 2: -0.21, 3: 0.057, 4: 0.04, 5: 0.057, 6: 0.161},
    "comparison": {
        0: 0.822,
        1: 0.857,
        2: 0.962,
        3: 0.984,
        4: 0.984,
        5: 0.981,
        6: 0.978,
    },
    "fluency": {0: 1.073, 1: 1.092, 2: 3.72, 3: 4.022, 4: 4.067, 5: 4.221, 6: 3.867},
    "accuracy": {0: 3.981, 1: 4.011, 2: 4.642, 3: 4.704, 4: 4.747, 5: 4.771, 6: 4.763},
    "model": {
        0: "large",
        1: "large",
        2: "large",
        3: "large",
        4: "large",
        5: "large",
        6: "large",
    },
}
tab_base = {
    "data": {
        0: "One-Shot ($1$)",
        1: "Few-Shot ($10$)",
        2: "Few-Shot ($100$)",
        3: "Few-Shot ($1000$)",
        4: "Full-Shot ($2473$)",
        5: "English ($2473$)",
        6: "Full ($4946$)",
    },
    "rouge": {0: 0.184, 1: 0.178, 2: 0.377, 3: 0.438, 4: 0.45, 5: 0.458, 6: 0.499},
    "meteor": {0: 0.151, 1: 0.152, 2: 0.315, 3: 0.353, 4: 0.366, 5: 0.398, 6: 0.448},
    "bleurt": {0: -0.552, 1: -0.578, 2: -0.014, 3: 0.025, 4: 0.032, 5: 0.072, 6: 0.122},
    "comparison": {
        0: 0.838,
        1: 0.819,
        2: 0.962,
        3: 0.981,
        4: 0.984,
        5: 0.984,
        6: 0.994,
    },
    "fluency": {0: 1.256, 1: 1.181, 2: 3.929, 3: 4.073, 4: 3.99, 5: 4.13, 6: 4.25},
    "accuracy": {0: 3.871, 1: 3.989, 2: 4.642, 3: 4.717, 4: 4.712, 5: 4.771, 6: 4.733},
    "model": {
        0: "base",
        1: "base",
        2: "base",
        3: "base",
        4: "base",
        5: "base",
        6: "base",
    },
}
tab_small = {
    "data": {
        0: "One-Shot ($1$)",
        1: "Few-Shot ($10$)",
        2: "Few-Shot ($100$)",
        3: "Few-Shot ($1000$)",
        4: "Full-Shot ($2473$)",
        5: "English ($2473$)",
        6: "Full ($4946$)",
    },
    "rouge": {0: 0.216, 1: 0.227, 2: 0.306, 3: 0.421, 4: 0.429, 5: 0.444, 6: 0.468},
    "meteor": {0: 0.203, 1: 0.22, 2: 0.235, 3: 0.325, 4: 0.333, 5: 0.363, 6: 0.403},
    "bleurt": {0: -0.48, 1: -0.336, 2: -0.183, 3: 0.001, 4: 0.009, 5: 0.045, 6: 0.083},
    "comparison": {
        0: 0.852,
        1: 0.881,
        2: 0.884,
        3: 0.946,
        4: 0.962,
        5: 0.968,
        6: 0.976,
    },
    "fluency": {0: 3.232, 1: 3.288, 2: 3.536, 3: 3.604, 4: 3.773, 5: 3.855, 6: 3.841},
    "accuracy": {0: 4.051, 1: 4.105, 2: 4.057, 3: 4.553, 4: 4.604, 5: 4.714, 6: 4.73},
    "model": {
        0: "small",
        1: "small",
        2: "small",
        3: "small",
        4: "small",
        5: "small",
        6: "small",
    },
}

# Load as Pandas Dataframe
tab_large = pd.DataFrame(tab_large)
tab_base = pd.DataFrame(tab_base)
tab_small = pd.DataFrame(tab_small)
tab_large["model"] = "large"
tab_base["model"] = "base"
tab_small["model"] = "small"

# Extract and rename columns
table = pd.concat(
    [
        tab_large[
            ["rouge", "meteor", "bleurt", "comparison", "fluency", "accuracy", "model"]
        ],
        tab_base[
            ["rouge", "meteor", "bleurt", "comparison", "fluency", "accuracy", "model"]
        ],
        tab_small[
            ["rouge", "meteor", "bleurt", "comparison", "fluency", "accuracy", "model"]
        ],
    ]
).reset_index()
table = table[
    ["rouge", "meteor", "bleurt", "comparison", "fluency", "accuracy", "model"]
]
table.columns = [
    "ROUGE-L",
    "METEOR",
    "BLEURT",
    "Comparison",
    "Fluency",
    "Accuracy",
    "Model",
]

In [None]:
## Generate Plots
# Miscellaneous (color pads and metric names to loop through)
color_map = {"large": "red", "base": "blue", "small": "green"}
colors = table["Model"].map(color_map)
automated = ["ROUGE-L", "METEOR", "BLEURT"]
prompting = ["Comparison", "Fluency", "Accuracy"]

for a_m in automated:
    for p_m in prompting:
        correlation = np.round(stats.spearmanr(table[a_m], table[p_m]).statistic, 3)
        fig = go.Figure()
        fig = px.scatter(table, x=a_m, y=p_m, color="Model", trendline="ols")

        width = 600
        height = 400
        scale = 1

        fig.update_layout(
            width=width * scale,
            height=height * scale,
            font=dict(size=20 * scale),
            xaxis_title=a_m,
            yaxis_title=f"Prompt {p_m}",
            margin=dict(l=0, r=10, t=50, b=0),
            legend=dict(x=0.8, y=0.05, bordercolor="Black", borderwidth=1),
            title=f"Spearman Correlation: {correlation}",
            title_x=0.5,
            barmode="overlay",
            plot_bgcolor="white",
        )

        fig.update_traces(marker=dict(size=12), selector=dict(mode="markers"))

        fig.update_xaxes(
            mirror=True,
            ticks="outside",
            showline=True,
            linecolor="black",
            gridcolor="lightgrey",
        )
        fig.update_yaxes(
            mirror=True,
            ticks="outside",
            showline=True,
            linecolor="black",
            gridcolor="lightgrey",
        )

        fig.show()
        fig.write_image(f"{a_m}_{p_m}.pdf")

## Explanation Statistics (Section 6)

In [None]:
## Get data statistics
# Load CSV files
gold = pd.read_csv(
    "../results/mt0-base-language=ja-epochs=15-batch_size=4-shots=10/fluency_results.csv"
)["gold_output"]
large = pd.read_csv(
    "../results/mt0-base-language=ja-epochs=15-batch_size=4-shots=10/fluency_results.csv"
)["output"]
base = pd.read_csv(
    "../results/mt0-large-language=ja-epochs=15-batch_size=4-shots=10/fluency_results.csv"
)["output"]
small = pd.read_csv(
    "../results/mt0-small-language=ja-epochs=15-batch_size=4-shots=10/fluency_results.csv"
)["output"]

# Get lengths of explanations
gold_len = [len(i.split()) - 1 for i in gold]
large_len = [len(i.split()) - 1 for i in large]
base_len = [len(i.split()) - 1 for i in base]
small_len = [len(i.split()) - 1 for i in small]

In [None]:
## Generate plots
fig = go.Figure()
fig.add_trace(go.Histogram(x=gold_len, name="Ground Truth"))
fig.add_trace(go.Histogram(x=large_len, name="large"))
fig.add_trace(go.Histogram(x=base_len, name="base"))
fig.add_trace(go.Histogram(x=small_len, name="small"))

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="Explanation Length",
    yaxis_title="Count",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.75, y=0.97, bordercolor="Black", borderwidth=1),
)

fig.update_layout(plot_bgcolor="white")
fig.update_traces(marker={"opacity": 0.85})

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)


fig.show()
fig.write_image("explanation_lengths.pdf")

## Get Sample Explanations (Section 6)

In [None]:
text = pd.read_csv(
    "../results/mt0-base-language=en-epochs=15-batch_size=4/fluency_results.csv"
)["input"]

index = random.randint(0, len(text))

gold = pd.read_csv(
    "../results/mt0-base-language=en-epochs=15-batch_size=4/fluency_results.csv"
)["gold_output"]
large = pd.read_csv(
    "../results/mt0-base-language=en-epochs=15-batch_size=4/fluency_results.csv"
)["output"]
base = pd.read_csv(
    "../results/mt0-large-language=en-epochs=15-batch_size=4/fluency_results.csv"
)["output"]
small = pd.read_csv(
    "../results/mt0-small-language=en-epochs=15-batch_size=4/fluency_results.csv"
)["output"]

print(
    "# Input: ",
    text[index],
)
print(
    "# Ground Truth: ",
    gold[index],
)
print(
    "# mt0-large: ",
    large[index],
)
print(
    "# mt0-base: ",
    base[index],
)
print("# mt0-small: ", small[index])

## Word Overlap Count (Section 6.2)

In [None]:
## Get statistics
model_to_overlap = {}

for folder in os.listdir(RESULTS_DIR):
    if "mt0" in folder:
        if "shots" in folder:
            file = os.path.join(
                RESULTS_DIR, folder, os.listdir(os.path.join(RESULTS_DIR, folder))[0]
            )

            model_size = file.split("-")[1]
            shots = file.split("=")[-1].split("/")[0]

            if model_size not in model_to_overlap:
                model_to_overlap[model_size] = {}

            model_to_overlap[model_size][shots] = []
            df = pd.read_csv(file)

            for index, row in df.iterrows():
                input_text = row["input"].lower().split()[1:]
                output_text = row["output"].lower().split()[1:]

                overlapping_count = longest_continuous_common_subsequence(
                    input_text, output_text
                )
                model_to_overlap[model_size][shots].append(overlapping_count)

        elif "language=en" in folder:
            file = os.path.join(
                RESULTS_DIR, folder, os.listdir(os.path.join(RESULTS_DIR, folder))[0]
            )
            df = pd.read_csv(file)

            model_size = file.split("-")[1]
            if model_size not in model_to_overlap:
                model_to_overlap[model_size] = {}

            if "arguments" not in folder:
                label_name = "Native (English Split)"
            elif "arguments" in folder:
                label_name = "Native (Full Split)"

            model_to_overlap[model_size][label_name] = []

            for index, row in df.iterrows():
                input_text = row["input"].lower().split()[1:]
                output_text = row["output"].lower().split()[1:]

                overlapping_count = longest_continuous_common_subsequence(
                    input_text, output_text
                )
                model_to_overlap[model_size][label_name].append(overlapping_count)

In [None]:
## Generate Box Plots for overlap statistics
fig = go.Figure()

models = ["small", "base", "large"]
shots = ["1", "10", "100", "1000", "full", "Native"]

# Create figures
for model_size in models:
    xs = list(
        itertools.chain.from_iterable(
            [i for i in model_to_overlap[model_size].values()]
        )
    )
    ys = list(
        itertools.chain.from_iterable(
            [[i] * 100 for i in model_to_overlap[model_size].keys()]
        )
    )

    fig.add_trace(
        go.Box(
            y=xs,
            x=ys,
            name=f"{model_size}",
        )
    )

# Get overlap statistics for ground truth
ground_truth_overlap = []

for index, row in df.iterrows():
    input_text = row["input"].lower().split()[1:]
    output_text = row["gold_output"].lower().split()[1:]

    overlapping_count = longest_continuous_common_subsequence(input_text, output_text)
    overlapping_ratio = overlapping_count / len(output_text)

    ground_truth_overlap.append(overlapping_count)

fig.add_trace(
    go.Box(
        y=ground_truth_overlap,
        x=["Ground Truth"] * len(ground_truth_overlap),
        name="Ground Truth",
    )
)


fig.update_layout(barmode="overlay", yaxis_range=[0, 20], plot_bgcolor="white")

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig.update_layout(boxmode="group")

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="Number of Shots",
    yaxis_title="Overlap Ratio with Input Text",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.84, y=0.97, bordercolor="Black", borderwidth=1),
)


fig.update_traces(orientation="v")
fig.update_layout(boxgroupgap=0.2, boxgap=0)

fig.show()
fig.write_image("word_overlap.pdf")

## Dataset Statistics (Section 2)

In [None]:
## Get text lengths for English split
en_train = pd.read_csv("../data/training_lang=en-data=split_en.csv")

en_train_input = [
    len(i.split("[TEXT]: ")[-1].split("[EMOTION]: ")[0].split())
    for i in en_train["input"]
]
en_train_output = [
    len(i.split("[RATIONALE]: ")[-1].split()) for i in en_train["output"]
]

fig = go.Figure()
fig.add_trace(go.Histogram(x=en_train_input, name="Text"))
fig.add_trace(go.Histogram(x=en_train_output, name="Explanation"))

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="Document Length",
    yaxis_title="Count",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.67, y=0.97, bordercolor="Black", borderwidth=1),
)

fig.update_layout(barmode="overlay", plot_bgcolor="white")
fig.update_traces(overwrite=True, marker={"opacity": 0.7})

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig.show()
fig.write_image("en_train.pdf")

In [None]:
## Get text lengths for Japanese split
ja_train = pd.read_csv("../data/training_lang=ja-data=split_ja.csv")
ja_train_input = [
    len(i.split("[文章]: ")[-1].split("[感情]: ")[0]) for i in ja_train["input"]
]
ja_train_output = [len(i.split("[説明]: ")[-1]) for i in ja_train["output"]]

fig = go.Figure()
fig.add_trace(go.Histogram(x=ja_train_input, name="Text"))
fig.add_trace(go.Histogram(x=ja_train_output, name="Explanation"))

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="Document Length",
    yaxis_title="Count",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.67, y=0.97, bordercolor="Black", borderwidth=1),
)

fig.update_layout(barmode="overlay", plot_bgcolor="white")
fig.update_traces(overwrite=True, marker={"opacity": 0.7})

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig.show()
fig.write_image("ja_train.pdf")

In [None]:
## Get text lengths for untranslated Japanese split
json_data = json.load(open("../data/lang=en-data=full.json", "r"))
json_en_data = json.load(open("../data/lang=en-data=split_en.json", "r"))
ja_untranslated = [i for i in json_data if i not in json_en_data]
ja_untranslated = pd.DataFrame(ja_untranslated)[["text", "choice", "explanation"]]

ja_untranslated_train_input = [len(i.split()) for i in ja_untranslated["text"]]
ja_untranslated_train_output = [len(i.split()) for i in ja_untranslated["explanation"]]

fig = go.Figure()
fig.add_trace(go.Histogram(x=ja_untranslated_train_input, name="Text"))
fig.add_trace(go.Histogram(x=ja_untranslated_train_output, name="Explanation"))

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="Document Length",
    yaxis_title="Count",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.67, y=0.97, bordercolor="Black", borderwidth=1),
)

fig.update_layout(barmode="overlay", plot_bgcolor="white")
fig.update_traces(overwrite=True, marker={"opacity": 0.7})

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)


fig.show()
fig.write_image("ja_untranslated_train.pdf")

## 