In [2]:
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import torch
import clip
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from glob import glob

# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_viT_32, preprocess_viT_32 = clip.load("./openaiclipweights/clip/CLIP/models/ViT-B-32.pt")
model_viT_32.cpu().eval()

# Precompute image embeddings
files = glob('*.jpeg') + glob('*.jpg')
image_embeddings = []

for file in files:
    image = preprocess_viT_32(Image.open(file).convert("RGB")).unsqueeze(0)
    with torch.no_grad():
        image_embeddings.append(model_viT_32.encode_image(image).cpu().detach().numpy())

image_embeddings = np.vstack(image_embeddings)

QUERIES = [
    "A red toy car",
    "A blue toy car",
    "A pink toy car",
    "A black toy car",
    "A white toy car",
    "A silver toy car",
    "A toy car with racing stripes",
    "A toy car with big wheels",
    "A vintage toy car",
    "A toy car collection",
    "A toy car in a box",
    "A toy car with a driver inside",
    "A toy car with many passengers inside",
    "A toy car with an open roof",
    "A toy car and a toy truck",
    "A toy car on top of a toy truck",
    "A GTR toy car",
    "A Mazda toy car",
    "A Bugatti toy car",
    "Barbie toy car",
    "Two toy cars",
    "Four toy cars",
    "Eight toy cars",
    "A toy car in the UK",
    "A toy car moving in a city",
    "A police toy car"
]

# GUI functions
def process_input(user_input):
    if isinstance(user_input, str):
        # Text input
        text_embedding = model_viT_32.encode_text(clip.tokenize(user_input)).cpu().detach().numpy()
        similarities = (image_embeddings @ text_embedding.T).squeeze()
        best_match_idx = np.argmax(similarities)
        best_image = Image.open(files[best_match_idx])
        show_image(best_image)
        result_text.set(f"Query: {user_input}")

    elif isinstance(user_input, Image.Image):
        # Image input
        image_embedding = model_viT_32.encode_image(preprocess_viT_32(user_input).unsqueeze(0)).cpu().detach().numpy()
        query_embeddings = model_viT_32.encode_text(clip.tokenize(QUERIES)).cpu().detach().numpy()
        similarities = (query_embeddings @ image_embedding.T).squeeze()
        best_match_idx = np.argmax(similarities)
        result_text.set(f"Best Match: {QUERIES[best_match_idx]}")

def show_image(image):
    img = ImageTk.PhotoImage(image.resize((250, 250)))
    panel.configure(image=img)
    panel.image = img  # Keep a reference to avoid garbage collection

def browse_file():
    filename = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.jpeg")])
    if filename:
        image_input = Image.open(filename)
        show_image(image_input)
        process_input(image_input)

def submit_text():
    text_input = text_entry.get()
    if text_input:
        process_input(text_input)

# Setting up Tkinter window
root = tk.Tk()
root.title("CLIP Model GUI")

frame = tk.Frame(root)
frame.pack(padx=10, pady=10)

# File input button
browse_button = tk.Button(frame, text="Upload Image", command=browse_file)
browse_button.grid(row=0, column=0, padx=5, pady=5)

# Text input field
text_entry = tk.Entry(frame, width=50)
text_entry.grid(row=1, column=0, padx=5, pady=5)

# Submit text button
submit_button = tk.Button(frame, text="Submit Text", command=submit_text)
submit_button.grid(row=1, column=1, padx=5, pady=5)

# Display area for images
panel = tk.Label(root)
panel.pack(padx=10, pady=10)

# Result label for displaying best match
result_text = tk.StringVar()
result_label = tk.Label(root, textvariable=result_text, font=("Helvetica", 16))
result_label.pack(padx=10, pady=10)

root.mainloop()


In [9]:
print("Image Embeddings:")
print(image_embeddings)
print(image_embeddings.shape)

Image Embeddings:
[[ 0.10161975 -0.40000352 -0.24615626 ...  0.5217787  -0.26385424
  -0.15515895]
 [ 0.2412628   0.2906397   0.09077024 ... -0.15605108 -0.02410448
  -0.1958622 ]
 [ 0.01464751  0.5035768  -0.2592774  ... -0.5374011  -0.0487541
   0.3548571 ]
 ...
 [ 0.07610247  0.02436707 -0.30105963 ...  0.12192594 -0.5808713
   0.6463735 ]
 [-0.0241044   0.02134677 -0.21804848 ...  0.2786589  -0.14082618
   0.1319272 ]
 [-0.00218999 -0.5937021   0.3551427  ...  0.33694646  0.03228282
   0.4029902 ]]
(13, 512)


In [10]:
print("Query Embeddings:")
query_embeddings = model_viT_32.encode_text(clip.tokenize(QUERIES)).cpu().detach().numpy()
print(query_embeddings)
print(query_embeddings.shape)

Query Embeddings:
[[-0.11993366  0.0638367   0.04534612 ... -0.0104454  -0.46732277
  -0.02355841]
 [-0.03105263 -0.134765    0.03857469 ...  0.19692582 -0.62035906
   0.10483619]
 [-0.3232686  -0.17871684  0.17166413 ...  0.16348962 -0.1307304
   0.20045945]
 ...
 [-0.44422528  0.221954   -0.03552963 ... -0.2808385  -0.3952454
  -0.07301737]
 [-0.07197786 -0.00512628  0.03806114 ...  0.17857713 -0.2079634
   0.0567353 ]
 [-0.1759924   0.19732346 -0.02507975 ...  0.09942378 -0.5817207
   0.05117435]]
(26, 512)
