In [1]:
from dotenv import load_dotenv
load_dotenv()
from PIL import Image
import os
from tqdm import tqdm
from pathlib import Path
import pandas as pd
import glob
import numpy as np
import torch
import seaborn as sns
from dreamsim import dreamsim
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


In [None]:
dataset_root = os.getenv("DATASETS_ROOT", "/default/path/to/datasets") #use default if DATASETS_ROOT env variable is not set.
print(f"dataset_root: {dataset_root}")
save_root = os.path.join(dataset_root, "MOSAIC")

In [3]:
bmd_stiminfo = pd.read_table(os.path.join(save_root, "stimuli", "datasets_stiminfo", "bmd_stiminfo.tsv"))
filenames = bmd_stiminfo['filename']

In [None]:
model_name = 'dreamsim'
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model, preprocess = dreamsim(pretrained=True, cache_dir=os.path.join(os.getenv('CACHE'),".cache"))

In [None]:
"""
we use the second frame and second to last frame because some videos have a black screen first or last frame from the frame conversion process
"""
similarity = []
for filename in tqdm(filenames):
    stem = Path(filename).stem
    #load frames
    first_frame = glob.glob(os.path.join(dataset_root, "MOSAIC", "stimuli", "frames", stem, f"{stem}_frame-0002_*.jpg"))
    assert(len(first_frame) == 1)
    total_frames = int(Path(first_frame[0]).stem.split('_')[-1])
    last_frame = glob.glob(os.path.join(dataset_root, "MOSAIC", "stimuli", "frames", stem, f"{stem}_frame-{total_frames-1:04}_{total_frames:04}.jpg"))    
    assert(len(last_frame) == 1)

    #compute similarity between them
    imgA = preprocess(Image.open(first_frame[0])).to(device)
    imgB = preprocess(Image.open(last_frame[0])).to(device)
    imgA_embedding = model.embed(imgA).detach().cpu().numpy()
    imgB_embedding = model.embed(imgB).detach().cpu().numpy()
    similarity.append(cosine_similarity(imgA_embedding, imgB_embedding)[0][0])

In [None]:
print(np.mean(similarity))

In [None]:
print(np.mean(similarity))
sns.violinplot(np.array(similarity))
plt.show()