In [None]:
import numba as nb
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image

In [None]:
@nb.njit(nopython=True, cache=True)
def cossim(vec1, vec2):
    norm = np.sqrt(np.sum(vec1**2) * np.sum(vec2**2))
    return np.sum(vec1 * vec2) / norm

@nb.njit(nopython=True, cache=True, parallel=True)
def cossim_mat(embeds1, embeds2):
    n1 = embeds1.shape[0]
    n2 = embeds2.shape[0]
    sims = np.zeros((n1, n2), dtype=np.float32)
    for i in nb.prange(n1):
        for j in range(n2):
            sims[i, j] = cossim(embeds1[i], embeds2[j])
    return sims

In [None]:
classes = [
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
]

In [None]:
def load_embed_dict(dset, model, overlap):
    embed_dict = {}
    for cl in classes:
        embed_dict[cl] = np.load(f"/home/thesis/marx/wilson_gen/hugface/{model}_embeds/{dset}/10-10{overlap}/{cl}.npy").astype(np.float32)
    return embed_dict

In [None]:
voc_clip = load_embed_dict("voc", "clip", "-ov")
voc_dino = load_embed_dict("voc", "dino", "-ov")

In [None]:
replay_dir = "replay_data_lora"
overlap = "-ov"
top_k = 2

replay_clip = load_embed_dict(replay_dir, "clip", overlap)
replay_dino = load_embed_dict(replay_dir, "dino", overlap)
fig, axes = plt.subplots(len(classes), 2 * top_k + 2, figsize=(30, 50))
for i, cl in enumerate(classes):
    cossim_voc_clip = cossim_mat(voc_clip[cl], voc_clip[cl])
    cossim_voc_dino = cossim_mat(voc_dino[cl], voc_dino[cl])
    cossim_mat_clip = cossim_mat(replay_clip[cl], voc_clip[cl])
    cossim_mat_dino = cossim_mat(replay_dino[cl], voc_dino[cl])
    best_clip_inds = np.argsort(cossim_mat_clip.mean(axis=1))[::-1][:top_k]
    best_dino_inds = np.argsort(cossim_mat_dino.mean(axis=1))[::-1][:top_k]
    best_voc_clip = np.argsort(cossim_voc_clip.mean(axis=1))[::-1][0]
    best_voc_dino = np.argsort(cossim_voc_dino.mean(axis=1))[::-1][0]
    voc_img_dino = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images"))[best_voc_dino]
    axes[i, 0].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images/{voc_img_dino}"))
    axes[i, 0].axis("off")
    axes[i, 0].set_title(f"best VOC {cl} (DINO)")
    for j, ind in enumerate(best_dino_inds):
        axes[i, j+1].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_dir}/10-10{overlap}/{cl}/images/{str(ind).zfill(5)}.jpg"))
        axes[i, j+1].axis("off")
        axes[i, j+1].set_title(f"DINO {j+1}. best")
    voc_img_clip = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images"))[best_voc_clip]
    axes[i, top_k+1].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images/{voc_img_clip}"))
    axes[i, top_k+1].axis("off")
    axes[i, top_k+1].set_title(f"best VOC {cl} (CLIP)")
    for j, ind in enumerate(best_clip_inds):
        axes[i, j+2+top_k].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_dir}/10-10{overlap}/{cl}/images/{str(ind).zfill(5)}.jpg"))
        axes[i, j+2+top_k].axis("off")
        axes[i, j+2+top_k].set_title(f"CLIP {j+1}. best")
plt.tight_layout()
plt.show()

In [None]:
replay_dir = "replay_data_baseline"
overlap = "-ov"
top_k = 2

replay_clip = load_embed_dict(replay_dir, "clip", overlap)
replay_dino = load_embed_dict(replay_dir, "dino", overlap)
fig, axes = plt.subplots(len(classes), 2 * top_k + 2, figsize=(30, 50))
for i, cl in enumerate(classes):
    cossim_voc_clip = cossim_mat(voc_clip[cl], voc_clip[cl])
    cossim_voc_dino = cossim_mat(voc_dino[cl], voc_dino[cl])
    cossim_mat_clip = cossim_mat(replay_clip[cl], voc_clip[cl])
    cossim_mat_dino = cossim_mat(replay_dino[cl], voc_dino[cl])
    best_clip_inds = np.argsort(cossim_mat_clip.mean(axis=1))[::-1][:top_k]
    best_dino_inds = np.argsort(cossim_mat_dino.mean(axis=1))[::-1][:top_k]
    best_voc_clip = np.argsort(cossim_voc_clip.mean(axis=1))[::-1][0]
    best_voc_dino = np.argsort(cossim_voc_dino.mean(axis=1))[::-1][0]
    voc_img_dino = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images"))[best_voc_dino]
    axes[i, 0].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images/{voc_img_dino}"))
    axes[i, 0].axis("off")
    axes[i, 0].set_title(f"best VOC {cl} (DINO)")
    for j, ind in enumerate(best_dino_inds):
        axes[i, j+1].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_dir}/10-10{overlap}/{cl}/images/{str(ind).zfill(5)}.jpg"))
        axes[i, j+1].axis("off")
        axes[i, j+1].set_title(f"DINO {j+1}. best")
    voc_img_clip = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images"))[best_voc_clip]
    axes[i, top_k+1].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images/{voc_img_clip}"))
    axes[i, top_k+1].axis("off")
    axes[i, top_k+1].set_title(f"best VOC {cl} (CLIP)")
    for j, ind in enumerate(best_clip_inds):
        axes[i, j+2+top_k].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_dir}/10-10{overlap}/{cl}/images/{str(ind).zfill(5)}.jpg"))
        axes[i, j+2+top_k].axis("off")
        axes[i, j+2+top_k].set_title(f"CLIP {j+1}. best")
plt.tight_layout()
plt.show()

In [None]:
replay_dir = "replay_data_lora"
overlap = "-ov"
top_k = 2

replay_clip = load_embed_dict(replay_dir, "clip", overlap)
replay_dino = load_embed_dict(replay_dir, "dino", overlap)
fig, axes = plt.subplots(len(classes), 2 * top_k + 2, figsize=(30, 50))
for i, cl in enumerate(classes):
    cossim_voc_clip = cossim_mat(voc_clip[cl], voc_clip[cl])
    cossim_voc_dino = cossim_mat(voc_dino[cl], voc_dino[cl])
    cossim_mat_clip = cossim_mat(replay_clip[cl], voc_clip[cl])
    cossim_mat_dino = cossim_mat(replay_dino[cl], voc_dino[cl])
    worst_clip_inds = np.argsort(cossim_mat_clip.mean(axis=1))[:top_k]
    worst_dino_inds = np.argsort(cossim_mat_dino.mean(axis=1))[:top_k]
    worst_voc_clip = np.argsort(cossim_voc_clip.mean(axis=1))[0]
    worst_voc_dino = np.argsort(cossim_voc_dino.mean(axis=1))[0]
    voc_img_dino = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images"))[worst_voc_dino]
    axes[i, 0].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images/{voc_img_dino}"))
    axes[i, 0].axis("off")
    axes[i, 0].set_title(f"worst VOC {cl} (DINO)")
    for j, ind in enumerate(worst_dino_inds):
        axes[i, j+1].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_dir}/10-10{overlap}/{cl}/images/{str(ind).zfill(5)}.jpg"))
        axes[i, j+1].axis("off")
        axes[i, j+1].set_title(f"DINO {j+1}. worst")
    voc_img_clip = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images"))[worst_voc_clip]
    axes[i, top_k+1].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images/{voc_img_clip}"))
    axes[i, top_k+1].axis("off")
    axes[i, top_k+1].set_title(f"worst VOC {cl} (CLIP)")
    for j, ind in enumerate(worst_clip_inds):
        axes[i, j+2+top_k].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_dir}/10-10{overlap}/{cl}/images/{str(ind).zfill(5)}.jpg"))
        axes[i, j+2+top_k].axis("off")
        axes[i, j+2+top_k].set_title(f"CLIP {j+1}. worst")
plt.tight_layout()
plt.show()

In [None]:
replay_dir = "replay_data_baseline"
overlap = "-ov"
top_k = 2

replay_clip = load_embed_dict(replay_dir, "clip", overlap)
replay_dino = load_embed_dict(replay_dir, "dino", overlap)
fig, axes = plt.subplots(len(classes), 2 * top_k + 2, figsize=(30, 50))
for i, cl in enumerate(classes):
    cossim_voc_clip = cossim_mat(voc_clip[cl], voc_clip[cl])
    cossim_voc_dino = cossim_mat(voc_dino[cl], voc_dino[cl])
    cossim_mat_clip = cossim_mat(replay_clip[cl], voc_clip[cl])
    cossim_mat_dino = cossim_mat(replay_dino[cl], voc_dino[cl])
    worst_clip_inds = np.argsort(cossim_mat_clip.mean(axis=1))[:top_k]
    worst_dino_inds = np.argsort(cossim_mat_dino.mean(axis=1))[:top_k]
    worst_voc_clip = np.argsort(cossim_voc_clip.mean(axis=1))[0]
    worst_voc_dino = np.argsort(cossim_voc_dino.mean(axis=1))[0]
    voc_img_dino = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images"))[worst_voc_dino]
    axes[i, 0].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images/{voc_img_dino}"))
    axes[i, 0].axis("off")
    axes[i, 0].set_title(f"worst VOC {cl} (DINO)")
    for j, ind in enumerate(worst_dino_inds):
        axes[i, j+1].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_dir}/10-10{overlap}/{cl}/images/{str(ind).zfill(5)}.jpg"))
        axes[i, j+1].axis("off")
        axes[i, j+1].set_title(f"DINO {j+1}. worst")
    voc_img_clip = sorted(os.listdir(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images"))[worst_voc_clip]
    axes[i, top_k+1].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/replay_data_voc/10-10{overlap}/{cl}/images/{voc_img_clip}"))
    axes[i, top_k+1].axis("off")
    axes[i, top_k+1].set_title(f"worst VOC {cl} (CLIP)")
    for j, ind in enumerate(worst_clip_inds):
        axes[i, j+2+top_k].imshow(Image.open(f"/home/thesis/marx/wilson_gen/WILSON/{replay_dir}/10-10{overlap}/{cl}/images/{str(ind).zfill(5)}.jpg"))
        axes[i, j+2+top_k].axis("off")
        axes[i, j+2+top_k].set_title(f"CLIP {j+1}. worst")
plt.tight_layout()
plt.show()