In [15]:
os.getcwd() 

'/Users/franciszekruszkowski/Desktop/OPPLY'

In [85]:
import torch
from transformers import CLIPModel, CLIPProcessor

class CLIPFeatureExtractor:
    
    def __init__(self):
        model_name = "openai/clip-vit-base-patch32"
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    @torch.no_grad()
    def encode_text(self, text):
        
        # input text
        # torch tensor: (1,512) vector of text embedding
        
        inputs = self.processor(text=text, return_tensors="pt")
        inputs = inputs.to(self.device)
        
        text_features = self.model.get_text_features(**inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        
        return text_features
   

    @torch.no_grad()
    def encode_images(self, images):
        
        # input images
        # torch_tensor : (n,1,512) vector of image features
        
        inputs = self.processor(images=images, return_tensors="pt")
        inputs = inputs.to(self.device)
        
        image_features = self.model.get_image_features(**inputs)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        
        return image_features

    
if __name__ == '__main__':
    pass

In [86]:
import numpy as np
from sentence_transformers import util
from PIL import Image
import os
from sklearn.preprocessing import MinMaxScaler

In [87]:
ft = CLIPFeatureExtractor()
almonds_text_embedding = ft.encode_text('almonds')

In [88]:
def load_image(fname):
    img = Image.open(fname)
    return img.convert('RGB')

In [89]:
def make_predictions_for_directory(directory):
    
    image_files = os.listdir(directory)
    image_files = [os.path.join(directory, f) for f in image_files]
    
    images = [load_image(f) for f in image_files]
    
    image_features = ft.encode_images(images)
    
    similarities = [util.cos_sim(img, almonds_text_embedding).item() for img in image_features]
    
    return np.mean(similarities)

In [128]:
def make_preds_for_examples(res_path='examples/'):
    
    irrelevant_dir = os.path.join(res_path, 'Irrelevant')
    relevant_dir = os.path.join(res_path, 'Relevant')

    irrelevant_examples = [os.path.join(irrelevant_dir, example) for example in os.listdir(irrelevant_dir)]
    relevant_examples = [os.path.join(relevant_dir, example) for example in os.listdir(relevant_dir)]
    
    examples = irrelevant_examples + relevant_examples
    
    preds = [make_predictions_for_directory(example) for example in examples]
    preds_dict = {example: pred for example, pred in zip(examples, preds)}
    
    
    ### Scaled 
    scaler = MinMaxScaler()

    preds_scaled = scaler.fit_transform(np.array([preds]).reshape(-1,1))
    preds_dict_scaled = {example: pred for example, pred in zip(examples, preds_scaled)}
    
    return preds_dict, preds_dict_scaled

In [129]:
def make_preds_for_google_images(res_path='google_images/'):
    
    image_dirs = os.listdir(res_path)
    image_dir_paths = [os.path.join(res_path, image_dir) for image_dir in image_dirs]
    
    labels = [img_dir.lower() for img_dir in image_dirs]
    
    avg_similarities = [make_predictions_for_directory(directory) for directory in (image_dir_paths)]
    
    preds_dict = {label: sim for label, sim in zip(labels, avg_similarities)}
    
    return preds_dict  

In [133]:
print('Making predictions for examples: ')
examples_dic,   preds_scaled = make_preds_for_examples()
examples_dic

Making predictions for examples: 


{'examples/Irrelevant/Example 1': 0.16390372753143312,
 'examples/Irrelevant/Example 3': 0.20035310144777652,
 'examples/Irrelevant/Example 2': 0.20734971463680268,
 'examples/Relevant/Example 1': 0.23983665242791175,
 'examples/Relevant/Example 3': 0.24710791371762753,
 'examples/Relevant/Example 2': 0.24388677423650568}

In [134]:
print('Making predictions for examples - MinMax SCALED: ')
preds_scaled

Making predictions for examples - MinMax SCALED: 


{'examples/Irrelevant/Example 1': array([0.]),
 'examples/Irrelevant/Example 3': array([0.4380714]),
 'examples/Irrelevant/Example 2': array([0.52216107]),
 'examples/Relevant/Example 1': array([0.91260943]),
 'examples/Relevant/Example 3': array([1.]),
 'examples/Relevant/Example 2': array([0.96128633])}

In [135]:
print('Making predictions for google images')
google_images_dic = make_preds_for_google_images()
google_images_dic

Making predictions for google images


  "Palette images with Transparency expressed in bytes should be "


{'maple syrup': 0.18799603870138526,
 'dates': 0.2504647237062454,
 'cocoa': 0.22287209704518318,
 'coconut': 0.2238942816549418,
 'peanuts': 0.2437329798936844,
 'oats': 0.21118771225214006,
 'almonds': 0.30445924401283264,
 'hazelnuts': 0.2631231632828712,
 'random': 0.18072939336299895,
 'sunflower oil': 0.1849773100444249}