In [1]:
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer, util

In [2]:
data = pd.read_csv('blip3.csv')

In [3]:
ganfd = pd.read_csv('../GANFD/image_lookup.csv')

In [4]:
merged_data = data.merge(ganfd[['full_ID', 'condition']], left_on='image_id', right_on='full_ID', how='left')
merged_data.drop(columns=['full_ID'], inplace=True)

In [5]:
cosines = merged_data.to_dict()

## Define Functions

In [6]:
def cosines_by_image(data_dict, model_name):
    model = SentenceTransformer(model_name)
    embeddings = model.encode(list(data_dict['response'].values()))
    data_dict['embedding'] = embeddings

    # Initialize containers for results
    cosine_sim_results = []

    # Identify unique conditions
    unique_conditions = set(data_dict['condition'].values())

    for condition in unique_conditions:
        # Initialize lists to hold condition-specific embeddings and images
        condition_embeddings = []
        condition_images = []

        for i in range(len(data_dict['condition'])):
            if data_dict['condition'][i] == condition:
                condition_embeddings.append(data_dict['embedding'][i])
                condition_images.append(data_dict['image_id'][i])

        # Calculate pairwise cosine similarity for the current condition
        for i in range(len(condition_embeddings)):
            for j in range(i + 1, len(condition_embeddings)):
                # Calculate cosine similarity
                sim = cosine_similarity([condition_embeddings[i]], [condition_embeddings[j]])[0][0]
                
                # Append result
                cosine_sim_results.append({
                    'condition': condition,
                    'image_1': condition_images[i],
                    'image_2': condition_images[j],
                    'cosine': sim
                })

    results_df = pd.DataFrame(cosine_sim_results)
    return results_df

In [7]:
mpnetbase_df = cosines_by_image(cosines, 'sentence-transformers/all-mpnet-base-v2')
mpnetbase_df.to_csv('blip3_cosines.csv', index = False)