In [None]:
import sys
import jsonlines
from collections import defaultdict
from pathlib import Path
import pandas as pd
import numpy as np
from tqdm import tqdm
import spacy

sys.path.append("..")
import src.utils.datatool as dtool  # noqa: E402

In [None]:
import matplotlib.pyplot as plt
import matplotlib.font_manager
import matplotlib.image as mpimg
print(f"available fonts: {sorted([f.name for f in matplotlib.font_manager.fontManager.ttflist])}")

plt.style.use('seaborn-muted')

plt.rcParams["figure.dpi"] = 150
plt.rcParams["savefig.dpi"] = 300
plt.rcParams["savefig.format"] = "pdf"
plt.rcParams["savefig.bbox"] = "tight"
plt.rcParams["savefig.pad_inches"] = 0.1

plt.rcParams['figure.titlesize'] = 18
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 18

plt.rcParams["lines.linewidth"] = 2
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16
plt.rcParams['legend.fontsize'] = 16
plt.rcParams['axes.linewidth'] = 2
plt.rcParams['axes.titlepad'] = 6

plt.rcParams['mathtext.fontset'] = 'dejavuserif'
plt.rcParams['mathtext.it'] = 'serif:italic'
plt.rcParams['lines.marker'] = ""
plt.rcParams['legend.frameon'] = False

In [None]:
with jsonlines.open("/data/vtt/meta/vtt.jsonl") as reader:
    data = list(reader)

In [None]:
data[0]

## Language Compositional Generalization

In [None]:
def list2count(_list):
    count = defaultdict(int)
    for x in _list:
        count[x] += 1
    count = {key: val for key, val in sorted(count.items())}
    return count

In [None]:
# python -m spacy download en_core_web_sm
nlp = spacy.load("en_core_web_sm")
lemmatizer = nlp.get_pipe("lemmatizer")

sentences = defaultdict(list)
words = defaultdict(lambda: defaultdict(int))
words_all = []
for sample in tqdm(data):
    for step in sample["annotation"]:
        sentences[sample["ori"]].append(len(step['label'].split()))
        doc = nlp(step['label'])
        for word in doc:
            word = str(word)
            words_all.append(word)
            if word not in [",", "."]:
                words["all"][word] += 1
                words[sample["ori"]][word] += 1
sentences_count = {}
for key, val in sentences.items():
    sentences_count[key] = list2count(val)

In [None]:
unique_words_all = set(words_all)
len(unique_words_all)

In [None]:
stat_all = defaultdict(lambda: defaultdict(int))
stat_unique = defaultdict(lambda:defaultdict(set))
for sample in data:
    
    stat_all['all']['Samples'] += 1
    stat_all[sample['ori']]['Samples'] += 1
    stat_all[sample['split']]['Samples'] += 1
    
    stat_all['all']['Transformations'] += len(sample['annotation'])
    stat_all[sample['ori']]['Transformations'] += len(sample['annotation'])
    stat_all[sample['split']]['Transformations'] += len(sample['annotation'])
    
    stat_all['all']['States'] += (len(sample['annotation']) + 1)
    stat_all[sample['ori']]['States'] += (len(sample['annotation']) + 1)
    stat_all[sample['split']]['States'] += (len(sample['annotation']) + 1)
    
    stat_unique['all']['Categories'].add(sample['category'])
    stat_unique[sample['ori']]['Categories'].add(sample['category'])
    stat_unique[sample['split']]['Categories'].add(sample['category'])
    
    stat_unique['all']['Topics'].add(sample['topic'])
    stat_unique[sample['ori']]['Topics'].add(sample['topic'])
    stat_unique[sample['split']]['Topics'].add(sample['topic'])
    
    for t in sample['annotation']:
        stat_unique['all']['transformations'].add(t['label'])
        stat_unique[sample['ori']]['transformations'].add(t['label'])
        stat_unique[sample['split']]['transformations'].add(t['label'])
        
for dataset, info in stat_unique.items():
    for key, s in info.items():
        if key == "transformations":
            key = "Unique Transformations"
        stat_all[dataset][key] = len(s)

In [None]:
### words in unique transforamtions

In [None]:
words_all = set()
words_cnt = defaultdict(int)
for t in stat_unique['all']['transformations']:
    doc = nlp(t)
    for word in doc:
        word = str(word)
        words_all.add(word)
        if word not in [",", "."]:
            words_cnt[word] += 1

In [None]:
len(stat_unique['all']['transformations'])

In [None]:
print(len(words_all))

In [None]:
top_words

In [None]:
t_words_cnt = pd.Series(words_cnt).sort_values()[::-1]
top_words = t_words_cnt[2:52]
top_words_str = list(top_words.index)
top_words_cnt = list(top_words.values)

In [None]:
width, height = plt.figaspect(0.15)
font_size = 16
plt.rcParams["figure.dpi"] = 200
plt.rcParams["savefig.dpi"] = 300
plt.rcParams['axes.labelsize'] = font_size + 2
plt.rcParams['axes.labelweight'] = 'normal'
plt.rcParams['legend.fontsize'] = font_size
plt.rcParams['xtick.labelsize'] = font_size
plt.rcParams['ytick.labelsize'] = font_size
plt.rcParams['axes.linewidth'] = 1

plt.figure(figsize=(width, height))
plt.xticks(rotation='vertical')
colormap = "tab20b"
colors = plt.get_cmap(colormap).colors
axis = plt.bar(top_words_str, top_words_cnt, color=colors[2])
plt.ylabel("Count")
plt.margins(x=0.005)
plt.savefig("top_words.pdf", dpi=300)

In [None]:
unique_train_transformations = list(sorted(stat_unique['train']['transformations']))
print(len(unique_train_transformations))

In [None]:
### TTNet transformations

In [None]:
results = dtool.read_jsonlines("/log/exp/vtt/VTTDataModule.TTNetDiff.TellingLossV1.2022-09-19_21-44-42/detail.jsonl")
unique_ttnet_transformations = set()
for sample in results:
    for t in sample["preds"]:
        unique_ttnet_transformations.add(t)
print(f"unique ttnet transformations: {len(unique_ttnet_transformations)}")

In [None]:
# no useful unique transformations from TTNet
ttnet_transformations_only = unique_ttnet_transformations - set(unique_train_transformations)
ttnet_transformations_only

In [None]:
### CrossTask related videos

In [None]:
tasks = dtool.read_jsonlines("../docs/lists/tasks.jsonl")
print(len(tasks))

In [None]:
"add" not in words_all

In [None]:
candidates_tasks = []
def is_task_valid(task):
    for t in task["steps"]:
        doc = nlp(t)
        for word in doc:
            word = str(word)
            if word not in words_all:
                print(word)
                return False
    return True
for task in tasks:
    if task['type'] == 'related' and is_task_valid(task):
        candidates_tasks.append(task)

In [None]:
candidates_tasks

In [None]:
EXPERIMENTS = {
    "cst": "/log/exp/vtt/VTTDataModule.CST.GenerationLoss.2022-09-17_00-06-25",
    "glacnet": "/log/exp/vtt/VTTDataModule.GLACNet.GenerationLoss.2022-09-18_17-56-36",
    "densecap": "/log/exp/vtt/VTTDataModule.DenseCap.GenerationLoss.2022-09-25_15-15-34",
    "ttnet_base": "/log/exp/vtt/VTTDataModule.TTNetMTM.GenerationLoss.2022-09-16_10-59-03",
    "ttnet": "/log/exp/vtt/VTTDataModule.TTNetDiff.TellingLossV1.2022-09-19_21-44-42",
}

In [None]:
for exp, exp_path in EXPERIMENTS.items():
    result_path = Path(exp_path) / "detail.jsonl"
    results = dtool.read_jsonlines(result_path)
    print(f"{exp}:")
    print([t for t in results[-1]["preds"]])

## Combination Generalization

In [None]:
# topic, all, train, val, test
process_set = defaultdict(lambda: defaultdict(set))
process = {}
for sample in data:
    p = "-".join([x["label"] for x in sample["annotation"]])
    process_set[sample["category"]]["All"].add(p)
    process_set[sample["category"]][sample["split"].capitalize()].add(p)
for topic, sets in process_set.items():
    process[topic] = {}
    s_train = sets['Train']
    for split, s in sets.items():
        process[topic][split] = len(s)
        if split == "Val" or split == "Test":
            name = f"{split} Unique"
            process[topic][name] = len(s - s_train)

In [None]:
df = pd.DataFrame(process).T

In [None]:
df = df.sort_index()
df.loc["Total"] = df.sum()
df = df[["Train", "Val", "Val Unique", "Test", "Test Unique", "All"]]
df

In [None]:
print(df.style.to_latex(hrules=True))

In [None]:
META_FILE = Path("/data/vtt/meta/vtt.jsonl")
EXPERIMENTS = {
    "cst": "/log/exp/vtt/VTTDataModule.CST.GenerationLoss.2022-09-17_00-06-25",
    "glacnet": "/log/exp/vtt/VTTDataModule.GLACNet.GenerationLoss.2022-09-18_17-56-36",
    "densecap": "/log/exp/vtt/VTTDataModule.DenseCap.GenerationLoss.2022-09-25_15-15-34",
    "ttnet_base": "/log/exp/vtt/VTTDataModule.TTNetMTM.GenerationLoss.2022-09-16_10-59-03",
    "ttnet": "/log/exp/vtt/VTTDataModule.TTNetDiff.TellingLossV1.2022-09-19_21-44-42",
}

In [None]:
test_samples = dtool.JSONLList(META_FILE, lambda x: x["split"] == "test").samples

In [None]:
process_set_train = set()
for topic, sets in process_set.items():
    process_set_train = process_set_train | sets["Train"]
print(len(process_set_train))

In [None]:
test_share = []
test_only = []
for i, sample in enumerate(test_samples):
    p = "-".join([x["label"] for x in sample["annotation"]])
    if p in process_set_train:
        test_share.append(i)
    else:
        test_only.append(i)
print(f"share: {len(test_share)}")
print(f"only: {len(test_only)}")

In [None]:
METRICS = ["BLEU_4", "ROUGE", "METEOR", "CIDEr", "SPICE", "BERTScore"]
def compute_metrics(results, metrics=METRICS):
    scores = defaultdict(list)
    for result in results:
        for metric in metrics:
            if type(result[metric]) is list:
                scores[metric].extend(result[metric])
            else:
                scores[metric].append(result[metric])
    return {key: np.mean(value) for key, value in scores.items()}

### Automatic Metrics

In [None]:
scores = {}
for exp, exp_path in EXPERIMENTS.items():
    result_path = Path(exp_path) / "detail.jsonl"
    results = dtool.read_jsonlines(result_path)
    scores[exp] = compute_metrics(results)
df_scores = pd.DataFrame(scores).T
df_scores

In [None]:
print(df_scores[["CIDEr"]].style.format(precision=2).to_latex(hrules=True, ))

In [None]:
scores = {}
for exp, exp_path in EXPERIMENTS.items():
    result_path = Path(exp_path) / "detail.jsonl"
    results = dtool.read_jsonlines(result_path)
    # share
    results = [results[i] for i in test_share]
    scores[exp] = compute_metrics(results)
df_scores = pd.DataFrame(scores).T
df_scores

In [None]:
print(df_scores[["CIDEr"]].style.format(precision=2).to_latex(hrules=True, ))

In [None]:
scores = {}
for exp, exp_path in EXPERIMENTS.items():
    result_path = Path(exp_path) / "detail.jsonl"
    results = dtool.read_jsonlines(result_path)
    # only
    results = [results[i] for i in test_only]
    scores[exp] = compute_metrics(results)
df_scores = pd.DataFrame(scores).T
df_scores

In [None]:
print(df_scores[["CIDEr"]].style.format(precision=2).to_latex(hrules=True, ))

### Human Results

In [None]:
HUMAN_RESULTS_DIR = Path("../docs/lists/human_results")
EXPS = ["cst", "glacnet", "densecap", "ttnet_base", "ttnet"]
HUMAN_METRICS = ["fluency", "relevance", "logical_soundness"]

In [None]:
scores = {}
for exp in EXPS:
    path = HUMAN_RESULTS_DIR / f"{exp}.jsonl"
    results = dtool.read_jsonlines(path)
    scores[exp] = compute_metrics(results, HUMAN_METRICS)
df_scores = pd.DataFrame(scores).T
df_scores

In [None]:
print(df_scores.style.format(precision=2).to_latex(hrules=True, ))

In [None]:
scores = {}
for exp in EXPS:
    path = HUMAN_RESULTS_DIR / f"{exp}.jsonl"
    results = dtool.read_jsonlines(path)
    results = [x for x in results if x["index"] in test_share]
    scores[exp] = compute_metrics(results, HUMAN_METRICS)
df_scores = pd.DataFrame(scores).T
df_scores

In [None]:
print(df_scores.style.format(precision=2).to_latex(hrules=True, ))

In [None]:
scores = {}
for exp in EXPS:
    path = HUMAN_RESULTS_DIR / f"{exp}.jsonl"
    results = dtool.read_jsonlines(path)
    results = [x for x in results if x["index"] in test_only]
    scores[exp] = compute_metrics(results, HUMAN_METRICS)
df_scores = pd.DataFrame(scores).T
df_scores

In [None]:
print(df_scores.style.format(precision=2).to_latex(hrules=True, ))