In [2]:
import os
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import csv
import json
import pandas as pd

In [3]:
# Load the CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Write embeddings to CSV file

In [5]:
# Directory containing your images
image_dir = 'screenshots'

# CSV file to save embeddings
csv_file = 'image_embeddings.csv'

# Open a CSV file for writing
with open(csv_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    # Write the header row
    writer.writerow(["Filename", "Embeddings"])
    
    # Loop through all images in the directory
    for image_name in os.listdir(image_dir):
        # Make sure to process only files with a .jpg or .png extension
        if image_name.endswith('.jpg') or image_name.endswith('.png'):
            image_path = os.path.join(image_dir, image_name)
            
            # Open the image
            image = Image.open(image_path)
            
            # Process the image for the CLIP model
            inputs = processor(images=image, return_tensors="pt", padding=True)
            
            # Generate embeddings
            with torch.no_grad():
                embeddings = model.get_image_features(**inputs).numpy()
                
            # Convert the embeddings to a list for CSV writing
            embeddings_list = embeddings.flatten().tolist()
            
            # Write the filename and embeddings to the CSV file
            writer.writerow([image_name, embeddings_list])

print("Embeddings have been saved to", csv_file)

Embeddings have been saved to image_embeddings.csv


## Reading from file

In [31]:
import pandas as pd

# Load the CSV file
df = pd.read_csv('image_embeddings.csv')

# Assuming embeddings are in the second column and are stored as strings
# Example of embedding: "[0.23, 0.35, 0.11]"
import ast  # ast.literal_eval safely evaluates a string containing a Python literal expression

df['embeddings'] = df['Embeddings'].apply(ast.literal_eval)
# Accessing the embedding for the first row
first_embedding = df.loc[0, 'embeddings']
print(first_embedding)
print(len(first_embedding))

[-0.29299911856651306, -0.24236251413822174, -0.08342461287975311, -0.46731022000312805, 0.17175191640853882, -0.6144365668296814, 0.050357967615127563, 0.6644703149795532, 0.9215517044067383, -0.10427722334861755, 0.3126069903373718, -0.20633041858673096, -0.8325822949409485, -0.4152643382549286, -0.102508544921875, 0.05112642049789429, -0.5552076697349548, -0.26692235469818115, -0.0949239432811737, -0.057202622294425964, -0.24518036842346191, -0.5673624277114868, 0.09919030964374542, -0.159003347158432, 0.2435503602027893, 0.06483319401741028, -0.022691786289215088, 0.07246550917625427, 0.14191289246082306, 0.05346570163965225, 0.2500624358654022, -0.10493884235620499, 0.2194877713918686, 0.20308303833007812, 0.29780036211013794, -0.35999947786331177, -0.32517102360725403, -0.09154203534126282, -0.3530856966972351, -1.3351821899414062, -0.00230446457862854, 0.10342222452163696, -0.05525994300842285, -0.6574844717979431, -0.14613592624664307, 0.37750765681266785, 0.5655860900878906, -

In [32]:
import pandas as pd
import umap

# Assuming 'df' is your DataFrame and it contains an 'embeddings' column with your embeddings data
# Convert embeddings list into a proper format if necessary
embeddings = list(df['embeddings'])

# Initialize UMAP. Reduce dimensionality to 2D for easy visualization.
# Create a UMAP instance with custom parameters
reducer = umap.UMAP(
    n_neighbors=30,
    n_components=2,
    metric='euclidean',
    min_dist=0.2,
    spread=1.5,
    learning_rate=1.0,
    n_epochs=200,
    init='spectral'
)
umap_embeddings = reducer.fit_transform(embeddings)

failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling back to random initialisation!
  warn(


### Prepare Labels

In [46]:
#  Prepare labels
import re

def extract_substring(s):
    # Use a regular expression to find the point at which to stop
    match = re.search(r'_(p|m|sw|s)', s)
    if match:
        return s[:match.start()]
    return s  # Return the whole string if no match is found

# Apply the function to the 'Label' column
df['Label'] = df['Filename'].apply(extract_substring)

# Display the DataFrame to see the original and trimmed labels
print(df)

                                        Filename  \
0                       heatmap_sw_1_2_s_1_0.png   
1           two_by_two_p_4_m_10_sw_0_7_s_1_2.png   
2        multiple_view_p_2_m_12_sw_0_7_s_0_7.png   
3      two_by_two_uneven_w_m_20_sw_0_7_s_0_7.png   
4         multiple_view_p_1_m_8_sw_1_2_s_1_0.png   
...                                          ...   
7291  three_composite_v_p_0_m_0_sw_1_2_s_2_0.png   
7292    two_by_two_uneven_h_p_0_sw_0_7_s_1_2.png   
7293        three_composite_m_9_sw_0_7_s_1_2.png   
7294   multi_view_link_p_0_m_11_sw_1_0_s_0_7.png   
7295     multiple_view_p_3_m_13_sw_1_2_s_1_0.png   

                                             Embeddings  \
0     [-0.29299911856651306, -0.24236251413822174, -...   
1     [-0.024562444537878036, -0.5847567319869995, 0...   
2     [0.05260123312473297, -0.2164071798324585, 0.3...   
3     [0.25502264499664307, -0.36298295855522156, 0....   
4     [0.0062620267271995544, 0.049051493406295776, ...   
...                  

In [35]:
print(len(umap_embeddings))

# Convert the embeddings to a DataFrame
embedding_df = pd.DataFrame(umap_embeddings, columns=['UMAP_1', 'UMAP_2'])

# Add the labels to the DataFrame
embedding_df['Label'] = df['Label']

# Display the first few rows of the DataFrame
print(embedding_df.head())

7296
      UMAP_1     UMAP_2                Label
0  10.697206  18.870228              heatmap
1   1.476405   3.224320           two_by_two
2  16.091751  13.461968        multiple_view
3   6.637461  11.374885  two_by_two_uneven_w
4   3.145026 -16.165962        multiple_view


### Visualize 🌱

In [43]:
# API Reference: https://github.com/flekschas/jupyter-scatter
# and also https://github.com/flekschas/regl-scatterplot/#properties
config = {
    "size": 5,
    "axes_labels": True,
    "height": 800,
    "background": "dark",
    "legend": True,
    # "aspectRatio": 1,
    "opacity": 0.8,
    "axes_grid": True
}

In [44]:
# Plotting the results using jupyter scatter
import jscatter

jscatter.Scatter(
    data=embedding_df, x='UMAP_1', y='UMAP_2', color_by='Label', **config
).show()

HBox(children=(VBox(children=(Button(button_style='primary', icon='arrows', layout=Layout(width='36px'), style…