In [None]:
from dotenv import load_dotenv
load_dotenv()
import os
import sys
sys.path.append(os.getenv('PYTHONPATH')) 
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from itertools import islice
from pathlib import Path
import glob

In [None]:
def visualize_similar_images(image_data, stiminfo, stim_path='./'):
    """
    Visualize similar images in rows. Each row represents one test image
    and its corresponding similar images.
    
    Parameters:
    image_data (list): A list of dictionaries where the key is the test image filename,
                       and the values are a list of similar image filenames.
    """
    repetition_columns = [col for col in stiminfo.columns if f'_reps_' in col]
    fs=10 #fontsize for titless
    # Number of test images to visualize
    num_rows = len(image_data)
    
    # Determine the maximum number of similar images to set plot columns
    max_similar_images = max(len(images) for images in image_data.values())
    
    fig, axes = plt.subplots(num_rows, max_similar_images + 1, figsize=(15, 3 * num_rows))
    #plt.subplots_adjust(hspace=0.4, wspace=0.2)
    # Loop through each test image and its similar images
    for row_idx, (test_image, similar_images) in enumerate(image_data.items()):
        stim_info = stiminfo[stiminfo['filename'] == test_image]
        seen_by = []
        for col in repetition_columns:
            if stim_info[col].array[0] > 0:
                seen_by.append(col.replace('_reps', ''))
            # Load and plot the test image
        if Path(test_image).suffix == '.mp4':
            #find the corresponding middle frame to the video
            middle_frame = glob.glob(os.path.join(stim_path, "frames_middle", f"{Path(test_image).stem}*.jpg"))
            assert(len(middle_frame) == 1)
            test_img = Image.open(middle_frame[0])
        else:
            test_img = Image.open(os.path.join(stim_path, "raw", test_image))
        axes[row_idx, 0].imshow(test_img)
        axes[row_idx, 0].set_title(f"Test Image\n{test_image}", fontsize=fs) #(f"Test Image {seen_by}")
        axes[row_idx, 0].axis("off")
        
        # Load and plot each similar image in the row
        for col_idx, sim_image in enumerate(similar_images, start=1):
            stim_info = stiminfo[stiminfo['filename'] == sim_image]
            seen_by = []
            for col in repetition_columns:
                if stim_info[col].array[0] > 0:
                    seen_by.append(col.replace('_reps', ''))
            if Path(sim_image).suffix == '.mp4':
                #find the corresponding middle frame to the video
                middle_frame = glob.glob(os.path.join(stim_path, "frames_middle", f"{Path(sim_image).stem}*.jpg"))
                assert(len(middle_frame) == 1)
                sim_img = Image.open(middle_frame[0])
            else:
                sim_img = Image.open(os.path.join(stim_path, "raw", sim_image))
            axes[row_idx, col_idx].imshow(sim_img)
            axes[row_idx, col_idx].set_title(f"Similar Train\n{sim_image}", fontsize=fs) #(f"Similar Train {seen_by}")
            axes[row_idx, col_idx].axis("off")
        
        # Hide any unused subplots
        for extra_col in range(len(similar_images) + 1, max_similar_images + 1):
            axes[row_idx, extra_col].axis("off")
    
    plt.tight_layout()
    plt.show()

def retrieve_image(filename, stim_path='./'):
    if Path(filename).suffix == '.mp4':
        #find the corresponding middle frame to the video
        middle_frame = glob.glob(os.path.join(stim_path, "frames_middle", f"{Path(filename).stem}*.jpg"))
        assert(len(middle_frame) == 1)
        test_img = Image.open(middle_frame[0])
    else:
        test_img = Image.open(os.path.join(stim_path, "raw", filename))
    plt.imshow(test_img)
    plt.axis("off")
    plt.show()

In [None]:
dataset_root =  os.path.join(os.getenv("DATASETS_ROOT", "/default/path/to/datasets"))
save_root = os.path.join(dataset_root, "MOSAIC")
with open(os.path.join(save_root, 'perceptually_similar_exclusions.pkl'), 'rb') as f:
    excluded_stim = pickle.load(f)
compiled_stiminfo = pd.read_table(os.path.join(save_root, "stimuli", "datasets_stiminfo", "compiled_dataset_stiminfo.tsv"), low_memory=False)
datasets = ['NSD','BMD','BOLD5000','THINGS','GOD','deeprecon','HAD','NOD']

In [None]:
count = 0
for k, v in excluded_stim.items():
    count += len(v)
print(f"{count} stimuli were removed because of high similarity to one of {len(excluded_stim)} test stimulus.")


In [None]:
sliced_data = dict(islice(excluded_stim.items(), 50))
print(sliced_data)
data = {}
for k,v in sliced_data.items():
    data[k] = v[:15]
    print(k, v)

In [None]:
#given a test image (a dictionary key) display it and all the similar ones
test_img = '000000112734.jpg'
retrieve_image(test_img, stim_path=os.path.join(save_root, 'stimuli'))
for img in data['000000112734.jpg']:
    retrieve_image(img, stim_path=os.path.join(save_root, 'stimuli'))

In [None]:
#retrieve a single image
retrieve_image('000000021718.jpg', stim_path=os.path.join(save_root, 'stimuli'))

In [None]:
visualize_similar_images(data, compiled_stiminfo, stim_path=os.path.join(save_root, 'stimuli'))