In [None]:
import pandas as pd
import numpy as np
import re
import plotly.express as px
import plotly.graph_objects as go

In [None]:
# df = pd.read_csv("../results/mt0-small-language=ja-epochs=15-batch_size=4-shots=full/fluency_results.csv")
df = pd.read_csv(
    "../results/mt0-small-language=ja-epochs=15-batch_size=4/fluency_results.csv"
)

# np.round(np.mean([int(i) for i in list(df["fluency"]) if i.isdigit()]), 3)
np.round(np.mean(list(df["fluency"])), 3)

In [None]:
from astropy.table import Table

tab_large = Table.read("mt0-large.tex").to_pandas()
tab_base = Table.read("mt0-base.tex").to_pandas()
tab_small = Table.read("mt0-small.tex").to_pandas()

tab_large.columns = [
    "data",
    "rouge",
    "meteor",
    "bleurt",
    "comparison",
    "fluency",
    "accuracy",
]
tab_base.columns = [
    "data",
    "rouge",
    "meteor",
    "bleurt",
    "comparison",
    "fluency",
    "accuracy",
]
tab_small.columns = [
    "data",
    "rouge",
    "meteor",
    "bleurt",
    "comparison",
    "fluency",
    "accuracy",
]

tab_large["model"] = "large"
tab_base["model"] = "base"
tab_small["model"] = "small"

table = pd.concat(
    [
        tab_large[
            ["rouge", "meteor", "bleurt", "comparison", "fluency", "accuracy", "model"]
        ],
        tab_base[
            ["rouge", "meteor", "bleurt", "comparison", "fluency", "accuracy", "model"]
        ],
        tab_small[
            ["rouge", "meteor", "bleurt", "comparison", "fluency", "accuracy", "model"]
        ],
    ]
).reset_index()
table = table[
    ["rouge", "meteor", "bleurt", "comparison", "fluency", "accuracy", "model"]
]
table.columns = [
    "ROUGE-L",
    "METEOR",
    "BLEURT",
    "Comparison",
    "Fluency",
    "Accuracy",
    "Model",
]

In [None]:
from scipy import stats

print(
    np.round(stats.spearmanr(table["ROUGE-L"], table["Comparison"]).statistic, 3),
    np.round(stats.spearmanr(table["ROUGE-L"], table["Fluency"]).statistic, 3),
    np.round(stats.spearmanr(table["ROUGE-L"], table["Accuracy"]).statistic, 3),
)

In [None]:
color_map = {"large": "red", "base": "blue", "small": "green"}
colors = table["Model"].map(color_map)

automated = ["ROUGE-L", "METEOR", "BLEURT"]
prompting = ["Comparison", "Fluency", "Accuracy"]

for a_m in automated:
    for p_m in prompting:
        correlation = np.round(stats.spearmanr(table[a_m], table[p_m]).statistic, 3)
        fig = go.Figure()
        fig = px.scatter(table, x=a_m, y=p_m, color="Model", trendline="ols")

        width = 600
        height = 400
        scale = 1

        fig.update_layout(
            width=width * scale,
            height=height * scale,
            font=dict(size=20 * scale),
            xaxis_title=a_m,
            yaxis_title=f"Prompt {p_m}",
            margin=dict(l=0, r=10, t=50, b=0),
            legend=dict(x=0.8, y=0.05, bordercolor="Black", borderwidth=1),
        )

        fig.update_traces(marker=dict(size=12), selector=dict(mode="markers"))

        fig.update_layout(
            title=f"Spearman Correlation: {correlation}",
            title_x=0.5,
            barmode="overlay",
            plot_bgcolor="white",
        )

        fig.update_xaxes(
            mirror=True,
            ticks="outside",
            showline=True,
            linecolor="black",
            gridcolor="lightgrey",
        )
        fig.update_yaxes(
            mirror=True,
            ticks="outside",
            showline=True,
            linecolor="black",
            gridcolor="lightgrey",
        )

        fig.show()
        fig.write_image(f"{a_m}_{p_m}.pdf")

In [None]:
color_map = {"large": "red", "base": "blue", "small": "green"}
colors = table["Model"].map(color_map)

fig = go.Figure()


fig = px.scatter(
    table,
    x="METEOR",
    y="Comparison",
    color="Model",
    trendline="ols",
)

width = 600
height = 400
scale = 1

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="METEOR",
    yaxis_title="Prompt Comparison",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.8, y=0.05, bordercolor="Black", borderwidth=1),
)

fig.update_traces(marker=dict(size=12), selector=dict(mode="markers"))

fig.update_layout(barmode="overlay", plot_bgcolor="white")

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig.show()

In [None]:
color_map = {"large": "red", "base": "blue", "small": "green"}
colors = table["Model"].map(color_map)

fig = go.Figure()


fig = px.scatter(
    table,
    x="BLEURT",
    y="Comparison",
    color="Model",
    trendline="ols",
)

width = 600
height = 400
scale = 1

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="BLEURT",
    yaxis_title="Prompt Comparison",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.8, y=0.05, bordercolor="Black", borderwidth=1),
)

fig.update_traces(marker=dict(size=12), selector=dict(mode="markers"))

fig.update_layout(barmode="overlay", plot_bgcolor="white")

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig.show()

In [None]:
color_map = {"large": "red", "base": "blue", "small": "green"}
colors = table["Model"].map(color_map)

fig = go.Figure()

fig = px.scatter(
    table,
    x="ROUGE-L",
    y="Comparison",
    color="Model",
    trendline="ols",
)

width = 600
height = 400
scale = 1

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="ROUGE-L",
    yaxis_title="Prompt Comparison",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.8, y=0.05, bordercolor="Black", borderwidth=1),
)

fig.update_traces(marker=dict(size=12), selector=dict(mode="markers"))

fig.update_layout(barmode="overlay", plot_bgcolor="white")

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig.show()
fig.write_image("rouge_comparison.pdf")
fig.write_image("rouge_comparison.pdf")
fig.write_image("rouge_comparison.pdf")

In [None]:
color_map = {"large": "red", "base": "blue", "small": "green"}
colors = table["Model"].map(color_map)

fig = go.Figure()

fig = px.scatter(
    table,
    x="ROUGE-L",
    y="Fluency",
    color="Model",
    trendline="ols",
)

width = 600
height = 400
scale = 1

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="ROUGE-L",
    yaxis_title="Prompt Fluency",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.8, y=0.05, bordercolor="Black", borderwidth=1),
)

fig.update_traces(marker=dict(size=12), selector=dict(mode="markers"))

fig.update_layout(barmode="overlay", plot_bgcolor="white")

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig.show()
fig.write_image("rouge_fluency.pdf")
fig.write_image("rouge_fluency.pdf")
fig.write_image("rouge_fluency.pdf")

In [None]:
color_map = {"large": "red", "base": "blue", "small": "green"}
colors = table["Model"].map(color_map)

fig = go.Figure()

fig = px.scatter(
    table,
    x="ROUGE-L",
    y="Accuracy",
    color="Model",
    trendline="ols",
)

width = 600
height = 400
scale = 1

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="ROUGE-L",
    yaxis_title="Prompt Accuracy",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.8, y=0.05, bordercolor="Black", borderwidth=1),
)

fig.update_traces(marker=dict(size=12), selector=dict(mode="markers"))

fig.update_layout(barmode="overlay", plot_bgcolor="white")

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)

fig.show()
fig.write_image("rouge_accuracy.pdf")

In [None]:
df = pd.read_csv(
    "../results/mt0-small-language=ja-epochs=15-batch_size=4-shots=10/fluency_results.csv"
)

filtered_scores = []
for index, row in df.iterrows():
    # try:
    #     output = row["fluency"]

    #     text = output.replace("[RATIONALE 1]", "").replace("[RATIONALE 2]", "")
    #     scores = (re.findall(r'\d+', text))
    #     if scores == []:
    #         if [i for i in cjk_substrings(row["output"])] == []:
    #             continue
    #         else:
    #             scores = 0
    #     else:
    #         scores = min(scores)
    #     filtered_scores.append(int(scores))
    # except IndexError:
    #     continue
    output = row["fluency"]

    text = output.replace("[RATIONALE 1]", "").replace("[RATIONALE 2]", "")
    scores = re.findall(r"\d+", text)
    if scores == []:
        if [i for i in cjk_substrings(row["output"])] == []:
            continue
        else:
            scores = 0
    else:
        scores = min(scores)
    filtered_scores.append(int(scores))

print(len(filtered_scores), np.round(np.mean(filtered_scores), 3))

In [None]:
filtered_scores

In [None]:
data = json.load(open("../data/lang=en-data=full.json", "r"))

inputs = []
outputs = []

for item in data:
    text_input = "[TEXT]: {} [EMOTION]: ".format(item["text"], item["choice"])
    text_output = "[RATIONALE]: {}".format(item["explanation"])

    inputs.append(text_input)
    outputs.append(text_output)

In [None]:
df = pd.DataFrame()
df["input"] = inputs
df["output"] = outputs

In [None]:
df.to_csv("../data/lang=en-data=full.csv", index=False)

In [None]:
pd.DataFrame(json.load(open("../data/lang=en-data=full.json", "r"))).to_csv(
    "../data/lang=en-data=full.csv", index=False
)

In [None]:
comparison = pd.read_csv(
    "../results/mt0-small-language=en-epochs=15-batch_size=4-arguments=full_data/comparison_results.csv"
)
# automated = pd.read_csv("../results/mt0-base-language=en-epochs=15-batch_size=4-arguments=full_data/results_all.csv")

print(
    sum(comparison["comparison"] == "YES") / len(comparison),
    sum(comparison["comparison"] == "NO") / len(comparison),
)
# print(np.mean(automated["bleurt"]), np.mean(automated["rouge"]), np.mean(automated["meteor"]))

In [None]:
# -*- coding:utf-8 -*-
ranges = [
    {"from": ord("\u3300"), "to": ord("\u33ff")},  # compatibility ideographs
    {"from": ord("\ufe30"), "to": ord("\ufe4f")},  # compatibility ideographs
    {"from": ord("\uf900"), "to": ord("\ufaff")},  # compatibility ideographs
    {"from": ord("\U0002F800"), "to": ord("\U0002fa1f")},  # compatibility ideographs
    {"from": ord("\u3040"), "to": ord("\u309f")},  # Japanese Hiragana
    {"from": ord("\u30a0"), "to": ord("\u30ff")},  # Japanese Katakana
    {"from": ord("\u2e80"), "to": ord("\u2eff")},  # cjk radicals supplement
    {"from": ord("\u4e00"), "to": ord("\u9fff")},
    {"from": ord("\u3400"), "to": ord("\u4dbf")},
    {"from": ord("\U00020000"), "to": ord("\U0002a6df")},
    {"from": ord("\U0002a700"), "to": ord("\U0002b73f")},
    {"from": ord("\U0002b740"), "to": ord("\U0002b81f")},
    {"from": ord("\U0002b820"), "to": ord("\U0002ceaf")},  # included as of Unicode 8.0
]


def is_cjk(char):
    return any([range["from"] <= ord(char) <= range["to"] for range in ranges])


def cjk_substrings(string):
    i = 0
    while i < len(string):
        if is_cjk(string[i]):
            start = i
            while is_cjk(string[i]):
                i += 1
            yield string[start:i]
        i += 1

In [None]:
comparison = pd.read_csv(
    "../results/mt0-small-language=en-epochs=15-batch_size=4/comparison_results.csv"
)


for index, row in comparison.iterrows():
    try:
        if len([i for i in cjk_substrings(row["output"].replace("[説明]:", ""))]) != 0:
            comparison.at[index, "comparison"] = "NO"
    except IndexError:
        comparison.at[index, "comparison"] = "NO"

np.round(sum(comparison["comparison"] == "YES") / len(comparison), 3), np.round(
    sum(comparison["comparison"] == "NO") / len(comparison), 3
)

In [None]:
np.mean(automated["bleurt"])

In [None]:
file_dir = "mt0-small-language=en-epochs=15-batch_size=4-arguments=full_data"

print(
    np.round(
        np.mean(pd.read_csv(f"../results/{file_dir}/results_all.csv")["bleurt"]), 3
    ),
    np.round(
        np.mean(pd.read_csv(f"../results/{file_dir}/results_all.csv")["rouge"]), 3
    ),
    np.round(
        np.mean(pd.read_csv(f"../results/{file_dir}/results_all.csv")["meteor"]), 3
    ),
)

In [None]:
print(
    # np.mean(pd.read_csv("../results/mt0-large-language=ja-epochs=15-batch_size=4-shots=100/results.csv")["bleurt"]),
    np.mean(
        pd.read_csv(
            "../results/mt0-large-language=ja-epochs=15-batch_size=4-shots=full/results.csv"
        )["rouge"]
    ),
    np.mean(
        pd.read_csv(
            "../results/mt0-large-language=ja-epochs=15-batch_size=4-shots=full/results.csv"
        )["meteor"]
    ),
)

In [None]:
import json

In [None]:
full_data = json.load(open("../data/lang=en-data=full.json", "r"))
en_data = json.load(open("../data/lang=en-data=split_en.json", "r"))

In [None]:
import random

In [None]:
sampled_en_data = random.sample(full_data, 500)

In [None]:
for index, i in enumerate(sampled_en_data):
    i["translation_id"] = index

In [None]:
js_data = [i for i in full_data if i not in en_data]

In [None]:
json.dump(
    sampled_en_data, open("../data/lang=en-data=emotion_test.json", "w"), indent=4
)

In [None]:
from munch import Munch

In [None]:
Munch.fromDict({"model": {"tokenizer": "a", "checkpoint": "b"}}).model

In [None]:
en_labels = pd.read_csv("../results/emotion_check/mt0-large-lang=en.csv")
ja_labels = pd.read_csv("../results/emotion_check/mt0-large-lang=ja.csv")

In [None]:
all_labels = pd.DataFrame()
all_labels["input"] = en_labels["input"]
all_labels["en"] = en_labels["output"]
all_labels["ground_truth"] = en_labels["gold_output"]
all_labels["ja"] = ja_labels["output"]

In [None]:
len(ja_labels["output"]), len(en_labels["output"])

In [None]:
all_labels

In [None]:
en_data = json.load(open("../data/lang=en-data=emotion_test.json", "r"))
ja_data = json.load(open("../data/lang=ja-data=emotion_test.json", "r"))

In [None]:
df = pd.DataFrame(ja_data)

In [None]:
df[["text", "choice", "translation_id"]].to_csv("temp.csv", index=False)

In [None]:
pd.read_csv("temp.csv").iloc[0]

In [None]:
data = []

for index, row in pd.read_csv("temp.csv").iterrows():
    data.append(
        {
            "text": row["text"],
            "choice": row["choice"],
            "translation_id": row["translation_id"],
        }
    )

In [None]:
json.dump(data, open("../data/lang=ja-data=emotion_test.json", "w"), indent=4)

In [None]:
en_labels = pd.read_csv("../results/emotion_check/mt0-large-lang=en.csv")

In [None]:
ja_labels = pd.read_csv("../results/emotion_check/mt0-large-lang=ja.csv")

In [None]:
plutchik_en_to_ja = {
    "joy": "喜び",
    "trust": "信頼",
    "fear": "恐れ",
    "sadness": "悲しみ",
    "disgust": "嫌悪",
    "anger": "怒り",
    "anticipation": "期待",
}
inv_map = {v: k for k, v in plutchik_en_to_ja.items()}

In [None]:
for index, row in ja_labels.iterrows():
    if ja_labels.at[index, "output"] in ["哀悼", "悲伤", "悲劇", "悲喜", "悲痛", "悲观", "痛感"]:
        new_em = "sadness"
    elif ja_labels.at[index, "output"] == "喜好":
        new_em = "joy"
    elif ja_labels.at[index, "output"] == "嫌惡":
        new_em = "disgust"
    elif ja_labels.at[index, "output"] in ["喜", "感受", "感激"]:
        new_em = "joy"
    elif ja_labels.at[index, "output"] in ["害怕", "恐怖心", "恐慌", "恐怖", "恐懼"]:
        new_em = "fear"
    elif ja_labels.at[index, "output"] == ["希望", "自信"]:
        new_em = "anticipation"
    elif ja_labels.at[index, "output"] in list(plutchik_en_to_ja.values()):
        new_em = inv_map[ja_labels.at[index, "output"]]

    ja_labels.at[index, "output"] = new_em

In [None]:
set(en_labels["output"])

In [None]:
en_labels

In [None]:
en_label = json.load(
    open("../results/emotion_check/prompting-lang=en-data=emotion_test.json", "r")
)

for item in en_label:
    item["model_label"] = item["model_label"].lower()
    if item["model_label"] in list(plutchik_en_to_ja.keys()):
        continue
    else:
        emotions = []
        for emotion in list(plutchik_en_to_ja.keys()):
            if emotion in item["model_label"]:
                emotions.append(emotion)
        item["model_label"] = ", ".join(emotions)

In [None]:
ja_label = json.load(
    open("../results/emotion_check/prompting-lang=ja-data=emotion_test.json", "r")
)

for item in ja_label:
    if item["model_label"] in list(inv_map.keys()):
        item["model_label"] = inv_map[item["model_label"]]
    else:
        emotions = []
        for emotion in list(inv_map.keys()):
            if emotion in item["model_label"]:
                emotions.append(inv_map[emotion])
        item["model_label"] = ", ".join(emotions)

In [None]:
positive_emotions = ["joy", "trust", "anticipation"]
negative_emotions = ["fear", "sadness", "disgust", "anger"]

In [None]:
for item in en_label:
    if item["model_label"] == "":
        item["general_label"] = "missing"
    else:
        for emotion in positive_emotions:
            if emotion in item["model_label"]:
                item["general_label"] = "negative"
            else:
                item["general_label"] = "positive"

In [None]:
for item in ja_label:
    if item["model_label"] == "":
        item["general_label"] = "missing"
    else:
        for emotion in positive_emotions:
            if emotion in item["model_label"]:
                item["general_label"] = "negative"
            else:
                item["general_label"] = "positive"

In [None]:
en_label_df = pd.DataFrame.from_records(en_label)
ja_label_df = pd.DataFrame.from_records(ja_label)

In [None]:
full_label_df = pd.DataFrame()
full_label_df["text"] = en_label_df["text"]
full_label_df["en"] = en_label_df["model_label"]
full_label_df["ja"] = ja_label_df["model_label"]
full_label_df["en_general"] = en_label_df["general_label"]
full_label_df["ja_general"] = ja_label_df["general_label"]
full_label_df["ground_truth"] = en_label_df["ground_truth"]

In [None]:
same = []
for index, row in full_label_df.iterrows():
    en_emotions = row["en"].split(", ")
    ja_emotions = row["ja"].split(", ")

    row_label = False

    if en_emotions == [""] or ja_emotions == [""]:
        row_label = -1

    for i in en_emotions:
        for j in ja_emotions:
            if i == j:
                row_label = True

    same.append(row_label)

In [None]:
same_general = []
for index, row in full_label_df.iterrows():
    row_label = False

    if row["en_general"] == "missing" or ja_emotions == "missing":
        row_label = -1

    if row["en_general"] == row["ja_general"]:
        row_label = True

    same_general.append(row_label)

In [None]:
full_label_df["same"] = same
full_label_df["same_general"] = same_general

In [None]:
len(full_label_df[full_label_df["same"] == -1]), len(
    full_label_df[full_label_df["same"] == True]
), len(full_label_df[full_label_df["same"] == False])

In [None]:
len(full_label_df[full_label_df["same_general"] == True]), len(
    full_label_df[full_label_df["same_general"] == False]
), len(full_label_df[full_label_df["same_general"] == -1])

In [None]:
full_label_df[full_label_df["en_general"] == "missing"]

In [None]:
import plotly.graph_objects as go

en_train = pd.read_csv("../data/training_lang=en-data=split_en.csv")
en_train_input = [
    len(i.split("[TEXT]: ")[-1].split("[EMOTION]: ")[0].split())
    for i in en_train["input"]
]
en_train_output = [
    len(i.split("[RATIONALE]: ")[-1].split()) for i in en_train["output"]
]

width = 600
height = 400
scale = 1

fig = go.Figure()
fig.add_trace(go.Histogram(x=en_train_input, name="Text"))
fig.add_trace(go.Histogram(x=en_train_output, name="Explanation"))

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="Document Length",
    yaxis_title="Count",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.67, y=0.97, bordercolor="Black", borderwidth=1),
)

fig.update_layout(barmode="overlay", plot_bgcolor="white")
fig.update_traces(overwrite=True, marker={"opacity": 0.7})

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)


fig.show()
fig.write_image("en_train.pdf")
fig.write_image("en_train.pdf")
fig.write_image("en_train.pdf")

In [None]:
width = 600
height = 400
scale = 1

ja_train = pd.read_csv("../data/training_lang=ja-data=split_ja.csv")
ja_train_input = [
    len(i.split("[文章]: ")[-1].split("[感情]: ")[0]) for i in ja_train["input"]
]
ja_train_output = [len(i.split("[説明]: ")[-1]) for i in ja_train["output"]]

fig = go.Figure()
fig.add_trace(go.Histogram(x=ja_train_input, name="Text"))
fig.add_trace(go.Histogram(x=ja_train_output, name="Explanation"))

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="Document Length",
    yaxis_title="Count",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.67, y=0.97, bordercolor="Black", borderwidth=1),
)

fig.update_layout(barmode="overlay", plot_bgcolor="white")
fig.update_traces(overwrite=True, marker={"opacity": 0.7})

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)


fig.show()
fig.write_image("ja_train.pdf")
fig.write_image("ja_train.pdf")
fig.write_image("ja_train.pdf")

In [None]:
json_data = json.load(open("../data/lang=en-data=full.json", "r"))
json_en_data = json.load(open("../data/lang=en-data=split_en.json", "r"))
ja_untranslated = [i for i in json_data if i not in json_en_data]
ja_untranslated = pd.DataFrame(ja_untranslated)[["text", "choice", "explanation"]]

In [None]:
width = 600
height = 400
scale = 1

json_data = json.load(open("../data/lang=en-data=full.json", "r"))
json_en_data = json.load(open("../data/lang=en-data=split_en.json", "r"))
ja_untranslated = [i for i in json_data if i not in json_en_data]
ja_untranslated = pd.DataFrame(ja_untranslated)[["text", "choice", "explanation"]]

ja_untranslated_train_input = [len(i.split()) for i in ja_untranslated["text"]]
ja_untranslated_train_output = [len(i.split()) for i in ja_untranslated["explanation"]]

fig = go.Figure()
fig.add_trace(go.Histogram(x=ja_untranslated_train_input, name="Text"))
fig.add_trace(go.Histogram(x=ja_untranslated_train_output, name="Explanation"))

fig.update_layout(
    width=width * scale,
    height=height * scale,
    font=dict(size=20 * scale),
    xaxis_title="Document Length",
    yaxis_title="Count",
    margin=dict(l=0, r=10, t=10, b=0),
    legend=dict(x=0.67, y=0.97, bordercolor="Black", borderwidth=1),
)

fig.update_layout(barmode="overlay", plot_bgcolor="white")
fig.update_traces(overwrite=True, marker={"opacity": 0.7})

fig.update_xaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)
fig.update_yaxes(
    mirror=True,
    ticks="outside",
    showline=True,
    linecolor="black",
    gridcolor="lightgrey",
)


fig.show()
fig.write_image("ja_untranslated_train.pdf")
fig.write_image("ja_untranslated_train.pdf")
fig.write_image("ja_untranslated_train.pdf")