In [None]:
from collections import defaultdict
import wandb
import pandas as pd

In [None]:
api = wandb.Api()
entity = api.default_entity
project = "vtt"

# get runs from the project
def filter_runs(filters=None, sort=None):
    runs = api.runs(f"{entity}/{project}", filters=filters)
    runs = [
        run
        for run in runs
        if ("test/CIDEr" in run.summary and "model/_target_" in run.config)
    ]
    if sort is not None:
        runs = sorted(runs, key=sort)
    print(f"Find {len(runs)} runs in {entity}/{project}")
    return runs

In [None]:
filters = {"tags": {"$in": ["miss"]}}
runs = filter_runs(filters, sort=lambda run: run.summary["test/CIDEr"])

In [None]:
import matplotlib.pyplot as plt
import matplotlib.font_manager
print(f"available fonts: {sorted([f.name for f in matplotlib.font_manager.fontManager.ttflist])}")

plt.style.use('seaborn-muted')

plt.rcParams["figure.dpi"] = 300
plt.rcParams["savefig.dpi"] = 300
plt.rcParams["savefig.format"] = "pdf"
plt.rcParams["savefig.bbox"] = "tight"
plt.rcParams["savefig.pad_inches"] = 0.1

plt.rcParams['figure.titlesize'] = 18
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 18

plt.rcParams["lines.linewidth"] = 2
plt.rcParams["scatter.marker"] = "o"
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16
plt.rcParams['legend.fontsize'] = 16
plt.rcParams['axes.linewidth'] = 2
plt.rcParams['axes.titlepad'] = 6

plt.rcParams['mathtext.fontset'] = 'dejavuserif'
plt.rcParams['mathtext.it'] = 'serif:italic'
plt.rcParams['lines.marker'] = ""
plt.rcParams['legend.frameon'] = False

In [None]:
results = defaultdict(list)
models = []
for run in runs[::-1]:
    model_name = run.config["model/_target_"].split(".")[-1]
    if "model/image_encoder" in run.config:
        image_encoder = run.config["model/image_encoder"]
    else:
        image_encoder = "ResNet152"
    image_encoder = image_encoder.replace("resnet", "ResNet")
    if image_encoder == "ViT-L/14":
        image_encoder = "CLIP"
    elif image_encoder == "inception_v3":
        image_encoder = "InceptionV3"
    if "TTNet" in model_name:
        if model_name == "TTNetDiff":
            if run.config["model/mask_ratio"] > 0:
                model_name = "TTNet"
            else:
                model_name = "TTNet w/o MTM"
        else:
            model_name = "TTNet$_\\text{Base}$"
    elif image_encoder == "CLIP":
        model_name += "*"
    models.append(model_name)
    results["Full"].append(run.summary["test/CIDEr"] * 100)
    results["Randomly mask one"].append(run.summary["miss_one_test/CIDEr"] * 100)
    results["Initial & Final"].append(run.summary["init_fin_only_test/CIDEr"] * 100)
    df = pd.DataFrame(results, index=models)

In [None]:
df

In [None]:
row.values

In [None]:
for _, row in df.iterrows():
    plt.plot([0,1,2], row.values, 's', markersize=12, ls='-', linewidth=5, label=row.name)
plt.xlabel("States")
plt.ylabel("CIDEr")
plt.xticks([0, 1, 2], ['full', 'randomly mask one', 'start & end only'])
plt.legend()
plt.savefig("miss.pdf")