In [16]:
import pickle
import clip
import csv
import torch
from PIL import Image
import os
import pandas as pd

In [2]:
file_path = 'pokemon_descriptions.pkl'
with open(file_path, 'rb') as file:
    description_data = pickle.load(file)
    
description_data

{'electric': ['Electric type, small, playful, squirrel-like, adorned in bright blue fur, long tufted ears that stand upright, white zigzag patterns running along its sides, bushy tail resembling a lightning bolt, nimble arms with four tiny fingers on fore',
  'Entry: Electric type, medium-sized, fluffy, bird-like, adorned with bright blue feathers, long crest of feathers on its head, vibrant yellow beak, and sharp orange talons. It has a pair of electric-blue wings that shimmer with energy',
  'Electric-type, small and agile, resembling a vibrant yellow lizard. Its smooth, bright yellow skin is adorned with dark blue spots that shimmer when it moves. It has large, expressive eyes that sparkle with mischief and a wide, toothy grin',
  '**Description:** Electric type, small and nimble, resembling a playful squirrel. Its vibrant blue fur shimmers with a sparkly sheen, accentuating its lively personality. Large, tufted ears stand upright, tipped with bright yellow that catches the light',


In [3]:
description_df = pd.DataFrame(description_data)
description_df.shape

(30, 18)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

100%|███████████████████████████████████████| 354M/354M [02:55<00:00, 2.01MiB/s]


In [7]:
best_images = {} 

for pokemon_type, descriptions in description_data.items():
    type_folder = os.path.join("images", pokemon_type) 

    for i, description in enumerate(descriptions, start=1):

        text_tokens = clip.tokenize([description]).to(device)
        similarities = []
        for j in range(1, 6):  
            image_path = os.path.join(type_folder, f"{i}_{j}.png")
            image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
            
            with torch.no_grad():
                text_embedding = model.encode_text(text_tokens)
                image_embedding = model.encode_image(image)
            
            text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
            image_embedding /= image_embedding.norm(dim=-1, keepdim=True)
            cosine_similarity = (text_embedding @ image_embedding.T).cpu().item()
            similarities.append((cosine_similarity, image_path))
        
        best_image = max(similarities, key=lambda x: x[0])[1]
        

        if pokemon_type not in best_images:
            best_images[pokemon_type] = []
        best_images[pokemon_type].append(best_image)

print(best_images)

{'electric': ['images/electric/1_3.png', 'images/electric/2_5.png', 'images/electric/3_1.png', 'images/electric/4_2.png', 'images/electric/5_3.png', 'images/electric/6_4.png', 'images/electric/7_4.png', 'images/electric/8_5.png', 'images/electric/9_1.png', 'images/electric/10_3.png', 'images/electric/11_2.png', 'images/electric/12_3.png', 'images/electric/13_5.png', 'images/electric/14_4.png', 'images/electric/15_5.png', 'images/electric/16_2.png', 'images/electric/17_3.png', 'images/electric/18_5.png', 'images/electric/19_4.png', 'images/electric/20_3.png', 'images/electric/21_1.png', 'images/electric/22_2.png', 'images/electric/23_1.png', 'images/electric/24_1.png', 'images/electric/25_5.png', 'images/electric/26_5.png', 'images/electric/27_3.png', 'images/electric/28_1.png', 'images/electric/29_2.png', 'images/electric/30_5.png'], 'fire': ['images/fire/1_3.png', 'images/fire/2_2.png', 'images/fire/3_4.png', 'images/fire/4_4.png', 'images/fire/5_3.png', 'images/fire/6_5.png', 'images

In [12]:
rows = []
for pokemon_type, descriptions in description_data.items():
    for i, description in enumerate(descriptions):
        image_path = best_images[pokemon_type][i]
        rows.append({
            "Type": pokemon_type,
            "Description": description,
            "Image Path": image_path
        })

best_pair = pd.DataFrame(rows)
best_pair

Unnamed: 0,Type,Description,Image Path
0,electric,"Electric type, small, playful, squirrel-like, ...",images/electric/1_3.png
1,electric,"Entry: Electric type, medium-sized, fluffy, bi...",images/electric/2_5.png
2,electric,"Electric-type, small and agile, resembling a v...",images/electric/3_1.png
3,electric,"**Description:** Electric type, small and nimb...",images/electric/4_2.png
4,electric,"Entry: Electric type, small and agile, resembl...",images/electric/5_3.png
...,...,...,...
535,normal,"Normal type, A small, fluffy creature resembli...",images/normal/26_3.png
536,normal,"Normal-type, This Pokémon resembles a small, r...",images/normal/27_5.png
537,normal,"Normal type, This Pokémon resembles a small, f...",images/normal/28_4.png
538,normal,"Normal Type: This Pokémon resembles a small, f...",images/normal/29_3.png


In [15]:
best_pair.to_csv('best_pair.csv', index=False) 