In [None]:
import json
from pathlib import Path
import matplotlib
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np

stats_path = Path("../stats/vidhoi/")

name_pred = ["lean_on", "watch", "above", "next_to", "behind", "away", "towards", "in_front_of", "hit", "hold", "wave", "pat", "carry", "point_to", "touch", "play(instrument)", "release", "ride", "grab", "lift", "use", "press", "inside", "caress", "pull", "get_on", "cut", "hug", "bite", "open", "close", "throw", "kick", "drive", "get_off", "push", "wave_hand_to", "feed", "chase", "kiss", "speak_to", "beneath", "smell", "clean", "lick", "squeeze", "shake_hand_with", "knock", "hold_hand_of", "shout_at"]
split_pred = {
    "spatial": [2, 3, 4, 5, 6, 7, 22, 41],
    "action": [0, 1, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42, 43, 44, 45, 46, 47, 48, 49],
    "temporal": [5, 6, 8, 10, 11, 16, 18, 19, 21, 23, 24, 25, 29, 30, 31, 32, 34, 35, 36, 38, 42, 44, 45, 47, 49]
}
print(name_pred)
print(split_pred)

In [None]:
color = ["#ff7f0e" if idx in split_pred["temporal"] else "#1f77b4" for idx in range(len(name_pred))]
blue_patch = mpatches.Patch(color="#1f77b4", label="Non-temporal interaction")
orange_patch = mpatches.Patch(color="#ff7f0e", label="Temporal interaction")


def draw_bar(hist, title, sorted_idx=None):
    if sorted_idx is not None:
        sorted_name_pred = [name_pred[i] for i in sorted_idx]
    else:
        sorted_name_pred = name_pred
    fig, ax = plt.subplots(figsize=(20, 8))
    ax.bar(range(len(hist)), hist, color=color)
    ax.set_xticks(range(len(hist)))
    ax.set_xticklabels(sorted_name_pred, rotation=90, fontsize=16)
    ax.tick_params(axis="y", labelsize=14)
    ax.set_ylim([0, None])
    ax.set_title(title, fontsize=20)
    
    axins:plt.Axes = ax.inset_axes((0.25, 0.3, 0.74, 0.68))
    axins.bar(range(len(hist)), hist, color=color)
    axins.set_xticks(range(len(hist)))
    axins.set_xticklabels(sorted_name_pred, rotation=90, fontsize=12)
    axins.tick_params(axis="y", labelsize=12)
    axins.set_yscale("log")
    # print(axins.get_ylim())
    axins.set_ylim([1.01, None])
    axins.set_title("Frequency in Log-Scale", fontsize=18, y=0.9)
    axins.legend(handles=[blue_patch, orange_patch], prop={"size": 16})

    plt.show()

### sorted
with (stats_path / "class_weight_train.json").open("r") as f:
    full_set_info = json.load(f)
interaction_train_hist = np.array(full_set_info["interaction_train_hist"]) + np.array(full_set_info["interaction_val_hist"])
sorted_idx = np.argsort(-interaction_train_hist)
draw_bar(interaction_train_hist[sorted_idx], "VidHOI Training Set Predicate Class Distribution", sorted_idx)
with (stats_path / "class_weight_val.json").open("r") as f:
    test_set_info = json.load(f)
interaction_test_hist = np.array(test_set_info["interaction_val_hist"])
sorted_idx = np.argsort(-interaction_test_hist)
draw_bar(interaction_test_hist[sorted_idx], "VidHOI Validation Set Predicate Class Distribution", sorted_idx)


### not sorted
print("=" * 40)
with (stats_path / "class_weight_train.json").open("r") as f:
    full_set_info = json.load(f)
interaction_train_hist = np.array(full_set_info["interaction_train_hist"]) + np.array(full_set_info["interaction_val_hist"])
draw_bar(interaction_train_hist, "VidHOI Training Set Predicate Class Distribution")
with (stats_path / "class_weight_val.json").open("r") as f:
    test_set_info = json.load(f)
interaction_test_hist = test_set_info["interaction_val_hist"]
draw_bar(interaction_test_hist, "VidHOI Validation Set Predicate Class Distribution")


print("=" * 40)
# full set training
with (stats_path / "class_weight_train.json").open("r") as f:
    full_set_info = json.load(f)
interaction_train_hist = full_set_info["interaction_train_hist"]
draw_bar(interaction_train_hist, "VidHOI Training Set Interaction Frequency")

# full set validation
interaction_val_hist = full_set_info["interaction_val_hist"]
draw_bar(interaction_val_hist, "VidHOI Validation Set Interaction Frequency")

# full set test
with (stats_path / "class_weight_val.json").open("r") as f:
    test_set_info = json.load(f)
interaction_test_hist = test_set_info["interaction_val_hist"]
draw_bar(interaction_test_hist, "VidHOI Test Set Interaction Frequency")



