In [None]:
%pip install pillow
%pip install torch
%pip install open_clip_torch
%pip install requests

In [None]:
import torch 
import open_clip

if torch.cuda.is_available():
    # Use the GPU (CUDA) as the device
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    # Use the CPU as the device
    device = torch.device("cpu")
    print("Using CPU")

# Load the CLIP model and tokenizer
# model, preprocess = open_clip.load('ViT-B/32')

print("Loading model...")

model, _, preprocess = open_clip.create_model_and_transforms(
    'ViT-B-32', 
    pretrained='laion2b_s34b_b79k',
    device=device
)

In [3]:
import requests

def load_keywords_from_urls(url_list):
    keyword_list = []
    for url in url_list:
        try:
            response = requests.get(url)
            response.raise_for_status()
            # Split the content by lines and strip leading/trailing whitespace
            keywords = [line.strip() for line in response.text.splitlines()]
            # Add keywords to the master keyword_list
            keyword_list.extend(keywords)
        except requests.exceptions.RequestException as e:
            print(f"Error loading keywords from URL ({url}): {e}")
    print(f"Loaded {len(keyword_list)} keywords...")
    return keyword_list

In [4]:
import os

def load_keywords_from_files(directory):
    keyword_list = []
    for filename in os.listdir(directory):
        # Check if the file has a .txt extension
        if filename.endswith('.txt'):
            file_path = os.path.join(directory, filename)
            with open(file_path, 'r') as file:
                # Read lines and strip leading/trailing whitespace
                keywords = [line.strip() for line in file]
                # Add keywords to the master keyword_list
                keyword_list.extend(keywords)
    print(f"Loaded {len(keyword_list)} keywords...")
    return keyword_list


In [5]:
import torch
import open_clip
import requests
from PIL import Image
from io import BytesIO

def generate_keywords_for_image(image_url, candidate_keywords, text_features, top_k=5, batch_size=100):

  try:
    # Load the CLIP model and tokenizer
    print("Retrieving image from URL...")

    response = requests.get(image_url)
    response.raise_for_status()
    image = Image.open(BytesIO(response.content))

    # Open the image and preprocess it
    # image = Image.open(image_path)

    print("Preprocessing image...")
    image_tensor = preprocess(image).unsqueeze(0).to(device)

    # Encode the image and keywords
    print("Encoding image and keywords...")
    with torch.no_grad():
        print("Encoding image...")
        image_features = model.encode_image(image_tensor)

    # Calculate the similarity scores
    print("Calculating similarity scores...")
    similarities = torch.matmul(image_features, text_features.T)

    # Get the top K keywords based on the similarity scores
    print("Getting top keywords...")
    top_indices = torch.topk(similarities, top_k, dim=-1).indices.squeeze(0)
    top_keywords = [candidate_keywords[i] for i in top_indices]

    return top_keywords

  except requests.exceptions.RequestException as e:
        print(f"Error loading image from URL: {e}")
        return []

In [6]:
import torch

def encode_keywords(candidate_keywords, batch_size=1000):
    # Tokenize and encode the candidate keywords in batches
    print("Tokenizing and encoding candidate keywords...")
    text_features = []
    for i in range(0, len(candidate_keywords), batch_size):
        print(f"Batch {i+1}...")
        batch_keywords = candidate_keywords[i:i + batch_size]
        batch_tokens = open_clip.tokenize(batch_keywords).to(device)
        with torch.no_grad():
            batch_features = model.encode_text(batch_tokens)
        text_features.append(batch_features)
    # Concatenate the encoded features from all batches
    text_features = torch.cat(text_features, dim=0)

    return text_features

In [None]:
%pip install h5py
%pip install numpy

In [7]:
import torch
import os
import h5py
import numpy as np

def save_features_to_hdf5(keywords, features, file_path):
    # Convert the keywords to a NumPy array of strings
    keywords_array = np.array(keywords, dtype='S')

    # Open the HDF5 file for writing
    with h5py.File(file_path, 'w') as hf:
        # Create and write the 'keywords' dataset
        hf.create_dataset('keywords', data=keywords_array)
        # Create and write the 'features' dataset
        hf.create_dataset('features', data=features)

def encode_and_save_keywords(candidate_keywords, feature_file_path, batch_size=1000, overwrite=False):
    text_features = encode_keywords(candidate_keywords, batch_size)

  
    if os.path.exists(feature_file_path) and not overwrite:
      # Append new features to the existing features
      existing_features = torch.load(feature_file_path)
      text_features = torch.cat([existing_features, text_features], dim=0)
    
    # Save the features to a file
    torch.save(text_features, feature_file_path)

In [None]:
# keyword_urls = [
#     'https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt',
#     'https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt',
#     'https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt'
# ]
# candidate_keywords = load_keywords_from_urls(keyword_urls)
# print(f"Loaded {len(candidate_keywords)} keywords")

# encode_and_save_keywords(candidate_keywords, "keyword_features.pt")

In [None]:
color_keywords = load_keywords_from_files("features/test")
encode_and_save_keywords(color_keywords, "keyword_features.pt")

In [None]:
image_url = 'https://images.unsplash.com/photo-1682027888746-25b1af7bd47f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1350&q=80'
text_features = torch.load("keyword_features.pt")
candidate_keywords = load_keywords_from_files("features")

print(f"Loaded {text_features.size()} features")
img_keywords = generate_keywords_for_image(image_url, candidate_keywords, text_features, top_k=20)

print(img_keywords)