In [27]:
import os, sys
import pandas as pd
from IPython.display import clear_output

import numpy as np

import warnings

from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
pd.options.mode.chained_assignment = None
from sklearn.manifold import TSNE
from transformers import AutoModel, AutoProcessor
from PIL import Image
from torchvision import transforms
import torch
import scripts.utils as utils

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
# directory & file hierarchy
proj_dir = os.path.abspath('../..')
stimuli_dir = os.path.join(proj_dir,'stimuli')
results_dir = os.path.join(proj_dir,'results')
plot_dir = os.path.join(results_dir,'plots')
ms_plot_dir = os.path.join(plot_dir,'ms')
analysis_dir = os.path.join(proj_dir,'analysis')
data_dir = os.path.join(proj_dir,'data')
# plot_dir = os.path.join(results_dir,'plots')
# csv_dir = os.path.join(results_dir,'csv')
exp_dir = os.path.abspath(os.path.join(proj_dir,'experiments'))

## add helpers to python path
if os.path.join(proj_dir,'utils') not in sys.path:
    sys.path.append(os.path.join(proj_dir,'utils'))

def make_dir_if_not_exists(dir_name):   
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    return dir_name

In [12]:
things_main_df = pd.read_csv(os.path.join(data_dir, 'things_concepts.tsv'), sep='\t') ### things concepts metadata
things_plus_df= pd.read_csv(os.path.join(data_dir,'THINGSplus_categories.tsv'),sep='\t')
things_plus_dict = dict(zip(things_plus_df.uniqueID, things_plus_df.category))

In [13]:
things_plus_dict

{'aardvark': 'mammal',
 'alligator': 'animal',
 'alpaca': 'mammal',
 'ant': 'insect',
 'anteater': 'mammal',
 'antelope': 'mammal',
 'badger': 'mammal',
 'barnacle': 'sea animal',
 'bat1': 'mammal',
 'bear': 'mammal',
 'beaver': 'mammal',
 'bee': 'insect',
 'beetle': 'insect',
 'bird': 'bird',
 'bison': 'mammal',
 'blowfish': 'seafood',
 'boar': 'mammal',
 'bug': 'insect',
 'bull': 'mammal',
 'butterfly': 'insect',
 'calf1': 'mammal',
 'camel': 'mammal',
 'cardinal': 'bird',
 'cat': 'mammal',
 'caterpillar': 'insect',
 'catfish': 'seafood',
 'cheetah': 'mammal',
 'chick': 'bird',
 'chicken2': 'farm animal',
 'chihuahua': 'mammal',
 'chinchilla': 'mammal',
 'chipmunk': 'mammal',
 'clam': 'seafood',
 'cobra': 'animal',
 'cockatoo': 'bird',
 'cockroach': 'insect',
 'coral': 'sea animal',
 'cougar': 'mammal',
 'cow': 'mammal',
 'coyote': 'mammal',
 'crab': 'seafood',
 'crayfish': 'seafood',
 'crow': 'bird',
 'dalmatian': 'mammal',
 'deer': 'mammal',
 'dog': 'mammal',
 'dolphin': 'sea anima

In [14]:
sketch_trials = pd.read_csv(os.path.join(data_dir,'things-drawings-prod-clean.csv'))

In [15]:
### apply things_plus_dict to the uniqueID column in sketch_trials
sketch_trials['category'] = sketch_trials['concept'].map(things_plus_dict)
sketch_trials['category'].value_counts()

category
food                      3002
tool                      1483
container                 1250
mammal                    1216
sports equipment           841
weapon                     790
vehicle                    756
furniture                  634
vegetable                  608
home decor                 589
clothing                   565
musical instrument         542
body part                  512
fruit                      500
electronic device          487
toy                        475
hardware                   474
part of car                464
plant                      461
seafood                    423
personal hygiene item      410
scientific equipment       404
bird                       361
medical equipment          328
kitchen appliance          325
women's clothing           312
school supply              301
arts and crafts supply     300
drink                      297
insect                     267
game                       260
home appliance             259

In [28]:

# Load the data
image_folder = os.path.join(stimuli_dir, 'sketch_pngs')
# Load the model and tokenizer
model_name = 'google/siglip-so400m-patch14-384'  # Replace with the name of your model
model = AutoModel.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)


In [None]:

def embed_image(image_path):
    image = Image.open(image_path).convert('RGB')
    inputs = processor(text=None, images=image, padding="max_length", return_tensors="pt")
    with torch.no_grad():
        outputs = model.get_image_features(**inputs)
    
   
    return outputs.squeeze().numpy().shape

# Embed all images
embeddings = []
for sketch_id in tqdm(sketch_trials['sketch_id'], desc="Embedding images"):
    image_path = os.path.join(image_folder, f'{sketch_id}.png')  # Adjust the file extension if needed
    embedding = embed_image(image_path)
    embeddings.append(embedding)

embeddings = np.array(embeddings)

# Perform t-SNE
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(embeddings)

# Plot the t-SNE results
plt.figure(figsize=(10, 8))
scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=sketch_trials['category'].astype('category').cat.codes, cmap='viridis')
plt.colorbar(scatter, ticks=range(len(sketch_trials['category'].unique())), label='Category')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.title('t-SNE of Sketch Embeddings')
plt.show()

Embedding images:   0%|          | 1/28627 [00:07<58:00:17,  7.29s/it]

In [35]:

i = Image.open(image_path).convert('RGB')


(1152,)