In [4]:
import transformers
from transformers import ViTImageProcessor, ViTForImageClassification
import PIL
import requests
import glob
import os
from scipy import spatial
from IPython.display import Image, display
import pandas as pd
from tqdm import tqdm
import torch

In [9]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [11]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model     = model.to(device)

In [12]:
def load_image(image_filename):
    # print(image_filename)
    image = PIL.Image.open(image_filename)
    return image

In [13]:
def load_images_folder(folder_path):
    images = []
    for img_filename in tqdm(os.listdir(folder_path)):
        print(img_filename)
        if img_filename==".DS_Store":
            continue
        img = load_image(os.path.join(folder_path, img_filename))
        images.append(img)
    return images

In [14]:
def process_input(images_list, processor):
    inputs = processor(images=images_list, return_tensors="pt")
    return inputs

In [15]:
def get_vit_features(model, inputs):
    outputs = model(**inputs)
    logits = outputs.logits
    return logits

In [16]:
def get_cosine_similarity_for_two_images(features_1, features_2):
    # cosine similarity
    cosine_similarity = 1 - spatial.distance.cosine(features_1, features_2)
    return cosine_similarity

In [17]:
def get_filenames(source_dir):
    l = os.listdir(source_dir)
    if ".DS_Store" in l:
        l.remove(".DS_Store")
    return l

In [18]:
caps_500 = load_images_folder("/Users/carlalasry/Desktop/captures_tps_500/")

100%|███████████████████████████████████████| 501/501 [00:00<00:00, 4576.31it/s]

1714387534789742600_ID_0-2.png
1715055090508796000_ID_0-6.png
1714692071182372900_ID_0-0.png
1714695703676153900_ID_0-4.png
1714642961368637400_ID_0-1.png
1714731946413736000_ID_0-1.png
1714719557370892300_ID_0-1.png
1714990876431859700_ID_0-5.png
1714390510837289000_ID_0-7.png
1714533169253916700_ID_0-0.png
1714643713730351000_ID_0-2.png
1714396700661063700_ID_0-6.png
1714628870931529700_ID_0-1.png
1714994978394017800_ID_0-1.png
1714664700500795400_ID_0-4.png
1714705048124072000_ID_0-3.png
1714365818910920700_ID_0-1.png
1714691227108057000_ID_0-6.png
1715059994841366500_ID_0-4.png
1714360737096376300_ID_0-6.png
1715066498948579300_ID_0-0.png
1714390863020380200_ID_0-0.png
1714743909176955000_ID_0-1.png
1714692790929158100_ID_0-5.png
1714693370540007400_ID_0-4.png
1715020268084170800_ID_0-3.png
1715064846069829600_ID_0-6.png
1715052721955999700_ID_0-5.png
1714716874014531600_ID_0-5.png
1714702945955618800_ID_0-5.png
1714361583511117800_ID_0-5.png
1714371790718980000_ID_0-0.png
17150596




In [19]:
caps_filenames = get_filenames("/Users/carlalasry/Desktop/captures_tps_500/")

In [20]:
streams_500 = load_images_folder("/Users/carlalasry/Desktop/streams_tps_500/")

100%|███████████████████████████████████████| 500/500 [00:00<00:00, 4836.57it/s]

1714370472986042400_10820_1679706957.782.png
1715015393296269300_10602_1679860718.757.png
1715059333315707000_10600_1679871206.775.png
1714618258042151000_10820_1679766031.039.png
1715033944614797300_12101_1679865126.405.png
1715041164945637400_12101_1679866869.54.png
1714743273127530500_23005_1679795840.718.png
1714358918416814000_12101_1679704202.466.png
1714706058523512800_10800_1679786965.405.png
1714654075468931000_10602_1679774587.255.png
1714692790929158100_23005_1679783802.3.png
1714687392931127300_10801_1679782512.143.png
1714746089053212700_10810_1679796517.954.png
1714550967313129500_20000_1679749977.292.png
1714673696158826500_10806_1679779247.243.png
1714693370540007400_10801_1679783942.863.png
1715036731822116900_25001_1679865803.79.png
1714735887339868200_12100_1679794080.741.png
1714992880965140500_12101_1679855351.208.png
1714622527478747100_10602_1679767053.296.png
1715008798931882000_14011_1679859146.467.png
1714719557370892300_23005_1679790186.114.png
17149709913372




In [21]:
streams_filenames = get_filenames("/Users/carlalasry/Desktop/streams_tps_500/")

In [22]:
inputs_cap = process_input(caps_500, processor)
inputs_stream = process_input(streams_500, processor)

In [25]:
outputs_cap = get_vit_features(model, inputs_cap.to(device))

In [27]:
outputs_stream = get_vit_features(model, inputs_stream.to(device))

RuntimeError: MPS backend out of memory (MPS allocated: 73.79 GB, other allocations: 7.71 GB, max allowed: 81.60 GB). Tried to allocate 288.57 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
def get_all_results(caps_filenames, caps_images, outputs_caps, streams_filenames, streams_images, outputs_streams):
    results = {}
    for cap_filename, cap_img, cap in zip(caps_filenames, caps_images, outputs_caps):
        cap_sess_id = cap_filename.split("_")[0]
        results[cap_sess_id] = {}
        for stream_filename, stream_img, stream in zip(streams_filenames, streams_images, outputs_streams):
            cosine_similarity = get_cosine_similarity_for_two_images(cap.tolist(), stream.tolist())
            stream_sess_id = stream_filename.split("_")[0]
            results[cap_sess_id][stream_sess_id] = cosine_similarity
    return results

In [48]:
all_res = get_all_results(caps_filenames, caps_500, outputs_cap, streams_filenames, streams_500, outputs_stream)

In [None]:
all_res