In [None]:
import json
import matplotlib.pyplot as plt
from pathlib import Path
import re
import pandas as pd
import seaborn as sns
sns.set_theme()

In [None]:
# TGT="en"
# LANG="deen"
# DATA = "wmt15"
# RESULT = Path(f"{DATA}_{TGT}-results")
# MODELS = [
#     "wait_{0}_deen_distill", 
#     "wait_{0}_deen_mon",
#     "wait_{0}_deen_reorder", 
#     "ctc_delay{0}",
#     "ctc_delay{0}_mon",
#     "ctc_delay{0}_reorder",
# #     "sinkhorn_delay{0}", 
#     "sinkhorn_delay{0}_ft", 
# ]
TGT="zh"
LANG="enzh"
DATA = "cwmt"
RESULT = Path(f"{DATA}_{TGT}-results")
MODELS = [
    "wait_{0}_enzh_distill", 
    "wait_{0}_enzh_mon",
    "wait_{0}_enzh_reorder", 
    "ctc_delay{0}",
    "ctc_delay{0}_mon",
    "ctc_delay{0}_reorder",
    "sinkhorn_delay{0}", 
    "sinkhorn_delay{0}_ft", 
]
TRIALS = [1,2,3]
OUTPUT = Path(f"./graphs/")
OUTPUT.mkdir(parents=True, exist_ok=True)

data = {}

def mean(l):
    return sum(l) / len(l)

for delay in range(1, 10, 2):
    for model in MODELS:
        m = model.format(delay)
        AL = []
        AL_CA = []
        for t in TRIALS:
            dat = json.load(open(RESULT/f"{m}.{DATA}"/f"scores.{t}"))
            AL += [dat["Latency"]["AL"]]
            AL_CA += [dat["Latency"]["AL_CA"]]
        
        if mean(AL) != AL[0]:
            print(f"{m}: {AL}")
        data[m] = {
            "AL": mean(AL),
            "AL_CA": mean(AL_CA)
        }
        
def getmeanvar(t):
    m = re.search(r"(?P<bleu>\d+\.?\d*)\s*\((?P<mean>\d+\.?\d*)\s*±\s*(?P<var>\d+\.?\d*)", t)
    return float(m.group("bleu")), float(m.group("mean")), float(m.group("var"))
        
for delay in range(1, 10, 2):
    baseline = MODELS[0]
    with open(RESULT/f"quality-results.{DATA}"/f"delay{delay}-systems") as f:
        for line in f:
            for m in MODELS:
                if f"{m.format(delay)}.{DATA}" in line and "pairwise" not in line:
                    bleu, bmu, bvar = getmeanvar(line.split("│")[2])
                    chrf, cmu, cvar = getmeanvar(line.split("│")[3])
                    data[m.format(delay)].update({
                        "BLEU": bleu,
                        "BLEU-mu": bmu,
                        "BLEU-var": bvar,
                        "chrF": chrf,
                        "chrF-mu": cmu,
                        "chrF-var": cvar,
                    })

In [None]:
pd.DataFrame.from_dict(data, orient='index').round(2)

In [None]:
formal = {
    "BLEU": "BLEU",
    "chrF": "chrF2",
    "AL": "AL", 
    "AL_CA": "AL-CA",
}

In [None]:
MODELNAMES = [
    "wait-$k$", 
    "wait-$k$+Pseudo", 
    "wait-$k$+Reorder", 
    "CTC", 
    "CTC+Pseudo", 
    "CTC+Reorder",
#     "CTC+ASN (Scratch)", 
    "CTC+ASN (Ours)", 
]
STYLES = [
    "^-",
    "s-",
    "d-",
    "o-",
    "p-",
    "*-",
    "P-"
]

BLEUs = {
    m: [ data[MODELS[i].format(k)]["BLEU"] for k in range(1, 10, 2) ]
    for i,m in enumerate(MODELNAMES)
}
chrFs = {
    m: [ data[MODELS[i].format(k)]["chrF"] for k in range(1, 10, 2) ]
    for i,m in enumerate(MODELNAMES)
}
ALs = {
    m: [ data[MODELS[i].format(k)]["AL"] for k in range(1, 10, 2) ]
    for i,m in enumerate(MODELNAMES)
}
AL_CAs = {
    m: [ data[MODELS[i].format(k)]["AL_CA"] for k in range(1, 10, 2) ]
    for i,m in enumerate(MODELNAMES)
}




In [None]:
def plot(x, y, xlabel, ylabel, legend=False, file=None):
    figsize=(5, 3)
    dpi=200

    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = fig.add_subplot(111)

    for i,m in enumerate(MODELNAMES):
        ax.plot(x[m], y[m], STYLES[i%len(STYLES)], label=m)

    if legend:
        ax.legend() # loc='upper right'
    ax.autoscale()

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    fig.tight_layout()
    fig.show()
    if file is not None:
        plt.savefig(str(file.as_posix()), dpi=dpi)

plot(ALs, BLEUs, formal["AL"], formal["BLEU"], file=OUTPUT / f"bleu-al-{LANG}.pdf")

In [None]:
plot(AL_CAs, BLEUs, formal["AL_CA"], formal["BLEU"], True, file=OUTPUT / f"bleu-alca-{LANG}.pdf")

In [None]:
plot(ALs, chrFs, formal["AL"], formal["chrF"], file=OUTPUT / f"chrf-al-{LANG}.pdf")

In [None]:
plot(AL_CAs, chrFs, formal["AL_CA"], formal["chrF"], legend=True, file=OUTPUT / f"chrf-alca-{LANG}.pdf")

## Oracle Order

In [None]:
# oracle
MODELS = [
    "sinkhorn_delay{0}_ft", 
]
for delay in range(1, 10, 2):
    for m in MODELS:
        model = m.format(delay)
        # default
        with open(RESULT/f"quality-results.{DATA}"/"verbose"/model) as f:
            vscore = json.load(f)["verbose_score"]
            ngrams = vscore.split()[0].split("/")
            for i, Ngram in enumerate(ngrams):
                data[model].update({
                    f"{i+1}-gram": float(Ngram)
                })
            bp = float(vscore.split()[3])
            data[m.format(delay)].update({
                "BP": bp,
            })
        # oracle
        with open("oracle_order"/RESULT/model/"score") as f:
            scores = json.load(f)
            vscore = scores["verbose_score"]
            ngrams = vscore.split()[0].split("/")
            for i, Ngram in enumerate(ngrams):
                data[model].update({
                    f"oracle {i+1}-gram": float(Ngram)
                })
            
            bleu = float(scores["score"])
            bp = float(vscore.split()[3])
            data[m.format(delay)].update({
                "oracle BLEU": bleu,
                "oracle BP": bp,
            })

show = ["BLEU", "BP"] + [f"{i+1}-gram" for i in range(4)]
show += ["oracle " + s for s in show]
df = pd.DataFrame.from_dict(data, orient='index')[show]
df = df[df["oracle BLEU"].notnull()]
df.round(2)

## Ablation

In [None]:
# ABLATIONS = [
#     "sinkhorn_delay3_unittemp",
#     "sinkhorn_delay3_nonoise",
#     "sinkhorn_delay3_softmax",
# ]
# ABLATIONNAMES = [
#     "Unit temperature",
#     "Zero noise",
#     "Low-temp noised softmax",
# ]
# for model in ABLATIONS:
#     m = model
#     AL = []
#     AL_CA = []
#     TRIALS = [1,]
#     for t in TRIALS:
#         dat = json.load(open(RESULT/f"{m}.{DATA}"/f"scores.{t}"))
#         AL += [dat["Latency"]["AL"]]
#         AL_CA += [dat["Latency"]["AL_CA"]]

#     if mean(AL) != AL[0]:
#         print(f"{m}: {AL}")
#     data[m] = {
#         "AL": mean(AL),
#         "AL_CA": mean(AL_CA)
#     }
        
# with open(RESULT/f"quality-results.{DATA}"/f"ablation-systems") as f:
#     for line in f:
#         for m in ABLATIONS:
#             if m in line:
#                 bleu, bmu, bvar = getmeanvar(line.split("│")[2])
#                 chrf, cmu, cvar = getmeanvar(line.split("│")[3])
#                 data[m.format(delay)].update({
#                     "BLEU": bleu,
#                     "BLEU-mu": bmu,
#                     "BLEU-var": bvar,
#                     "chrF": chrf,
#                     "chrF-mu": cmu,
#                     "chrF-var": cvar,
#                 })

In [None]:
MODELS = [
    "sinkhorn_delay{0}", 
    "sinkhorn_delay{0}_ft", 
]
MODELNAMES = [
    "Scratch", 
    "Weight init.", 
]
STYLES = [
    "^-",
    "P-"
]
COLORS = [
    "#1f77b4",
    "#e377c2",
]

BLEUs = {
    m: [ data[MODELS[i].format(k)]["BLEU"] for k in range(1, 10, 2) ]
    for i,m in enumerate(MODELNAMES)
}
ALs = {
    m: [ data[MODELS[i].format(k)]["AL"] for k in range(1, 10, 2) ]
    for i,m in enumerate(MODELNAMES)
}

# BLEUs.update({
#     m: [ data[ABLATIONS[i]]["BLEU"]]
#     for i,m in enumerate(ABLATIONNAMES)
# })
# ALs.update({
#     m: [ data[ABLATIONS[i]]["AL"]]
#     for i,m in enumerate(ABLATIONNAMES)
# })
# MODELNAMES += ABLATIONNAMES

x, y, xlabel, ylabel, legend = ALs, BLEUs, formal["AL"], formal["BLEU"], True
figsize=(5, 2.2)
dpi=200

fig = plt.figure(figsize=figsize, dpi=dpi)
ax = fig.add_subplot(111)

for i,m in enumerate(MODELNAMES):
    ax.plot(
        x[m], y[m], 
        STYLES[i%len(STYLES)], 
        color=COLORS[i%len(COLORS)], 
        label=m
    )

if legend:
    ax.legend() # loc='upper right'
ax.autoscale()

ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
fig.tight_layout()
fig.show()
plt.savefig(str(OUTPUT / f"weight-init.pdf"), dpi=dpi)

In [None]:
for delay in range(1, 10, 2):
    for m in MODELS:
        model = m.format(delay)
        with open(RESULT/f"quality-results.{DATA}"/"verbose"/model) as f:
            vscore = json.load(f)["verbose_score"]
            ngrams = vscore.split()[0].split("/")
            for i, Ngram in enumerate(ngrams):
                data[model].update({
                    f"{i+1}-gram": float(Ngram)
                })

In [None]:
DIFF = {
    f"{i+1}-gram": [ data[f"sinkhorn_delay{k}_ft"][f"{i+1}-gram"] - data[f"sinkhorn_delay{k}"][f"{i+1}-gram"]  for k in range(1, 10, 2) ]
    for i in range(4)
}
STYLES = [
    "^-",
    "s-",
    "d-",
    "o-",
    "p-",
    "*-",
    "P-"
]
figsize=(5, 2.5)
dpi=200

fig = plt.figure(figsize=figsize, dpi=dpi)
ax = fig.add_subplot(111)

for i,m in enumerate(DIFF):
    ax.plot(
        range(1, 10, 2), DIFF[m],
        STYLES[i%len(STYLES)], 
        label=f"{i+1}-gram"
    )
    ax.set_xticks(range(1, 10, 2))
    ax.set_yticks(range(0, 5, 2))
    ax.set_yticklabels([f"+{j}" for j in range(0, 5, 2)])

if legend:
    ax.legend(loc='upper right', ncol=2)
ax.autoscale()

ax.set_xlabel("k")
# ax.set_ylabel("Improvement")
fig.tight_layout()
fig.show()

In [None]:
# oracle


In [None]:
import re
def kAR(align, k, reverse=False):
    inv, tot = 0, 0
    itr = re.finditer(r"(?P<i>[0-9]+)-(?P<j>[0-9]+)", align)
    for m in itr:
        i = int(m.group("j" if reverse else "i"))
        j = int(m.group("i" if reverse else "j"))
        tot += 1
        if i - k + 1 > j:
            inv += 1
    # print(inv, tot)
    return inv / tot * 100

In [None]:
with open("anticipation/alignments/valid.en-zh_1000000", "r") as f:
    align_en_zh = f.read()
kar_enzh_valid = [    
    kAR(align_en_zh, k)
    for k in range(1, 10)
]

with open("anticipation/alignments/test.en-zh_1000000", "r") as f:
    align_en_zh = f.read()
kar_enzh_test = [    
    kAR(align_en_zh, k)
    for k in range(1, 10)
]

In [None]:
with open("anticipation/alignments/valid.de-en_1000000", "r") as f:
    align_de_en = f.read()
kar_deen_valid = [    
    kAR(align_de_en, k)
    for k in range(1, 10)
]

with open("anticipation/alignments/test.de-en_1000000", "r") as f:
    align_de_en = f.read()
kar_deen_test = [    
    kAR(align_de_en, k)
    for k in range(1, 10)
]

In [None]:
fig = plt.figure(figsize=(5, 3), dpi=200)
ax = fig.add_subplot(111)


ax.plot(range(1, 10), kar_enzh_valid, "b^-", label="CWMT dev")
ax.plot(range(1, 10), kar_enzh_test, "gs-", label="CWMT test")
ax.plot(range(1, 10), kar_deen_valid, "yD-", label="WMT15 dev")
ax.plot(range(1, 10), kar_deen_test, "ro-", label="WMT15 test")

ax.set_xticks(range(1, 10)) 

ax.legend()
ax.autoscale()

ax.set_xlabel("$k$")
ax.set_ylabel("$k$-AR (%)")
fig.tight_layout()
fig.show()

In [None]:
import re
import numpy as np

def plot_align(file):
    dists = []
    with open(file) as f:
        pattern = re.compile(r"(?P<i>[0-9]+)-(?P<j>[0-9]+)")
        for line in f.readlines():
            all_i = []
            all_j = []
            for si, sj in pattern.findall(line):
                i = int(si)
                j = int(sj)
                all_i.append(i)
                all_j.append(j)

            max_i = max(all_i)
            max_j = max(all_j)
            for i, j in zip(all_i, all_j):
                tgt = i / (max_i + 1e-9)
                hyp = j / (max_j + 1e-9)
                dists.append(abs(tgt - hyp))

    return dists

In [None]:
enzh_hist_valid = plot_align("anticipation/alignments/valid.en-zh_1000000")
enzh_hist_test = plot_align("anticipation/alignments/test.en-zh_1000000")
deen_hist_valid = plot_align("anticipation/alignments/valid.de-en_1000000")
deen_hist_test = plot_align("anticipation/alignments/test.de-en_1000000")

In [None]:
fig = plt.figure(figsize=figsize, dpi=dpi)
ax = fig.add_subplot(111)

ax.hist(
    (
        enzh_hist_valid,
#         enzh_hist_test,
        deen_hist_valid,
#         deen_hist_test
    ), 
    bins=10, 
    weights=(
        np.ones(len(enzh_hist_valid)) / len(enzh_hist_valid) * 100,
#         np.ones(len(enzh_hist_test)) / len(enzh_hist_test)  * 100,
        np.ones(len(deen_hist_valid)) / len(deen_hist_valid)  * 100,
#         np.ones(len(deen_hist_test)) / len(deen_hist_test)  * 100
    ), 
    label=(
        "En-Zh",
#         "En-Zh test",
        "De-En",
#         "De-En test"
    )
)


ax.legend()
ax.autoscale()

ax.set_xlabel("Relative Alignment Distance")
ax.set_ylabel("Alignments (%)")
fig.tight_layout()
fig.show()

In [None]:
def metric(metric, displayname, runs, rundisplaynames, max_step=50000):
    import wandb
    api = wandb.Api()
    
    key=f"train_inner/{metric}"

    datas = []
    for rname in runs:
        run = api.run(rname)

        history = run.scan_history(keys=[key], max_step=max_step)
        datas.append([row[key] for row in history])

    fig = plt.figure(figsize=(6,4), dpi=200)
    ax = fig.add_subplot(111)


    for d, name in zip(datas,rundisplaynames):
        steps = [t*50 for t in range(len(d))]
        ax.plot(steps, d, label=name)


    ax.legend()
    ax.autoscale()

    ax.set_xlabel("Steps")
    ax.set_ylabel(displayname)
    fig.tight_layout()
    fig.show()