In [1]:
import json
import open_clip
import torch
import os
from PIL import Image

filename = 'clusters.json'
image_folder_path = "yars_data/photos"

with open(filename, 'r') as file:
    clusters_loaded = json.load(file)


model, _, transform = open_clip.create_model_and_transforms(
  model_name="coca_ViT-L-14",
  pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)

if torch.cuda.is_available():
    model = model.cuda()
else:
    print("CUDA is not available. Running on CPU.")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cluster_captions = {}

count = 0
with open("cluster_captions.txt", "a") as file:
    for cluster, images in clusters_loaded.items():
        # ignore small clusters which are likely outliers 
        if len(images) >= 5:
            for image_id in images:
                if cluster != '-1':
                    if cluster not in cluster_captions:
                        cluster_captions[cluster] = []
                    
                    image_path = os.path.join(image_folder_path, image_id)
                    im = Image.open(image_path).convert("RGB")
                    im = transform(im).unsqueeze(0)
                    im = im.cuda()
            
                    with torch.no_grad(), torch.cuda.amp.autocast():
                      generated = model.generate(im, generation_type='top_k')
                    
                    caption = open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", "")
                    cluster_captions[cluster].append(caption)
                    file.write(f"{cluster}: {caption}\n")
                    print(f"Processed cluster {cluster}: {image_id}", end='\r', flush=True)


Processed cluster 56: qzElzsnHc9yvFgxeWKy-qA.jpg

In [7]:
from collections import Counter
import re
import nltk
from nltk.corpus import stopwords

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))


def tokenize_and_count(captions):
    # Regular expression to find words
    word_pattern = re.compile(r'\b\w+\b')
    
    # Flatten the list of captions into a single string
    text = ' '.join(captions).lower()  # Convert to lower case for uniformity
    
    # Find all words in the string
    words = word_pattern.findall(text)
    
    # Count the frequency of each word
    frequency = Counter(words)
    
    return frequency

for cluster, captions in cluster_captions.items():
    freq = tokenize_and_count(captions)
    filtered_counts = {word: count for word, count in freq.items() if word not in stop_words}
    print(sorted(filtered_counts.items(), key=lambda x: x[1], reverse=True))



[('food', 6), ('table', 6), ('rice', 4), ('plate', 3), ('bowl', 3), ('plates', 3), ('meat', 1), ('vegetables', 1), ('fork', 1), ('curry', 1), ('bowls', 1), ('variety', 1), ('shown', 1), ('water', 1)]
[('plate', 4), ('ribs', 3), ('fries', 3), ('coleslaw', 3), ('beer', 3), ('wooden', 2), ('cutting', 2), ('board', 2), ('meat', 2), ('chicken', 2), ('wings', 2), ('celery', 2), ('drink', 1), ('sticks', 1), ('bowl', 1), ('side', 1)]
[('plate', 21), ('sushi', 20), ('food', 8), ('variety', 6), ('table', 6), ('sauce', 6), ('toppings', 5), ('roll', 5), ('rolls', 3), ('close', 3), ('rice', 2), ('shrimp', 1), ('items', 1), ('candle', 1), ('background', 1), ('chopsticks', 1), ('green', 1), ('onions', 1), ('wooden', 1), ('topped', 1), ('covered', 1), ('two', 1), ('plates', 1), ('salmon', 1), ('avocado', 1), ('fork', 1)]
[('table', 10), ('lemon', 8), ('oysters', 7), ('plate', 6), ('black', 6), ('white', 6), ('checkered', 6), ('food', 4), ('cloth', 4), ('tray', 3), ('wedge', 3), ('sauce', 2), ('couple'

[nltk_data] Downloading package stopwords to /home/qz2190/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
