In [3]:
import torch
from torchvision import transforms
from PIL import Image
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SSCD_MODEL_PATH = "/home/ubuntu/Developer/dataunlearning/checkpoints/classifiers/sscd_disc_mixup.torchscript.pt"
model = torch.jit.load(SSCD_MODEL_PATH).to('cuda')

In [3]:
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
)
sscd_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

In [1]:
IMAGE_DIR = '/home/ubuntu/Developer/dataunlearning/data/examples/manual/stallone/images'
BATCH_SIZE = 16

In [4]:
# Load and compute embeddings for all images in batches
embeddings = []
image_files = sorted([f for f in os.listdir(IMAGE_DIR) if f.endswith(".png")])
print(image_files)

with torch.no_grad():
    for i in range(0, len(image_files), BATCH_SIZE):
        batch_files = image_files[i:i+BATCH_SIZE]
        batch_images = []
        
        for image_file in batch_files:
            image_path = os.path.join(IMAGE_DIR, image_file)
            img = Image.open(image_path).convert('RGB')
            batch_images.append(sscd_transform(img))
        
        batch = torch.stack(batch_images).to('cuda')
        batch_embeddings = model(batch)
        embeddings.append(batch_embeddings)

embeddings = torch.cat(embeddings, dim=0)

['image_000.png', 'image_001.png', 'image_002.png', 'image_003.png', 'image_004.png', 'image_005.png', 'image_006.png', 'image_007.png', 'image_008.png', 'image_009.png', 'image_010.png', 'image_011.png', 'image_012.png', 'image_013.png', 'image_014.png', 'image_015.png', 'image_016.png', 'image_017.png', 'image_018.png', 'image_019.png', 'image_020.png', 'image_021.png', 'image_022.png', 'image_023.png', 'image_024.png', 'image_025.png', 'image_026.png', 'image_027.png', 'image_028.png', 'image_029.png', 'image_030.png', 'image_031.png', 'image_032.png', 'image_033.png', 'image_034.png', 'image_035.png', 'image_036.png', 'image_037.png', 'image_038.png', 'image_039.png', 'image_040.png', 'image_041.png', 'image_042.png', 'image_043.png', 'image_044.png', 'image_045.png', 'image_046.png', 'image_047.png', 'image_048.png', 'image_049.png', 'image_050.png', 'image_051.png', 'image_052.png', 'image_053.png', 'image_054.png', 'image_055.png', 'image_056.png', 'image_057.png', 'image_058.pn

NameError: name 'sscd_transform' is not defined

In [11]:
# Compute similarities for a given index
index = 6  # Change this to the desired index
query_embedding = embeddings[index]

similarities = torch.matmul(query_embedding, embeddings.T).squeeze()

In [13]:
print(similarities)

tensor([ 0.0686,  0.1157,  0.4641,  0.4284,  0.1390,  0.3418,  1.0000,  0.1676,
         0.1097,  0.4163,  0.0440,  0.2360,  0.5935,  0.0961,  0.4086,  0.0315,
         0.6302,  0.1257,  0.0184,  0.5222,  0.2936,  0.1716,  0.6055,  0.2554,
         0.2349,  0.2602,  0.4827,  0.3334,  0.0658,  0.6232, -0.0026,  0.1347,
         0.4415,  0.0995,  0.6586,  0.0440,  0.5676,  0.4527,  0.4695,  0.1406,
         0.1409,  0.1063,  0.1117,  0.1961,  0.6566,  0.4685,  0.4433,  0.2017,
         0.0912,  0.2524,  0.3964,  0.3502,  0.2420,  0.1105,  0.0517,  0.2889,
         0.2310,  0.4873,  0.0423,  0.4430,  0.4850,  0.2552,  0.2722,  0.4578,
         0.0464,  0.2186,  0.5384,  0.5274,  0.6258,  0.5908,  0.0258,  0.0760,
         0.3305,  0.1807,  0.5153,  0.4224,  0.5664,  0.5899,  0.0291,  0.5845,
         0.2894,  0.0987,  0.0557,  0.2318,  0.5865,  0.1614,  0.0010,  0.6636,
         0.4240,  0.5979,  0.7028,  0.0158,  0.2090,  0.1725,  0.6484,  0.1859,
         0.4413,  0.0011,  0.5301,  0.19

In [14]:
# Find images with similarity above 0.7
threshold = 0.4
similar_indices = torch.where(similarities > threshold)[0].tolist()

print(f"Images with similarity above {threshold} to index {index}:")
for i in similar_indices:
    print(f"- {image_files[i]}")

Images with similarity above 0.4 to index 6:
- image_002.png
- image_003.png
- image_006.png
- image_009.png
- image_012.png
- image_014.png
- image_016.png
- image_019.png
- image_022.png
- image_026.png
- image_029.png
- image_032.png
- image_034.png
- image_036.png
- image_037.png
- image_038.png
- image_044.png
- image_045.png
- image_046.png
- image_057.png
- image_059.png
- image_060.png
- image_063.png
- image_066.png
- image_067.png
- image_068.png
- image_069.png
- image_074.png
- image_075.png
- image_076.png
- image_077.png
- image_079.png
- image_084.png
- image_087.png
- image_088.png
- image_089.png
- image_090.png
- image_094.png
- image_096.png
- image_098.png
