# Load dataset from Huggingface

Optimize hf dataset for fast search and retrival.

In [68]:
if dataset_name == "cifar100":
    hf_dataset = load_dataset("clip-benchmark/wds_vtab-cifar100", split=split, streaming=False)
elif dataset_name == "caltech101":
    hf_dataset = load_dataset("clip-benchmark/wds_vtab-caltech101", split=split, streaming=False)
elif dataset_name == "food101":
    hf_dataset = load_dataset("clip-benchmark/wds_food101", split=split, streaming=False)
elif dataset_name == "cars":
    hf_dataset = load_dataset("clip-benchmark/wds_cars", split=split, streaming=False)
elif dataset_name == "country211":
    hf_dataset = load_dataset("clip-benchmark/wds_country211", split=split, streaming=False)
elif dataset_name == "sun397":
    hf_dataset = load_dataset("clip-benchmark/wds_sun397", split=split, streaming=False)
elif dataset_name == "fer2013":
    hf_dataset = load_dataset("clip-benchmark/wds_fer2013", split=split, streaming=False)
elif dataset_name == "aircraft":
    hf_dataset = load_dataset("clip-benchmark/wds_fgvc_aircraft", split=split, streaming=False)
elif dataset_name == "imagenetv2":
    hf_dataset = load_dataset("clip-benchmark/wds_imagenetv2", split=split, streaming=False)
elif dataset_name == "imagenet-o":
    hf_dataset = load_dataset("clip-benchmark/wds_imagenet-o", split=split, streaming=False)
elif dataset_name == "pets":
    hf_dataset = load_dataset("clip-benchmark/wds_vtab-pets", split=split, streaming=False)
elif dataset_name == "imagenet-a":
    hf_dataset = load_dataset("clip-benchmark/wds_imagenet-a", split=split, streaming=False)
elif dataset_name == "imagenet-r":
    hf_dataset = load_dataset("clip-benchmark/wds_imagenet-r", split=split, streaming=False)
elif dataset_name == "cub":
    hf_dataset = load_dataset("lxs784/cub-200-2011-clip-benchmark", split=split, streaming=False)

if "webp" in hf_dataset[0] and hf_dataset[0]["webp"] is not None:
    image_key = "webp"
elif hf_dataset[0]["jpg"] is not None:
    image_key = "jpg"
    
dataset_name += "-" + split
print(dataset_name)

pets-train-train


In [69]:
import os
import json
from pathlib import Path
from PIL import Image

def optimize_hf_to_lightning(hf_dataset, output_dir, image_key="webp", id_key="__key__", label_key="cls"):
    """
    Iterates over the Hugging Face dataset and saves each sample to disk in a format
    that Lightning's StreamingDataset can read. An index file (index.json) is created.
    
    Each sample is stored as:
      - An image file in JPEG format
      - A metadata entry in the index that records the file path and label
    
    Parameters:
      hf_dataset: The Hugging Face dataset (can be streaming or in-memory)
      output_dir: Directory where the optimized dataset will be stored.
      image_key: Field name in the dataset containing image data.
      id_key: Field name to use as a unique identifier.
      label_key: Field name containing label or class information.
    Returns:
      The output directory path (which contains the data and index).
    """

    os.makedirs(output_dir, exist_ok=True)
    index = {}
    # serializer = JPEGSerializer()  # Can be used to serialize images if desired.
    
    # Iterate over the dataset and write each sample.
    for sample in hf_dataset:
        uid = sample[id_key]
        # Define a file path for the image.
        image_filename = f"{uid}.jpeg"
        image_path = os.path.join(output_dir, image_filename)
        
        # Get the image. Depending on your dataset, it might already be a PIL Image.
        image = sample[image_key]
        if not isinstance(image, Image.Image):
            # If image is not a PIL image, try converting it.
            image = Image.fromarray(image)
            
        if image.mode != "RGB":
            image = image.convert("RGB")
        # Save the image in JPEG format.
        image.save(image_path, format="JPEG")
        
        # Record metadata in the index.
        index[uid] = {
            "image_path": image_filename,  # Store relative path
            "label": sample[label_key],
        }
    
    # Write out the index file.
    index_path = os.path.join(output_dir, "index.json")
    with open(index_path, "w") as f:
        json.dump(index, f)
    
    return output_dir

In [70]:
import os
import io
import json
from PIL import Image
import imagehash
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class HFDataset(Dataset):

    def __init__(self, root_dir, index_file, lookup=None, transform=None):
        self.root_dir = root_dir
        with open(os.path.join(root_dir, index_file), "r") as f:
            self.index_data = json.load(f)
        self.lookup = lookup
        self.samples = list(self.index_data.items())
        self.uid_to_sample = dict(self.samples)
        self.transform = transform 


    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, index):
        uid, sample = self.samples[index]
        image_path = os.path.join(self.root_dir, sample["image_path"])
        pil_image = Image.open(image_path).convert("RGB")
        text = self.lookup[sample["label"]] if self.lookup else sample["label"]

        ahash = str(imagehash.average_hash(pil_image))
        phash = str(imagehash.phash(pil_image))

        return index, text, ahash, phash, uid

    def get_by_id(self, uid):
        """
        Retrieve a raw PIL image and metadata by its unique identifier.
        """
        # if uid not in self.uid_to_sample:
        #     raise KeyError(f"UID: {uid} not found in dataset.")
        sample = self.uid_to_sample[uid]
        image_path = os.path.join(self.root_dir, sample["image_path"])
        pil_image = Image.open(image_path).convert("RGB")
        text = self.lookup[sample["label"]] if self.lookup else sample["label"]
        ahash = imagehash.average_hash(pil_image)
        phash = str(imagehash.phash(pil_image))

        return pil_image, text, ahash, phash

In [71]:
optimized_dir = f"data/optimized_dataset/{dataset_name}"

if not os.path.exists(os.path.join(optimized_dir, "index.json")):
    optimize_hf_to_lightning(hf_dataset, optimized_dir, image_key=image_key)

dataset = HFDataset(
        index_file = "index.json",
        root_dir=optimized_dir,
        lookup=classes if classes else None,
        # transform = transform
        )

dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
for _, texts, ahashes, phashes, uids in dataloader:
    print(texts, ahashes, phashes, uids)
    break
sample_uid = dataset.samples[0][0]
pil_image, text, ahash, phash = dataset.get_by_id(sample_uid)
pil_image.show()

KeyboardInterrupt: 

# Load the laion400m dataset

In [13]:
"""load the laion400m dataset for image retrival"""
import os
from lightning_cloud.utils.data_connection import add_s3_connection
from lightning.data import StreamingDataset, StreamingDataLoader
from lightning.data.streaming.serializers import JPEGSerializer
import torchvision.transforms.v2 as T
from tqdm import tqdm
import imagehash
import torchvision.transforms as T
import matplotlib.pyplot as plt 
import numpy as np
import torch
import json
import concurrent
from PIL import Image
import io

# 1. Add the prepared dataset to your teamspace
add_s3_connection("laoin-400m")

# 2. Create the streaming dataset
class LAOINStreamingDataset(StreamingDataset):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.serializer = JPEGSerializer()

    def __getitem__(self, index):
        id, image, text, _, _, _ = super().__getitem__(index)
        
        return Image.open(io.BytesIO(image)), text, str(id)

laion = LAOINStreamingDataset(input_dir="/teamspace/s3_connections/laoin-400m")

# FILTER WITH CLIP

In [None]:
"""
EXTRA CELL FOR RANDOM TEST
plot original image and overlap images
one row per plot
"""
from PIL import Image, UnidentifiedImageError
from io import BytesIO
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import textwrap
import json
import clip
import torch
import numpy as np
import glob

def resize_image(image, target_size=(256, 256)):
    return image.resize(target_size, Image.Resampling.LANCZOS)

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def show_match_results_clip(dataset, results, output_dir, final_results):
    error_find = 0
    error_filter = 0
    correct_find = 0
    correct_filter = 0
    to_inspect = {}
    for uid, match_indices in tqdm(results.items(), desc=f"verifying duplicate images in {dataset_name}"):
        original_image, original_text, _, _= dataset.get_by_id(uid)
        original_image_resized = resize_image(original_image)
        orig_input = preprocess(original_image).unsqueeze(0).to(device)
        with torch.no_grad():
            orig_features = model.encode_image(orig_input)
            orig_features /= orig_features.norm(dim=-1, keepdim=True)

        correct = 0
        for j in range (len(match_indices)):
            idx = match_indices[j]
            match_image, match_text, _ = laion[idx]
            match_input = preprocess(match_image).unsqueeze(0).to(device)
            with torch.no_grad():
                match_features = model.encode_image(match_input)
                match_features /= match_features.norm(dim=-1, keepdim=True)
            similarity = (orig_features @ match_features.T).item()
            if similarity >= sim_threshold:
                correct += 1

        if correct > 0 and uid not in final_results:
            error_find += 1
            to_inspect[uid] = match_indices
        elif correct > 0 and uid in final_results:
            correct_find += 1
        elif correct == 0 and uid in final_results:
            error_filter += 1
        else:
            correct_filter += 1
    print(f"error_find: {error_find}, \nerror_filter: {error_filter}, \ncorrect_find: {correct_find}, \ncorrect_filter: {correct_filter}")
    print("details in error_find: ", to_inspect, "\n")
input_dir = f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/match_indices_4"
if not os.path.exists(input_dir):
    input_dir = f"/teamspace/studios/find-overlaps-in-laion-400m/data/intermediate/{dataset_name}/match_indices_4"
input_file = os.path.join(input_dir, "combined_results.json")
output_dir = f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-clip"
os.makedirs(output_dir, exist_ok=True)
with open(f"/teamspace/studios/this_studio/data/final/{dataset_name}/final_results.json", "r") as f:
    final_results = json.load(f)

with open(input_file, "r") as f:
    results = json.load(f)
    show_match_results_clip(dataset, results, output_dir, final_results)
print("Plotted all images: ", output_dir)

correct = glob.glob(f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-clip/*.png")


print("final_results(hand picked):", len(final_results))
print("results filtered with clip:", len(correct))
print("error rate: ", round(abs(len(final_results) - len(correct))/ len(correct), 4) * 100, "%")

In [4]:
"""
EXTRA CELL FOR RANDOM TEST
plot original image and overlap images
one row per plot
"""
import os
from PIL import Image, UnidentifiedImageError
from io import BytesIO
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import textwrap
import json
import numpy as np
import glob

def inspect_clip_results(dataset_name):
    input_dir = f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/match_indices_4"
    if not os.path.exists(input_dir):
        input_dir = f"/teamspace/studios/find-overlaps-in-laion-400m/data/intermediate/{dataset_name}/match_indices_4"
    input_file = os.path.join(input_dir, "combined_results.json")
    output_dir = f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-clip"
    os.makedirs(output_dir, exist_ok=True)
    with open(f"/teamspace/studios/this_studio/data/final/{dataset_name}/final_results.json", "r") as f:
        final_results = json.load(f)

    clip_results_fullpaths = glob.glob(f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-clip/*.png")
    clip_results = [os.path.basename(path).split('.')[0] for path in clip_results_fullpaths]
    # print(clip_results)

    error_find = 0
    error_filter = 0
    correct_find = 0
    correct_filter = 0
    error_find_list = []
    error_filter_list = []

    with open(input_file, "r") as f:
        results = json.load(f)
        
    for uid in results.keys():
        if uid in clip_results and uid not in final_results:
            error_find += 1
            error_find_list.append(uid)
        elif uid in clip_results and uid in final_results:
            correct_find += 1
        elif uid not in clip_results and uid in final_results:
            error_filter += 1
            error_filter_list.append(uid)
        else:
            correct_filter += 1
    print("dataset_name =",dataset_name)
    print(f"images before filtering: {len(results.keys())} \nmanual filtered results: {len(final_results.keys())} \nerror_find: {error_find} \nerror_filter: {error_filter} \ncorrect_find: {correct_find} \ncorrect_filter: {correct_filter}")
    print(f"false_positives = {error_find_list} \nfalse_negatives = {error_filter_list}\n")

dataset_names = ["cifar100-train", "cifar100-test", "caltech101-train","caltech101-test", "pets-train", "pets-test", "imagenetv2-test", "imagenet-a-test",
                    "food101-train", "food101-test", "aircraft-train", "aircraft-test", "cars-train", "cars-test", "sun397-test"]

for dataset_name in dataset_names:
    inspect_clip_results(dataset_name)

dataset_name = cifar100-train
images before filtering: 402 
manual filtered results: 445 
error_find: 0 
error_filter: 149 
correct_find: 0 
correct_filter: 253
false_positives = [] 
false_negatives = ['s0021593', 's0022113', 's0022296', 's0022578', 's0022751', 's0022774', 's0023005', 's0023826', 's0023868', 's0024160', 's0025313', 's0025717', 's0026175', 's0026226', 's0026408', 's0026544', 's0026809', 's0026856', 's0027055', 's0027271', 's0027664', 's0028238', 's0028470', 's0028537', 's0028792', 's0028811', 's0028826', 's0029131', 's0029158', 's0029510', 's0029537', 's0029562', 's0030357', 's0030950', 's0030973', 's0031182', 's0031435', 's0031625', 's0032030', 's0032440', 's0001241', 's0001244', 's0001308', 's0001659', 's0001739', 's0001893', 's0002174', 's0002504', 's0002597', 's0002747', 's0002792', 's0003175', 's0003209', 's0004417', 's0004689', 's0004763', 's0004942', 's0005122', 's0005236', 's0006484', 's0007830', 's0007886', 's0009231', 's0009563', 's0010490', 's0010810', 's0011

In [1]:
"""
DELETE
plot original image and overlap images
one row per plot
"""
import os
from PIL import Image, UnidentifiedImageError
from io import BytesIO
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import textwrap
import json
import numpy as np
import glob

def inspect_clip_results(dataset_name):
    input_dir = f"/teamspace/studios/this_studio/data/intermediate/cifar100/match_indices_4"
    if not os.path.exists(input_dir):
        input_dir = f"/teamspace/studios/find-overlaps-in-laion-400m/data/intermediate/{dataset_name}/match_indices_4"
    input_file = os.path.join(input_dir, "combined_results.json")
    output_dir = f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-clip"
    os.makedirs(output_dir, exist_ok=True)
    with open(f"/teamspace/studios/this_studio/data/final/{dataset_name}/final_results.json", "r") as f:
        final_results = json.load(f)

    clip_results_fullpaths = glob.glob(f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-clip/*.png")
    clip_results = [os.path.basename(path).split('.')[0] for path in clip_results_fullpaths]
    # print(clip_results)

    error_find = 0
    error_filter = 0
    correct_find = 0
    correct_filter = 0
    error_find_list = []
    error_filter_list = []

    with open(input_file, "r") as f:
        results = json.load(f)
        
    for uid in results.keys():
        if uid in clip_results and uid not in final_results:
            error_find += 1
            error_find_list.append(uid)
        elif uid in clip_results and uid in final_results:
            correct_find += 1
        elif uid not in clip_results and uid in final_results:
            error_filter += 1
            error_filter_list.append(uid)
        else:
            correct_filter += 1
    print("dataset_name =",dataset_name)
    print(f"images before filtering: {len(results.keys())} \nmanual filtered results: {len(final_results.keys())} \nerror_find: {error_find} \nerror_filter: {error_filter} \ncorrect_find: {correct_find} \ncorrect_filter: {correct_filter}")
    print(f"false_positives = {error_find_list} \nfalse_negatives = {error_filter_list}\n")

# dataset_names = ["cifar100-train", "cifar100-test", "caltech101-train","caltech101-test", "pets-train", "pets-test", "imagenetv2-test", "imagenet-a-test",
#                     "food101-train", "food101-test", "aircraft-train", "aircraft-test", "cars-train", "cars-test", "sun397-test"]

# for dataset_name in dataset_names:
    # inspect_clip_results(dataset_name)
dataset_name = "cifar100-train"
inspect_clip_results(dataset_name)

dataset_name = cifar100-train
images before filtering: 2074 
manual filtered results: 445 
error_find: 0 
error_filter: 445 
correct_find: 0 
correct_filter: 1629
false_positives = [] 
false_negatives = ['s0034179', 's0034210', 's0034265', 's0034288', 's0034693', 's0034749', 's0034800', 's0035041', 's0035126', 's0035246', 's0035402', 's0035403', 's0035420', 's0035472', 's0035530', 's0035708', 's0035718', 's0035760', 's0035787', 's0035874', 's0035882', 's0036013', 's0036022', 's0036045', 's0036196', 's0036436', 's0036509', 's0010818', 's0010837', 's0010873', 's0010879', 's0010987', 's0011173', 's0011196', 's0011537', 's0011674', 's0011883', 's0012233', 's0012280', 's0012364', 's0012483', 's0012606', 's0012613', 's0004513', 's0004689', 's0004763', 's0004767', 's0004808', 's0004942', 's0005122', 's0005209', 's0005236', 's0005282', 's0005373', 's0005542', 's0005960', 's0006234', 's0006378', 's0006379', 's0006471', 's0006484', 's0006569', 's0006696', 's0030245', 's0030304', 's0030357', 's00

In [None]:
input_dir = 

# Collect Results

In [None]:
"""
EXTRA CELL FOR RANDOM TEST
plot original image and overlap images
one row per plot
"""
from PIL import Image, UnidentifiedImageError
from io import BytesIO
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import textwrap
import json
import clip
import torch
import numpy as np
import glob

def resize_image(image, target_size=(256, 256)):
    return image.resize(target_size, Image.Resampling.LANCZOS)
    
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def single_plot(dataset_name, uid, match_indices, output_dir, k=5):
    cols = k + 2

    fig, axes = plt.subplots(1, cols, figsize=(cols * 3, 3))
    axes[0].text(0.5, 0.5, uid, fontsize=24, ha='center', va='center')
    axes[0].axis("off")

    original_image, original_text, _, _= dataset.get_by_id(uid)
    original_image_resized = resize_image(original_image)
    axes[1].imshow(original_image_resized)
    wrapped_caption = "\n".join(textwrap.wrap(original_text, width=24))
    axes[1].set_title(wrapped_caption)
    axes[1].axis('off')
    orig_input = preprocess(original_image).unsqueeze(0).to(device)
    with torch.no_grad():
        orig_features = model.encode_image(orig_input)
        orig_features /= orig_features.norm(dim=-1, keepdim=True)

    for j in range (k):
        ax = axes[j + 2]
        if j >= len(match_indices):
            ax.imshow(np.ones((1, 1, 3)))
            ax.axis('off')
        else:
            idx = match_indices[j]
            match_image, match_text, _ = laion[idx]
            match_input = preprocess(match_image).unsqueeze(0).to(device)
            with torch.no_grad():
                match_features = model.encode_image(match_input)
                match_features /= match_features.norm(dim=-1, keepdim=True)
            similarity = (orig_features @ match_features.T).item()

            ax.imshow(match_image)
            caption = f"sim: {similarity:.2f}\n" + match_text
            wrapped_lines = textwrap.wrap(caption, width=24)
            wrapped_caption_match = "\n".join(wrapped_lines[:2])
            ax.set_title(wrapped_caption_match, fontsize=8)
            ax.axis('off')
            
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{dataset_name}-{uid}.png"))
    plt.close(fig)

In [145]:
# dataset_name = caltech101-train
# false_positives = ['s0001023', 's0001262', 's0001799', 's0002531'] 
# false_negatives = ['s0000302']

# dataset_name = "caltech101"
# split = "test"
# false_positives = ['s0000108', 's0000716', 's0002165', 's0003359', 's0004898', 's0005496', 's0005888', 's0006014'] 
# false_negatives = ['s0002108', 's0002482', 's0002578', 's0003155', 's0003214', 's0005345']

# dataset_name = "pets-train"
# false_positives = [] 
# false_negatives = ['s0000997', 's0001765', 's0002224']

# dataset_name = "pets-test"
# false_positives = [] 
# false_negatives = ['s0001332', 's0001979', 's0003416']

# dataset_name = "imagenetv2-test"
# false_positives = ['s0002250'] 
# false_negatives = ['s0000078', 's0000354', 's0002538', 's0003656', 's0003772', 's0003874', 's0004143', 's0004767', 's0004799', 's0005602', 's0005604', 's0005648', 's0006662', 's0006761', 's0007936', 's0008001', 's0008166', 's0008247', 's0008516', 's0008600', 's0008846', 's0009537']

# dataset_name = "imagenet-a-test"
# false_positives = [] 
# false_negatives = ['s0004766', 's0004984', 's0005449', 's0005684', 's0006332', 's0006778']

# dataset_name = "food101-train"
# false_positives = [] 
# false_negatives = ['s0007669', 's0055625']

# dataset_name = "food101-test"
# false_positives = [] 
# false_negatives = []

# dataset_name = "aircraft-train"
# false_positives = ['s0000666'] 
# false_negatives = ['s0000794', 's0000933', 's0001078', 's0003288']

# dataset_name = "aircraft-test"
# false_positives = ['s0001730', 's0000626', 's0000631', 's0001108', 's0001123'] 
# false_negatives = ['s0002205', 's0000725']

# dataset_name = "cars-train"
# false_positives = ['s0002945', 's0005470', 's0000936'] 
# false_negatives = ['s0005929', 's0006020', 's0006027', 's0006200', 's0006293', 's0006303', 's0006320', 's0006392', 's0006452', 's0006457', 's0006774', 's0006821', 's0006828', 's0006930', 's0007050', 's0007170', 's0002489', 's0002744', 's0002851', 's0002896', 's0002971', 's0003003', 's0003012', 's0003037', 's0003038', 's0003151', 's0003198', 's0003253', 's0003488', 's0003506', 's0003512', 's0003554', 's0003561', 's0007217', 's0007290', 's0007345', 's0007353', 's0007516', 's0007520', 's0007565', 's0007578', 's0007840', 's0007887', 's0007923', 's0007981', 's0008077', 's0008104', 's0004698', 's0004793', 's0004957', 's0004959', 's0005054', 's0005058', 's0005175', 's0005329', 's0005342', 's0005358', 's0005378', 's0005567', 's0005617', 's0005625', 's0005645', 's0005665', 's0005668', 's0005719', 's0005851', 's0005854', 's0005903', 's0003816', 's0003841', 's0004019', 's0004022', 's0004041', 's0004074', 's0004131', 's0004165', 's0004225', 's0004248', 's0004258', 's0004322', 's0004343', 's0004423', 's0004479', 's0004522', 's0004669', 's0000051', 's0000076', 's0000116', 's0000125', 's0000251', 's0000285', 's0000339', 's0000498', 's0000549', 's0000576', 's0000607', 's0000646', 's0000795', 's0000820', 's0000837', 's0000943', 's0001164', 's0001332', 's0001420', 's0001430', 's0001431', 's0001485', 's0001490', 's0001495', 's0001519', 's0001525', 's0001547', 's0001616', 's0001636', 's0001661', 's0001771', 's0001867', 's0001884', 's0002028', 's0002037', 's0002065', 's0002205', 's0002230', 's0002369']

# dataset_name = "cars-test"
# false_positives = ['s0006177', 's0007658', 's0000302', 's0001053', 's0001189'] 
# false_negatives = ['s0005918', 's0005920', 's0005973', 's0006135', 's0006193', 's0006199', 's0006211', 's0006238', 's0006260', 's0006299', 's0006302', 's0006318', 's0006445', 's0006527', 's0006554', 's0006614', 's0006653', 's0006681', 's0006743', 's0002131', 's0002148', 's0002177', 's0002196', 's0002239', 's0002295', 's0002302', 's0002322', 's0002466', 's0002516', 's0002583', 's0002693', 's0002763', 's0002776', 's0002806', 's0002829', 's0002937', 's0003000', 's0003053', 's0003122', 's0003126', 's0003262', 's0006818', 's0006886', 's0006964', 's0007087', 's0007108', 's0007121', 's0007352', 's0007475', 's0007482', 's0007514', 's0007654', 's0007823', 's0007866', 's0007867', 's0007874', 's0007877', 's0007894', 's0008031', 's0004691', 's0004863', 's0004923', 's0004929', 's0005072', 's0005304', 's0005388', 's0005448', 's0005550', 's0005706', 's0005848', 's0005849', 's0005870', 's0003388', 's0003433', 's0003493', 's0003571', 's0003711', 's0003718', 's0003761', 's0003787', 's0003899', 's0004031', 's0004070', 's0004158', 's0004184', 's0004232', 's0004281', 's0004314', 's0004456', 's0004482', 's0004493', 's0004503', 's0004530', 's0004546', 's0004577', 's0004637', 's0004648', 's0000065', 's0000151', 's0000183', 's0000287', 's0000331', 's0000346', 's0000365', 's0000447', 's0000487', 's0000508', 's0000596', 's0000659', 's0000671', 's0000765', 's0000792', 's0000795', 's0000858', 's0000907', 's0000975', 's0001015', 's0001109', 's0001176', 's0001269', 's0001465', 's0001496', 's0001548', 's0001619', 's0001622', 's0001647', 's0001697', 's0001763', 's0001774', 's0001826', 's0001859', 's0001922', 's0002069', 's0002074']

dataset_name = "sun397-test"
false_positives = [] 
false_negatives = ['s0090324', 's0090517', 's0090622', 's0090661', 's0090874', 's0091453', 's0091801', 's0003140', 's0003334', 's0003660', 's0003824', 's0004984', 's0005158', 's0005192', 's0006391', 's0019648', 's0020049', 's0021660', 's0022098', 's0043402', 's0027635', 's0028201', 's0028581', 's0044441', 's0044769', 's0044959', 's0045150', 's0045251', 's0029093', 's0029257', 's0029499', 's0031894', 's0032296', 's0100550', 's0100704', 's0100787', 's0100933', 's0101050', 's0101133', 's0101370', 's0102165', 's0102175', 's0102289', 's0102650', 's0102663', 's0102748', 's0102817', 's0103134', 's0077994', 's0078808', 's0079276', 's0079568', 's0079618', 's0081222', 's0081327', 's0081382', 's0081446', 's0013141', 's0013438', 's0013498', 's0013717', 's0014414', 's0014479', 's0014501', 's0014676', 's0015941', 's0015964', 's0097661', 's0097675', 's0097867', 's0097871', 's0098862', 's0098934', 's0099182', 's0099446', 's0100275', 's0050401', 's0050499', 's0051201', 's0051450', 's0052297', 's0000234', 's0000749', 's0001165', 's0001388', 's0001408', 's0002159', 's0002483', 's0016509', 's0016641', 's0017507', 's0017510', 's0017535', 's0017539', 's0017633', 's0017776', 's0019179', 's0019436', 's0019539', 's0019544', 's0033655', 's0034398', 's0034612', 's0034614', 's0035191', 's0036520', 's0054585', 's0054692', 's0054697', 's0054698', 's0054755', 's0054784', 's0054786', 's0054811', 's0054864', 's0055159', 's0061165', 's0061359', 's0062938', 's0063240', 's0063716', 's0103485', 's0103749', 's0104318', 's0104607', 's0104820', 's0088396', 's0089159', 's0089208', 's0089406', 's0089471', 's0089555', 's0089608', 's0089925', 's0090194', 's0085808', 's0085969', 's0086393', 's0086671', 's0086975', 's0087165', 's0091886', 's0092398', 's0092449', 's0092828', 's0093405', 's0094280', 's0094288', 's0106247', 's0107507', 's0107702', 's0108404', 's0108572', 's0108676', 's0011896', 's0012067', 's0040688', 's0094944', 's0095872', 's0096934', 's0096953', 's0096993', 's0097213', 's0006652', 's0006717', 's0006922', 's0007356', 's0008422', 's0008435', 's0059900', 's0060482', 's0060498', 's0061129', 's0071178', 's0071954', 's0071956', 's0072361', 's0073150', 's0083447', 's0083626', 's0083799', 's0084168', 's0084379', 's0084813', 's0085081', 's0085738', 's0081812', 's0081919', 's0081944', 's0082076', 's0082744', 's0009765', 's0010411', 's0074776', 's0075174', 's0075185', 's0075873', 's0076597', 's0077093', 's0077635', 's0025077', 's0025177', 's0026085', 's0026132', 's0026216', 's0026619', 's0026700', 's0026771', 's0026952', 's0026996', 's0047978', 's0048193', 's0048258', 's0048314', 's0048516', 's0048849', 's0048989', 's0049039', 's0049735', 's0036932', 's0037566', 's0037776', 's0038571', 's0038782', 's0039241', 's0064626', 's0064912', 's0064934', 's0065369', 's0065548', 's0066405', 's0066495', 's0066681', 's0066875', 's0008717', 's0056104', 's0056787', 's0057056', 's0058875', 's0059550', 's0059624', 's0022626', 's0023457', 's0024501', 's0024829', 's0046684', 's0047555', 's0047721', 's0047803', 's0047864', 's0067511', 's0067691', 's0067845', 's0069646', 's0070371']

# plot fp and fn

In [146]:
from datasets import load_dataset
from torchvision import transforms
import json
"""
Parameters:
"""
if dataset_name.split('-')[-1] == "train":
    dataset_name = dataset_name[:-6]
    split = "train"
else:
    dataset_name = dataset_name[:-5]
    split = "test"
print(dataset_name, split)

# dataset_name = "caltech101"
# split = "train"
# sim_threshold = 0.85

k = 5

classes = json.load(open(f"data/classes_{dataset_name}.json", "r"))
print(len(classes))

sun397 test
397


In [147]:
if dataset_name == "cifar100":
    hf_dataset = load_dataset("clip-benchmark/wds_vtab-cifar100", split=split, streaming=False)
elif dataset_name == "caltech101":
    hf_dataset = load_dataset("clip-benchmark/wds_vtab-caltech101", split=split, streaming=False)
elif dataset_name == "food101":
    hf_dataset = load_dataset("clip-benchmark/wds_food101", split=split, streaming=False)
elif dataset_name == "cars":
    hf_dataset = load_dataset("clip-benchmark/wds_cars", split=split, streaming=False)
elif dataset_name == "country211":
    hf_dataset = load_dataset("clip-benchmark/wds_country211", split=split, streaming=False)
elif dataset_name == "sun397":
    hf_dataset = load_dataset("clip-benchmark/wds_sun397", split=split, streaming=False)
elif dataset_name == "fer2013":
    hf_dataset = load_dataset("clip-benchmark/wds_fer2013", split=split, streaming=False)
elif dataset_name == "aircraft":
    hf_dataset = load_dataset("clip-benchmark/wds_fgvc_aircraft", split=split, streaming=False)
elif dataset_name == "imagenetv2":
    hf_dataset = load_dataset("clip-benchmark/wds_imagenetv2", split=split, streaming=False)
elif dataset_name == "imagenet-o":
    hf_dataset = load_dataset("clip-benchmark/wds_imagenet-o", split=split, streaming=False)
elif dataset_name == "pets":
    hf_dataset = load_dataset("clip-benchmark/wds_vtab-pets", split=split, streaming=False)
elif dataset_name == "imagenet-a":
    hf_dataset = load_dataset("clip-benchmark/wds_imagenet-a", split=split, streaming=False)
elif dataset_name == "imagenet-r":
    hf_dataset = load_dataset("clip-benchmark/wds_imagenet-r", split=split, streaming=False)

if "webp" in hf_dataset[0] and hf_dataset[0]["webp"] is not None:
    image_key = "webp"
elif hf_dataset[0]["jpg"] is not None:
    image_key = "jpg"
    
dataset_name += "-" + split
print(dataset_name)

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/76 [00:00<?, ?it/s]

sun397-test


In [148]:
import os
import io
import json
from PIL import Image
import imagehash
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class HFDataset(Dataset):

    def __init__(self, root_dir, index_file, lookup=None, transform=None):
        self.root_dir = root_dir
        with open(os.path.join(root_dir, index_file), "r") as f:
            self.index_data = json.load(f)
        self.lookup = lookup
        self.samples = list(self.index_data.items())
        self.uid_to_sample = dict(self.samples)
        self.transform = transform 


    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, index):
        uid, sample = self.samples[index]
        image_path = os.path.join(self.root_dir, sample["image_path"])
        pil_image = Image.open(image_path).convert("RGB")
        text = self.lookup[sample["label"]] if self.lookup else sample["label"]

        ahash = str(imagehash.average_hash(pil_image))
        phash = str(imagehash.phash(pil_image))

        return index, text, ahash, phash, uid

    def get_by_id(self, uid):
        """
        Retrieve a raw PIL image and metadata by its unique identifier.
        """
        # if uid not in self.uid_to_sample:
        #     raise KeyError(f"UID: {uid} not found in dataset.")
        sample = self.uid_to_sample[uid]
        image_path = os.path.join(self.root_dir, sample["image_path"])
        pil_image = Image.open(image_path).convert("RGB")
        text = self.lookup[sample["label"]] if self.lookup else sample["label"]
        ahash = imagehash.average_hash(pil_image)
        phash = str(imagehash.phash(pil_image))

        return pil_image, text, ahash, phash

In [149]:
optimized_dir = f"data/optimized_dataset/{dataset_name}"

if not os.path.exists(os.path.join(optimized_dir, "index.json")):
    optimize_hf_to_lightning(hf_dataset, optimized_dir, image_key=image_key)

dataset = HFDataset(
        index_file = "index.json",
        root_dir=optimized_dir,
        lookup=classes if classes else None,
        # transform = transform
        )

dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
# for _, texts, ahashes, phashes, uids in dataloader:
#     print(texts, ahashes, phashes, uids)
#     break
# sample_uid = dataset.samples[0][0]
# pil_image, text, ahash, phash = dataset.get_by_id(sample_uid)
# pil_image.show()

In [150]:
# move images
import os
import shutil
from tqdm import tqdm

fp_path = "clip_mistakes/false_positives"
fn_path = "clip_mistakes/false_negatives"
os.makedirs(fp_path, exist_ok=True)
os.makedirs(fn_path, exist_ok=True)

source_dir_fp = f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-clip"
source_dir_fn = f"data/final/{dataset_name}/final_results.json"

for uid in false_positives:
    image_path = os.path.join(source_dir_fp, f"{uid}.png")
    new_path = os.path.join(fp_path, f"{dataset_name}-{uid}.png")
    shutil.copy(image_path, new_path)
print("done moving false positives(error find)")
source_json = json.load(open(source_dir_fn, "r"))
for uid in tqdm(false_negatives, total=len(false_negatives)):
    match_indices = source_json[uid]
    single_plot(dataset_name, uid, match_indices, fn_path, k=5)
print("done moving false negatives(error filter)")

done moving false positives(error find)


  0%|          | 0/255 [00:00<?, ?it/s]

100%|██████████| 255/255 [15:38<00:00,  3.68s/it]

done moving false negatives(error filter)





In [151]:
# """
# EXTRA CELL FOR RANDOM TEST
# plot original image and overlap images
# one row per plot
# """
# from PIL import Image, UnidentifiedImageError
# from io import BytesIO
# import matplotlib.pyplot as plt
# from tqdm import tqdm
# import pandas as pd
# import textwrap
# import json
# import glob

# def resize_image(image, target_size=(256, 256)):
#     return image.resize(target_size, Image.Resampling.LANCZOS)

# def show_match_results_single_wahash(dataset, results, output_dir, k=5):

#     cols = k + 2
#     for uid, match_indices in tqdm(results.items(), desc=f"plotting duplicate images for {dataset_name}"):
#         fig, axes = plt.subplots(1, cols, figsize=(cols * 3, 3))
#         axes[0].text(0.5, 0.5, uid, fontsize=24, ha='center', va='center')
#         axes[0].axis("off")

#         original_image, original_text, ahash, phash= dataset.get_by_id(uid)
#         original_image_resized = resize_image(original_image)
#         axes[1].imshow(original_image_resized)
#         wrapped_caption = "\n".join(textwrap.wrap(original_text, width=24))
#         axes[1].set_title(wrapped_caption)
#         axes[1].axis('off')

#         correct = 0
#         for j in range (k):
#             ax = axes[j + 2]
#             if j >= len(match_indices):
#                 ax.imshow(np.ones((1, 1, 3)))
#             else:
#                 idx = match_indices[j]
#                 match_image, match_text, _ = laion[idx]
#                 laion_phash = imagehash.phash(match_image)
#                 p_dist = abs(imagehash.hex_to_hash(phash) - laion_phash)
#                 laion_ahash = imagehash.average_hash(match_image)
#                 a_dist = abs(ahash - laion_ahash)
#                 if a_dist <= 4:
#                     ax.imshow(match_image)
#                     warapped_lines = "a_dist: " + str(a_dist) + ", p_dist: " + str(a_dist)
#                     wrapped_caption_match = "\n".join(wrapped_lines[:2])
#                     ax.set_title(wrapped_caption_match, fontsize=8)
#                     correct += 1
#                 else:
#                     ax.imshow(np.ones((1, 1, 3)))
#             ax.axis('off')
#         if correct > 0:
#             plt.tight_layout()
#             plt.savefig(os.path.join(output_dir, f"{uid}.png"))
#         plt.close(fig)

# # find-overlaps-in-laion-400m
# input_dir = f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/match_indices_{threshold}"
# input_file = os.path.join(input_dir, "combined_results.json")
# output_dir = f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-ahash"
# os.makedirs(output_dir, exist_ok=True)

# with open(input_file, "r") as f:
#     results = json.load(f)
#     show_match_results_single_wahash(dataset, results, output_dir, k)
# print("Plotted all images: ", output_dir)

# correct = glob.glob(f"/teamspace/studios/this_studio/data/intermediate/{dataset_name}/plots-ahash/*.png")

# with open(f"/teamspace/studios/this_studio/data/final/{dataset_name}/final_results.json", "r") as f:
#     final_results = json.load(f)

# print("final_results(hand picked):", len(final_results))
# print("results filtered with average hash:", len(correct))
# print("error rate: ", round(abs(len(final_results) - len(correct))/ len(correct), 4) * 100, "%")