In [49]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

# Load a pre-trained sentence-transformer model
model = SentenceTransformer("all-MiniLM-L6-v2")

def get_average_cosine_similarity(prompts, captions, embedding_model=model):
    # Compute sentence embeddings
    prompt_embeddings = embedding_model.encode(prompts)
    caption_embeddings = embedding_model.encode(captions)

    # for every sentence embedding in list a and b, compute the cosine similarity. Then, take the average of all the cosine similarities
    similarities_list = []
    for i in tqdm(range(len(prompt_embeddings))):
        similarities_list.append(cosine_similarity([prompt_embeddings[i]], [caption_embeddings[i]]))

    average_cosine_similarity = sum(similarities_list) / len(similarities_list)
    return round(average_cosine_similarity.item(), 3)

In [50]:
# example usage
get_average_cosine_similarity(["The quick brown fox jumps over the lazy dog.", "I love coding in Python"],
                               ["A lazy dog is jumped over by a quick brown fox.", "Python is my favorite programming language."])

100%|██████████| 2/2 [00:00<00:00, 1673.70it/s]


0.872

# Compute mean cosine similarity between prompts and captions - CARS

In [51]:
import zipfile

car_prompts_archive = 'datasets_generated_ready_models\SD_XL\cars_zipped.zip'
car_prompts_name = 'cars_prompts.txt'
car_captions_sdxl_path = 'datasets_generated_ready_models\SD_XL\image_captions_cars_zipped.txt'
car_captions_dalle3_path = 'datasets_generated_ready_models\DALLE3\image_captions_dalle_cars_zipped.txt'

with zipfile.ZipFile(car_prompts_archive) as z:
    with z.open(car_prompts_name) as f:
        car_prompts = f.read().decode('utf-8').splitlines()

with open(car_captions_sdxl_path, 'r', encoding='utf-8') as f:
    car_captions_sdxl = f.read().splitlines()

with open(car_captions_dalle3_path, 'r', encoding='utf-8') as f:
    car_captions_dalle3 = f.read().splitlines()

In [56]:
mean_similarity_sdxl = get_average_cosine_similarity(car_prompts, car_captions_sdxl)
mean_similarity_dalle3 = get_average_cosine_similarity(car_prompts, car_captions_dalle3)

mean_similarity_sdxl_dalle3 = get_average_cosine_similarity(car_captions_sdxl, car_captions_dalle3)

print(f"Mean similarity between prompts and SD_XL captions: {mean_similarity_sdxl}")
print(f"Mean similarity between prompts and DALL-E 3 captions: {mean_similarity_dalle3}")

print(f"Mean similarity between SD_XL and DALL-E 3 captions: {mean_similarity_sdxl_dalle3}")

100%|██████████| 50/50 [00:00<00:00, 2275.38it/s]
100%|██████████| 50/50 [00:00<00:00, 1692.89it/s]
100%|██████████| 50/50 [00:00<00:00, 2287.72it/s]

Mean similarity between prompts and SD_XL captions: 0.278
Mean similarity between prompts and DALL-E 3 captions: 0.26
Mean similarity between SD_XL and DALL-E 3 captions: 0.569





# Compute mean cosine similarity between prompts and captions - WILDLIFE

In [57]:
import zipfile

wildlife_prompts_archive = 'datasets_generated_ready_models\SD_XL\wildlife_zipped.zip'
wildlife_prompts_name = 'wildlife_prompts.txt'
wildlife_captions_sdxl_path = 'datasets_generated_ready_models\SD_XL\image_captions_wildlife_zipped.txt'
wildlife_captions_dalle3_path = 'datasets_generated_ready_models\DALLE3\image_captions_dalle_wildlife_zipped.txt'

with zipfile.ZipFile(wildlife_prompts_archive) as z:
    with z.open(wildlife_prompts_name) as f:
        wildlife_prompts = f.read().decode('utf-8').splitlines()

with open(wildlife_captions_sdxl_path, 'r', encoding='utf-8') as f:
    wildlife_captions_sdxl = f.read().splitlines()

with open(wildlife_captions_dalle3_path, 'r', encoding='utf-8') as f:
    wildlife_captions_dalle3 = f.read().splitlines()

In [58]:
mean_similarity_sdxl = get_average_cosine_similarity(wildlife_prompts, wildlife_captions_sdxl)
mean_similarity_dalle3 = get_average_cosine_similarity(wildlife_prompts, wildlife_captions_dalle3)

mean_similarity_sdxl_dalle3 = get_average_cosine_similarity(wildlife_captions_sdxl, wildlife_captions_dalle3)

print(f"Mean similarity between prompts and SD_XL captions: {mean_similarity_sdxl}")
print(f"Mean similarity between prompts and DALL-E 3 captions: {mean_similarity_dalle3}")

print(f"Mean similarity between SD_XL and DALL-E 3 captions: {mean_similarity_sdxl_dalle3}")

100%|██████████| 50/50 [00:00<00:00, 2683.63it/s]
100%|██████████| 50/50 [00:00<00:00, 2246.38it/s]
100%|██████████| 50/50 [00:00<00:00, 2225.09it/s]

Mean similarity between prompts and SD_XL captions: 0.259
Mean similarity between prompts and DALL-E 3 captions: 0.248
Mean similarity between SD_XL and DALL-E 3 captions: 0.751





# Compute mean cosine similarity between prompts and captions - BRECAHAD

In [61]:
import zipfile

brecahad_prompts_archive = 'datasets_generated_ready_models\SD_XL\\brecahad_zipped.zip'
brecahad_prompts_name = 'brecahad_prompts.txt'
brecahad_captions_sdxl_path = 'datasets_generated_ready_models\SD_XL\image_captions_brecahad_zipped.txt'
brecahad_captions_dalle3_path = 'datasets_generated_ready_models\DALLE3\image_captions_dalle_brecahad_zipped.txt'

with zipfile.ZipFile(brecahad_prompts_archive) as z:
    with z.open(brecahad_prompts_name) as f:
        brecahad_prompts = f.read().decode('utf-8').splitlines()
        # remove line 4, not present in dalle3 images
        brecahad_prompts = brecahad_prompts[:3] + brecahad_prompts[4:]

with open(brecahad_captions_sdxl_path, 'r', encoding='utf-8') as f:
    brecahad_captions_sdxl = f.read().splitlines()
    # remove line 4, not present in dalle3 images
    brecahad_captions_sdxl = brecahad_captions_sdxl[:3] + brecahad_captions_sdxl[4:]

with open(brecahad_captions_dalle3_path, 'r', encoding='utf-8') as f:
    brecahad_captions_dalle3 = f.read().splitlines()

In [62]:
mean_similarity_sdxl = get_average_cosine_similarity(brecahad_prompts, brecahad_captions_sdxl)
mean_similarity_dalle3 = get_average_cosine_similarity(brecahad_prompts, brecahad_captions_dalle3)

mean_similarity_sdxl_dalle3 = get_average_cosine_similarity(brecahad_captions_sdxl, brecahad_captions_dalle3)

print(f"Mean similarity between prompts and SD_XL captions: {mean_similarity_sdxl}")
print(f"Mean similarity between prompts and DALL-E 3 captions: {mean_similarity_dalle3}")

print(f"Mean similarity between SD_XL and DALL-E 3 captions: {mean_similarity_sdxl_dalle3}")

100%|██████████| 9/9 [00:00<00:00, 2249.89it/s]
100%|██████████| 9/9 [00:00<00:00, 563.62it/s]
100%|██████████| 9/9 [00:00<00:00, 984.32it/s]

Mean similarity between prompts and SD_XL captions: 0.214
Mean similarity between prompts and DALL-E 3 captions: 0.167
Mean similarity between SD_XL and DALL-E 3 captions: 0.416



