In [None]:
!pip install transformers

In [None]:
from PIL import Image
import requests
import torch

from transformers import CLIPProcessor, CLIPModel
import torch.nn.functional as F
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)

def sim_score(images,prompts,weights):
    inputs = processor(text=prompt, images=images, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    outputs = model(**inputs)
    image_embed=outputs.image_embeds
    target_embeds=outputs.text_embeds
    # print(image_embed.shape)
    # print(target_embeds.shape)
    with torch.no_grad():
        input_normed = F.normalize(image_embed, dim=1)
        loss = torch.zeros(image_embed.shape[0]).to(device)
        for text_embed, weight in zip(target_embeds, weights):#can do without for loop but too lazy to change
            embed_normed = F.normalize(text_embed, dim=0)
            dists = input_normed.sub(embed_normed).norm(dim=1).div(2).arcsin().pow(2).mul(2) # Squared Great Circle Distance
            loss += dists*weight
        return loss

In [None]:
from datasets import load_dataset
from pathlib import Path

dataset = load_dataset("ceyda/smithsonian_butterflies")

print(dataset)

In [None]:
# Optional data processing stuff
import os
def bg(ex):
    ex["image"]=Image.open(f"./data_smith_higres/transparent/{ex['image_hash']}.png")
    return ex
dataset=dataset.filter(lambda ex: os.path.exists(f"./data_smith_higres/transparent/{ex['image_hash']}.png"))
dataset=dataset.map(bg)
dataset=dataset["train"]

In [None]:
# We can vary this to find what works best
prompt = ['pretty butterfly'] # can use multiple prompts (not tested)

def calc(ex):
    im = [x.convert("RGB") for x in ex['image']] # necessary because my images are rgba
    loss = sim_score(im, prompt, [1])
    print(loss)
    for i,l in enumerate(loss.cpu().numpy()):
        ex['sim_score'][i]=l
    return ex


In [None]:
dset_w_sim = dataset.map(calc,batched=True,batch_size=20)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
sorted = dset_w_sim.sort('sim_score')

n_rows = 4
fig, axs = plt.subplots(n_rows, n_rows, figsize=(9, 9))
for i, ax in enumerate(axs.flatten()):
  ax.imshow(np.array(sorted[i]['image']))
plt.tight_layout()

In [None]:
# Compare to those with worst scores:
n_rows = 4
fig, axs = plt.subplots(n_rows, n_rows, figsize=(9, 9))
for i, ax in enumerate(axs.flatten()):
  ax.imshow(np.array(sorted[-(i+5)]['image']))
plt.tight_layout()

In [None]:
# How are we looking further down the list?
offset=3000 # Pretty good at 5k
n_rows = 4
fig, axs = plt.subplots(n_rows, n_rows, figsize=(9, 9))
for i, ax in enumerate(axs.flatten()):
  ax.imshow(np.array(sorted[i+offset]['image']))
plt.tight_layout()

In [None]:
# And even further?
offset=2000 # Not quite so good - some caterpllars etc
n_rows = 4
fig, axs = plt.subplots(n_rows, n_rows, figsize=(9, 9))
for i, ax in enumerate(axs.flatten()):
  ax.imshow(np.array(sorted[i+offset]['image']))
plt.tight_layout()

In [None]:
# Let's grab the top ten (without changin order as we would if we did sorted[:10000])
score_thresh = sorted[10000]['sim_score']
filtered = dset_w_sim.filter(lambda x: x['sim_score'] < score_thresh)
len(filtered)

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
dset_w_sim.push_to_hub('ceyda/smithsonian_butterflies_transparent')

In [None]:
dset_w_sim[0]