In [None]:
import glob
import pickle
import torch
import pandas as pd
import fsspec
from baselines.apply_filter import load_metadata
import os
import random
import matplotlib.pyplot as plt
import seaborn as sns
from math import log

In [None]:
file_list = glob.glob('/local1/siting/scores/*.pt')
df_original = load_metadata('/local1/datasets/datacomp_small/metadata/', num_workers=os.cpu_count())

file_path = random.choice(file_list)
zipped_content = torch.load(file_path)
meru_uid_collection, meru_score_collection = zip(*zipped_content)
df = pd.DataFrame({'uid': meru_uid_collection, 'l_xtime': meru_score_collection})
new_df = df_original[df_original['uid'].isin(df['uid'])]
merged_df = pd.merge(new_df, df, on='uid')

merged_df['text_length'] = merged_df['text'].apply(lambda x: len(x))
merged_df['log_text_length'] = merged_df['text'].apply(lambda x: log(len(x)))

In [None]:
plt.figure(figsize=(6, 6))
sns.kdeplot(merged_df["clip_l14_similarity_score"], cumulative=True, bw_adjust=0.5)

plt.title('CLIP score CDF')
plt.xlabel('CLIP score')
plt.ylabel('Cumulative Density')
plt.show()
plt.savefig("./kde-clip.pdf")

plt.figure(figsize=(6, 6))
sns.kdeplot(merged_df["l_xtime"], cumulative=True, bw_adjust=0.5)

plt.title('MERU x_time CDF')
plt.xlabel("MERU x_time")
plt.ylabel('Cumulative Density')
plt.show()
plt.savefig("./kde-meru.pdf")

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(merged_df["l_xtime"], merged_df["clip_l14_similarity_score"], s=1)

x_max = merged_df["l_xtime"].max()
y_max = merged_df["clip_l14_similarity_score"].max()
plt.hlines(y=0.24194336, xmin=3.1830077171325684, xmax=x_max, colors='red', linestyles='--')
plt.vlines(x=3.1830077171325684, ymin=0.24194336, ymax=y_max, colors='red', linestyles='--')

plt.hlines(y=0.22106934, xmin=3.210522413253784, xmax=x_max, colors='red', linestyles='--')
plt.vlines(x=3.210522413253784, ymin=0.22106934, ymax=y_max, colors='red', linestyles='--')

plt.hlines(y=0.20251465, xmin=3.214315414428711, xmax=x_max, colors='red', linestyles='--')
plt.vlines(x=3.214315414428711, ymin=0.20251465, ymax=y_max, colors='red', linestyles='--')

plt.hlines(y=0.1595459, xmin=3.217371940612793, xmax=x_max, colors='red', linestyles='--')
plt.vlines(x=3.217371940612793, ymin=0.1595459, ymax=y_max, colors='red', linestyles='--')

plt.hlines(y=-0.11401367, xmin=3.2183597087860107, xmax=x_max, colors='red', linestyles='--')
plt.vlines(x=3.2183597087860107, ymin=-0.11401367, ymax=y_max, colors='red', linestyles='--')

plt.xlabel("MERU x_time")
plt.ylabel("CLIP score")
plt.show()
plt.savefig("scatter.pdf")

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(merged_df["l_xtime"], merged_df["text_length"], s=1)
plt.vlines(x=3.210522413253784, ymin=merged_df["text_length"].min(), ymax=merged_df["text_length"].max(), colors='orange', linestyles='--', label="MERU x_time "+r'$75\%$')
plt.xlabel("MERU x_time")
plt.ylabel("Text length")
plt.legend()
plt.show()
plt.savefig("scatter-length.pdf")

plt.figure(figsize=(6,6))
plt.scatter(merged_df["l_xtime"], merged_df["log_text_length"], s=1)
plt.vlines(x=3.2183597087860107, ymin=merged_df["log_text_length"].min(), ymax=merged_df["log_text_length"].max(), colors='red', linestyles='--', label="MERU x_time "+r'$30\%$')
plt.hlines(y=log(52), xmin=merged_df["l_xtime"].min(), xmax=merged_df["l_xtime"].max(), colors='orange', linestyles='--', label="Text length "+r'$30\%$')
plt.xlabel("MERU x_time")
plt.ylabel("Log text length")
plt.legend()
plt.show()
plt.savefig("scatter-length-log.pdf")

In [None]:
pd.set_option("display.max_colwidth", 10000)
with open("shortcaptions.txt", "w") as f:
    strange_df = merged_df[(merged_df["log_text_length"] < 2.5)&(merged_df["l_xtime"] >= 3.2183597087860107)]
    f.write(strange_df["text"].to_string(index=False))
    f.close()

with open("longcaptions.txt", "w") as f:
    strange_df2 = merged_df[(merged_df["log_text_length"] >= 6.0)&(merged_df["l_xtime"] < 3.2183597087860107)]
    f.write(strange_df2["text"].to_string(index=False))
    f.close()

In [None]:
strange_df2 = merged_df[(merged_df["log_text_length"] >= log(52))&(merged_df["l_xtime"] >= 3.2183597087860107)]
strange_df3 = merged_df[(merged_df["log_text_length"] < log(52))&(merged_df["l_xtime"] < 3.2183597087860107)]
print((len(strange_df2)+len(strange_df3)) / len(merged_df))