# Winoground CLIP Retrieval

Retrieve images from Winoground captions and images using CLIP Retrieval.

## Install

In [None]:
%pip install clip-retrieval

## Setup CLIP Client

In [1]:
from IPython.display import Image, display
from clip_retrieval.clip_client import ClipClient, Modality

def log_result(result):
    id, caption, url, similarity = result["id"], result["caption"], result["url"], result["similarity"]
    print(f"id: {id}")
    print(f"caption: {caption}")
    print(f"url: {url}")
    print(f"similarity: {similarity}")
    display(Image(url=url, unconfined=True))

client = ClipClient(
    url="https://knn5.laion.ai/knn-service",
    indice_name="laion5B",
    aesthetic_score=9,
    aesthetic_weight=0.5,
    modality=Modality.IMAGE,
    num_images=40,
)

## Query by text

In [2]:
cat_results = client.query(text="an image of a cat")
log_result(cat_results[0])

id: 518836491
caption: orange cat with supicious look stock photo
url: https://media.istockphoto.com/photos/orange-cat-with-supicious-look-picture-id907595140?k=6&amp;m=907595140&amp;s=612x612&amp;w=0&amp;h=4CTvSxNvv4sxSCPxViryha4kAjuxDbrXM5vy4VPOuzk=
similarity: 0.5591725707054138


## Query by image

In [3]:
beach_results = client.query(image="https://github.com/rom1504/clip-retrieval/raw/main/tests/test_clip_inference/test_images/321_421.jpg")
log_result(beach_results[0])

id: 574870177
caption: Palm trees in Orlando, Florida
url: https://www.polefitfreedom.com/wp-content/uploads/2018/03/Orlando.jpg
similarity: 0.9619366526603699


## Query by embedding

In [4]:
import clip  # pylint: disable=import-outside-toplevel
import torch

model, preprocess = clip.load("ViT-L/14", device="cpu", jit=True)

import urllib
import io
import numpy as np

def download_image(url):
    urllib_request = urllib.request.Request(
        url,
        data=None,
        headers={"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"},
    )
    with urllib.request.urlopen(urllib_request, timeout=10) as r:
        img_stream = io.BytesIO(r.read())
    return img_stream

def normalized(a, axis=-1, order=2):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)

def get_text_emb(text):
    with torch.no_grad():
        text_emb = model.encode_text(clip.tokenize([text], truncate=True).to("cpu"))
        text_emb /= text_emb.norm(dim=-1, keepdim=True)
        text_emb = text_emb.cpu().detach().numpy().astype("float32")[0]
    return text_emb

from PIL import Image as pimage

def get_image_emb_url(image_url):
    with torch.no_grad():
        image = pimage.open(download_image(image_url))
        image_emb = model.encode_image(preprocess(image).unsqueeze(0).to("cpu"))
        image_emb /= image_emb.norm(dim=-1, keepdim=True)
        image_emb = image_emb.cpu().detach().numpy().astype("float32")[0]
        return image_emb
    
def get_image_emb(image_path):
    with torch.no_grad():
        image = pimage.open(image_path)
        image_emb = model.encode_image(preprocess(image).unsqueeze(0).to("cpu"))
        image_emb /= image_emb.norm(dim=-1, keepdim=True)
        image_emb = image_emb.cpu().detach().numpy().astype("float32")[0]
        return image_emb

In [9]:
red_tshirt_text_emb = get_text_emb("red tshirt")
red_tshirt_results = client.query(embedding_input=red_tshirt_text_emb.tolist())
log_result(red_tshirt_results[0])

id: 2195290014
caption: CTS29 100% Cotton T Shirt Crew Neck V Neck Long Sleeves Solid Maroon
url: https://cdn.shopify.com/s/files/1/1531/9423/products/CTS29_100_Cotton_T_Shirt_Crew_Neck_V_Neck_Long_Sleeves_Solid_Maroon_large.jpg?v=1476875650
similarity: 0.5375759601593018


In [6]:
blue_dress_image_emb = get_image_emb_url("https://rukminim1.flixcart.com/image/612/612/kv8fbm80/dress/b/5/n/xs-b165-royal-blue-babiva-fashion-original-imag86psku5pbx2g.jpeg?q=70")
blue_dress_results = client.query(embedding_input=blue_dress_image_emb.tolist())
log_result(blue_dress_results[0])

id: 2463946620
caption: 8c7889e0b92b Cinderella Divine 1295 Long Chiffon Grecian Royal Blue Dress Mid Length  Sleeves V Neck ...
url: https://cdn.shopify.com/s/files/1/1417/0920/products/1295cd-royal-blue_cfcbd4bc-ed74-47c0-8659-c1b8691990df.jpg?v=1527650905
similarity: 0.9430058598518372


In [7]:
red_tshirt_text_emb =  get_text_emb("red tshirt")
blue_dress_image_emb = get_image_emb_url("https://rukminim1.flixcart.com/image/612/612/kv8fbm80/dress/b/5/n/xs-b165-royal-blue-babiva-fashion-original-imag86psku5pbx2g.jpeg?q=70")
mean_emb = normalized(red_tshirt_text_emb + blue_dress_image_emb)[0]
mean_results = client.query(embedding_input=mean_emb.tolist())
log_result(mean_results[0])

id: 2702080924
caption: CLEARANCE - Long Chiffon Grecian Red Dress Mid Length Sleeves V Neck (Size Medium)
url: https://cdn-img-3.wanelo.com/p/716/c27/0c0/aef7a32a4317370b6f7f14b/x354-q80.jpg
similarity: 0.8246004581451416


## Query by Winoground Captions and Images

In [10]:
import json
winoground = []
with open('../data/examples.jsonl') as f:
    for line in f:
        winoground.append(json.loads(line))

In [14]:
from tqdm import tqdm
for example in tqdm(winoground):
    caption_0 = example["caption_0"]
    caption_1 = example["caption_1"]
    image_0 = f'../data/images/{example["image_0"]}.png'
    image_1 = f'../data/images/{example["image_1"]}.png'
    
    caption_0_emb = get_text_emb(caption_0)
    caption_1_emb = get_text_emb(caption_1)
    image_0_emb = get_image_emb(image_0)
    image_1_emb = get_image_emb(image_0)
    caption_0_image_0_emb = normalized(caption_0_emb + image_0_emb)[0]
    caption_0_image_1_emb = normalized(caption_0_emb + image_1_emb)[0]
    caption_1_image_0_emb = normalized(caption_1_emb + image_0_emb)[0]
    caption_1_image_1_emb = normalized(caption_1_emb + image_1_emb)[0]
    
    example["caption_0_retrieval"] = client.query(embedding_input=caption_0_emb.tolist())
    example["caption_1_retrieval"] = client.query(embedding_input=caption_1_emb.tolist())
    example["image_0_retrieval"] = client.query(embedding_input=image_0_emb.tolist())
    example["image_1_retrieval"] = client.query(embedding_input=image_1_emb.tolist())
    example["caption_0_image_0_retrieval"] = client.query(embedding_input=caption_0_image_0_emb.tolist())
    example["caption_0_image_1_retrieval"] = client.query(embedding_input=caption_0_image_1_emb.tolist())
    example["caption_1_image_0_retrieval"] = client.query(embedding_input=caption_1_image_0_emb.tolist())
    example["caption_1_image_1_retrieval"] = client.query(embedding_input=caption_1_image_1_emb.tolist())

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [1:00:15<00:00,  9.04s/it]


In [15]:
with open('clip_retrieval.json', 'w') as f:
    json.dump(winoground, f, indent=4)