In [None]:
import os
from fiftyone import zoo as foz
import json
import pandas as pd
import itertools
import torch
from collections import Counter
from matplotlib import pyplot as plt

In [None]:
COCO_17_DATA_DIR = os.path.abspath(
    "/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/coco-org/coco2017")
COCO_17_PLOT_DIR = os.path.abspath(
    "/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/coco-org/notebooks/plots")

PLOT_DIR = os.path.abspath("/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/main-project/notebooks/plots")

In [None]:
json_info_str = foz.load_zoo_dataset_info("coco-2017", dataset_dir=COCO_17_DATA_DIR).to_str(pretty_print=False)
info_dict = json.loads(json_info_str)
all_classes = info_dict["classes"]
target_classes = list(filter(lambda s: not s.isnumeric(), all_classes))

print(len(all_classes), len(target_classes))

In [None]:
df = pd.read_csv(os.path.join(COCO_17_DATA_DIR, "coco2017.csv"))
df["labels"] = df["labels"].apply(eval)
labels = df["labels"].tolist()
c = Counter(itertools.chain.from_iterable(labels))

In [None]:
dfc = pd.DataFrame.from_dict(c, orient="index", columns=["count"])
dfc.sort_values(by="count", inplace=True, ascending=False)
dfc.head()

In [None]:
ax = dfc[:20]["count"].plot(kind="bar", figsize=(10, 6), legend=False, width=0.8, color="tab:blue")
plt.xlabel("Concept (top 20)")
plt.ylabel("Frequency")
plt.xticks(rotation=45, ha="right")
plt.grid(alpha=0.5, axis="y")
plt.title("Coco-2017: Distribution of top 20 concepts")
plt.savefig(os.path.join(COCO_17_PLOT_DIR, "coco2017_top20_concepts.png"), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
checkpoint = torch.load("/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/main-project/checkpoints/coco2017-beit_clip-epoch_5.pt", map_location="cpu")
checkpoint.keys()

In [None]:
i2i_reports = checkpoint["report_i2i"][-1]
dfc["i2i_precision"] = [None for _ in range(len(dfc))]
for k, v in i2i_reports.items():
    dfc.loc[k, "i2i_precision"] = v["precision"]

dfc["t2t_precision"] = [None for _ in range(len(dfc))]
for k, v in checkpoint["report_t2t"][-1].items():
    dfc.loc[k, "t2t_precision"] = v["precision"]
dfc.sort_values(by="count", inplace=True, ascending=False)
dfc.head()

In [None]:
top_20 = dfc[:20]
fig, ax1 = plt.subplots(figsize=(10, 6))
ax2 = ax1.twinx()
ax1.bar("index", "count", data=top_20.reset_index(), color="tab:blue", label="Frequency")
ax2.plot(top_20["i2i_precision"], color="tab:orange", label="I2I AP", marker="+")
ax2.plot(top_20["t2t_precision"], color="tab:green", label="T2T AP", marker=".")

ax1.set_xlabel("Concept (top 20)")
ax1.set_ylabel("Frequency")
ax1.set_xticklabels(top_20.index, rotation=45, ha="right")

ax2.set_ylabel("Precision")
ax2.grid(alpha=0.5, axis="y")

plt.legend()
plt.title("TransforMMER (BEiT + CLIP) AP of top 20 most frequent concepts of Coco-2017")
plt.savefig(os.path.join(PLOT_DIR, "beit_clip_full_top20_concepts.png"), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
checkpoint = torch.load("/Users/xihaochen/Documents/National University of Singapore/Modules/2223 Sem 2/CP4101 B. Comp. Dissertation/Project/main-project/checkpoints/coco2017-augmented/coco2017-clip_clip-epoch_15.pt", map_location="cpu")
len(checkpoint["topics"])

In [None]:
checkpoint["report_i2i"][-1].keys()

In [None]:
i2i_reports = checkpoint["report_i2i"][-1]
dfc["i2i_precision"] = [None for _ in range(len(dfc))]
for k, v in i2i_reports.items():
    dfc.loc[k, "i2i_precision"] = v["precision"]

dfc["t2t_precision"] = [None for _ in range(len(dfc))]
for k, v in checkpoint["report_t2t"][-1].items():
    dfc.loc[k, "t2t_precision"] = v["precision"]
dfc.sort_values(by="count", inplace=True, ascending=False)
dfc.head()

In [None]:
top_20 = dfc[:20]
fig, ax1 = plt.subplots(figsize=(10, 6))
ax2 = ax1.twinx()
ax1.bar("index", "count", data=top_20.reset_index(), color="tab:blue", label="Frequency")
ax2.plot(top_20["i2i_precision"], color="tab:orange", label="I2I AP", marker="+")
ax2.plot(top_20["t2t_precision"], color="tab:green", label="T2T AP", marker=".")

ax1.set_xlabel("Concept (top 20)")
ax1.set_ylabel("Frequency")
ax1.set_xticklabels(top_20.index, rotation=45, ha="right")

ax2.set_ylabel("Precision")
ax2.grid(alpha=0.5, axis="y")

plt.legend()
plt.title("TransforMMER (BEiT + CLIP) AP of top 20 most frequent concepts of Coco-2017-A")
# plt.savefig(os.path.join(PLOT_DIR, "beit_clip_aug_top20_concepts.png"), dpi=300, bbox_inches='tight')
plt.show()