# Lib install section

In [None]:
!pip install transformers

In [None]:
from tqdm.auto import tqdm
import numpy as np
from PIL import Image
import pandas as pd
from transformers import CLIPProcessor, CLIPModel

# Define functions and prepare dataset



## Order prompts based on an image

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def order_prompts(model, processor, image, prompt_list):

  inputs = processor(text=prompt_list, images=image, return_tensors="pt", padding=True)
  outputs = model(**inputs)
  logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
  probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities

  df1 = pd.DataFrame(probs.detach().numpy()[0])
  df2 = pd.DataFrame(prompt_list)

  df = pd.concat([df1, df2], axis=1)
  df.columns=['prob', 'prompt']

  return df.sort_values('prob', ascending=False)

## Compute truth table cells

In [None]:
from glob import glob
def count_framed_and_unframed_rations_from_path(file_path):

  folder_filename_list = sorted(glob(file_path+"/*", recursive = True))

  frammed_cpt = 0
  unframmed_cpt = 0
  for curr_file_path in folder_filename_list:
    curr_filename = curr_file_path.split('/')[-1]
    filename_prefix = curr_filename.split('_')[0]

    if filename_prefix == "framed":
      frammed_cpt = frammed_cpt +1
    else:
      unframmed_cpt = unframmed_cpt +1

  
  return frammed_cpt/len(folder_filename_list), unframmed_cpt/len(folder_filename_list)

# Run the sorting

In [None]:
!mkdir -p "/content/drive/MyDrive/AI/clip_interrogator/framed_unframed/output/framed"
!mkdir -p "/content/drive/MyDrive/AI/clip_interrogator/framed_unframed/output/unframed"

In [None]:
# setup params
frame_prompt_list = ['picture of a framed painting', 'picture of an unframed painting']
input_path = "/content/drive/MyDrive/AI/clip_interrogator/framed_unframed/input"
output_path = "/content/drive/MyDrive/AI/clip_interrogator/framed_unframed/output"

# load file paths
from glob import glob
filename_list = sorted(glob(input_path+"/*/*", recursive = True))

# run sorting
for curr_filename in tqdm(filename_list):
  image = Image.open(curr_filename)


  df_out = order_prompts(model, processor, image, frame_prompt_list)
  answer = df_out.iloc[0]['prompt']

  filename = curr_filename.split('/')[-1]

  if 'unframed' in answer:
    target_file_path = output_path.replace('output', 'output/unframed/')
    
  else:
    target_file_path = output_path.replace('output', 'output/framed/')
  
  target_file_path = target_file_path+filename
  # print(target_file_path)
  image.save(target_file_path)

# Compute Truth Table

In [None]:
framed_output = "/content/drive/MyDrive/AI/clip_interrogator/framed_unframed/output/framed"
unframed_output = "/content/drive/MyDrive/AI/clip_interrogator/framed_unframed/output/unframed"

In [None]:
count_framed_and_unframed_rations_from_path(framed_output)

In [None]:
count_framed_and_unframed_rations_from_path(unframed_output)