In [1]:
import pickle
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
import random
import time

import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import clip
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation

# File Organization

In [2]:
def rename_files(source_dir):
    """
    Renames all jpg files in the source directory with their Design Labels.

    Parameters:
    source_dir: str, the path to the directory containing the jpg files.

    Returns:
    None
    """
    
    # List all files in the source directory
    files = os.listdir(source_dir)
    
    for file in files:

        # Check if the file is a jpg
        if file.endswith('.jpg'):

            # Get the file extension
            _, ext = os.path.splitext(file)

            # Skip the first VL in the file name
            first_vl_index = file.find('VL')

            # Find the next VL in the file name
            if first_vl_index != -1:
                start_index = file.find('VL', first_vl_index + 2)
                if start_index != -1:
                    end_index = file.find('.', start_index)
                    new_name = file[start_index:end_index] if end_index != -1 else file[start_index:]

                    # Rename the file
                    original_file_path = os.path.join(source_dir, file)
                    new_file_path = os.path.join(source_dir, new_name + ext)
                    os.rename(original_file_path, new_file_path)

# Helpers

In [3]:
def open_image(image_path, convert_mode):
    """
    Opens an image from the given path.

    Parameters:
    image_path: str, the path to the image.
    convert_mode: str, the mode to convert the image to. Options are "RGB" and "L".

    Returns:
    image: Image, the opened image.
    """

    assert convert_mode in ["RGB", "L"], "Invalid convert mode. Options are 'RGB' and 'L'."
    
    # Open the image
    image = Image.open(image_path)

    # Convert the image to specified mode
    image = image.convert(convert_mode)

    return image

def display_image(image):
    """
    Displays the image.

    Parameters:
    image: Image, the image to display.

    Returns:
    None
    """

    image.show()

def get_palette(num_cls):
    """ Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3
    return palette


def resize_image_to_tensor(image):
    """
    Resizes a given RGB PIL image into a 1x3x336x336 tensor.

    Parameters:
    - image: PIL.Image object, the image to resize.

    Returns:
    - tensor: torch.Tensor, the resized image as a 1x3x336x336 tensor.
    """
    # Define the transformation
    transform = transforms.Compose([
        transforms.Resize((336, 336)),  # Resize to 336x336
        transforms.ToTensor(),          # Convert to tensor and normalize to [0, 1]
    ])

    # Apply the transformation
    tensor = transform(image)

    # Add batch dimension
    tensor = tensor.unsqueeze(0)  # Shape: 1x3x336x336

    return tensor

# Functions

In [4]:
def image_encoder(image, model, transform, save_folder, filename):
    """
    Use CLIP model to encode the image and save the tranformed version.
    First resizes the image to 224x224, then normalizes it, and finally encodes it.
    Then, encodes the image into a 512-dimensional feature vector.
    
    
    Parameters:
    - image: PIL.Image object, the image to encode.
    - model: CLIP model, the model used for encoding.
    - transform: CLIP transform, the transformation to required for CLIP.
    - save_folder: str, the folder to save the transformed image.
    - filename: str, the name of the file to save the transformed image as.

    Returns:
    - image_features: torch.Tensor, the encoded image.
    """

    # Load the CLIP model
    model = model.eval().to(DEVICE)

    # Preprocess the image
    image = transform(image).unsqueeze(0).to(DEVICE)

    # Encode the image
    with torch.no_grad():
        image_features = model.encode_image(image)


    # # Ensure the save folder exists
    # if not os.path.exists(save_folder):
    #     print(f"Creating folder {save_folder}...")
    #     os.makedirs(save_folder)
    
    # # Save the transformed image
    # save_path = os.path.join(save_folder, filename)
    # transformed_image_pil = transforms.ToPILImage()(image.squeeze(0).cpu())
    # transformed_image_pil.save(save_path)

    return image_features

def create_reference_embeddings(source_dir, CLIP_model, CLIP_transform, convert_mode, save_folder):
    """
    Creates the image embeddings for the images in the source directory and saves together with labels.
    
    Parameters:
    - source_dir: str, the path to the directory containing the images.
    - CLIP_model: CLIP model, the CLIP model to use for encoding.
    - CLIP_transform: CLIP transforms, the CLIP transformation to apply to the images.
    - convert_mode: str, the mode to convert the image to. Options are "RGB" and "L".
    
    Returns:
    None
    """

    # Get the list of files in the source directory
    sub_files = os.listdir(source_dir)

    # Initialize the list of image features and labels
    design_features_list = []
    design_labels_list = []

    for file in sub_files:
        if file == ".DS_Store":
            continue
        print(f"Processing {file}...")

        # Get the path to the folder containing the images
        image_path = os.path.join(source_dir, file)

        # Load the images from the folder
        image = open_image(image_path, convert_mode)

        # Embed the image
        image_features = image_encoder(image, CLIP_model, CLIP_transform, save_folder, filename=file)

        # Append the image features and labels to the lists
        design_features_list.append(image_features)
        design_labels_list.append(file)

    # Save the image features and labels
    with open(f'../data/design_embeddings_{convert_mode}.pkl', 'wb') as f:
        pickle.dump(design_features_list, f)
    with open(f'../data/design_labels_{convert_mode}.pkl', 'wb') as f:
        pickle.dump(design_labels_list, f)

def get_segmentation_mask(image, processor, model):
    """
    Function to segment clothes in an image.

    Parameters:
    - image: PIL.Image object, the image to segment.
    - processor: SegformerImageProcessor object, the processor used to preprocess the image.
    - model: AutoModelForSemanticSegmentation object, the model used to segment the image.

    Returns:
    - pred_seg: torch.Tensor, the segmented image.
    """
    inputs = processor(images=image, return_tensors="pt")

    outputs = model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]

    # Create a mask for the labels 4, 5, 6, and 7
    mask = (pred_seg == 4) | (pred_seg == 5) | (pred_seg == 6) | (pred_seg == 7) | (pred_seg == 8) | (pred_seg == 16) | (pred_seg == 17)

    # Set all other labels to 0
    pred_seg[~mask] = 0

    # Set the labels 4, 5, 6, and 7 to 255
    pred_seg[mask] = 255

    return pred_seg
    
    # plt.imshow(pred_seg)

## Triplet Functions

In [25]:
def image_to_tensor(image):
    """
    Converts a PIL image to a tensor.

    Parameters:
    image: PIL Image, the image to convert.

    Returns:
    tensor: Tensor, the converted tensor.
    """

    # Convert the image to a tensor
    tensor = transforms.ToTensor()(image)

    return tensor

def tensor_to_image(tensor):
    """
    Converts a tensor to a PIL image.

    Parameters:
    tensor: Tensor, the tensor image to convert.

    Returns:
    image: PIL Image, the converted image.
    """

    # Convert the tensor to an image
    image = transforms.ToPILImage()(tensor)

    return image


def apply_random_rotation(image_tensor, degrees=30):
    """
    Applies random rotation to the given tensor image.

    Parameters:
    image_tensor: Tensor, the input image tensor.
    degrees: int or tuple, range of degrees to select from.

    Returns:
    rotated_tensor: Tensor, the image tensor with random rotation applied.
    """
    # Create a RandomRotation transform
    random_rotation = transforms.RandomRotation(degrees=degrees)

    # Apply the transform to the image tensor
    rotated_tensor = random_rotation(image_tensor)

    return rotated_tensor


def downsample_and_upsample(image_tensor, downsample_level=5):
    """
    Downsamples an input tensor to a specified level and then upsamples it to the original size.

    A proper range for downsample_level is 5 to 10.

    Parameters:
    image_tensor: Tensor, the input image tensor.
    downsample_level: int, the factor by which to downsample.

    Returns:
    upsampled_tensor: Tensor, the upsampled image tensor.
    """
    # Get the original size of the image tensor
    original_size = image_tensor.shape[-2:]

    # Calculate the downsampled size
    downsampled_size = (original_size[0] // downsample_level, original_size[1] // downsample_level)

    # Downsample the image tensor
    downsampled_tensor = F.interpolate(image_tensor.unsqueeze(0), size=downsampled_size, mode='bilinear', align_corners=False).squeeze(0)

    # Upsample the image tensor back to the original size
    upsampled_tensor = F.interpolate(downsampled_tensor.unsqueeze(0), size=original_size, mode='bilinear', align_corners=False).squeeze(0)

    return upsampled_tensor


# def gaussian_blur(image_tensor, kernel_size=5, sigma=1.0):
#     """
#     Applies a Gaussian blur to a given tensor.

#     Parameters:
#     image_tensor: Tensor, the input image tensor.
#     kernel_size: int, the size of the Gaussian kernel.
#     sigma: float, the standard deviation of the Gaussian kernel.

#     Returns:
#     blurred_tensor: Tensor, the blurred image tensor.
#     """
#     # Define the Gaussian blur transform
#     gaussian_blur = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)

#     # Apply the Gaussian blur to the image tensor
#     blurred_tensor = gaussian_blur(image_tensor)

#     return blurred_tensor


def random_jpeg_compression(image_tensor, min_quality=30, max_quality=70):
    """
    Applies random JPEG compression with varying levels of quality to simulate artifacts and lower quality in images.

    Parameters:
    image_tensor: Tensor, the input image tensor.
    min_quality: int, the minimum JPEG quality.
    max_quality: int, the maximum JPEG quality.

    Returns:
    compressed_tensor: Tensor, the compressed image tensor.
    """
    # Convert the tensor to a PIL image
    image = transforms.ToPILImage()(image_tensor)

    # Generate a random quality level between min_quality and max_quality
    quality = random.randint(min_quality, max_quality)

    # Save the PIL image to a bytes buffer with the generated quality level
    buffer = io.BytesIO()
    image.save(buffer, format='JPEG', quality=quality)
    buffer.seek(0)

    # Load the image back from the bytes buffer
    compressed_image = Image.open(buffer)

    # Convert the PIL image back to a tensor
    compressed_tensor = transforms.ToTensor()(compressed_image)

    return compressed_tensor


def random_mask(image_tensor, mask_size=500, area_to_mask=4000000):
    """
    Randomly masks out regions of the image tensor.

    Proper range for mask_size is 500 to 1000.
    and for num_masks is such that in total 4k pixels are masked.

    Parameters:
    image_tensor: Tensor, the input image tensor.
    mask_size: int, the size of the mask.
    num_masks: int, the number of masks to apply.

    Returns:
    masked_tensor: Tensor, the masked image tensor.
    """
    # Get the dimensions of the image tensor
    _, height, width = image_tensor.shape

    # Create a copy of the image tensor to apply masks
    masked_tensor = image_tensor.clone()

    # Calculate the area of one mask
    mask_area = mask_size * mask_size

    # Calculate the number of masks needed defaulting to 4M pixels
    num_masks = area_to_mask // mask_area

    print(f"Number of masks: {num_masks}")

    for _ in range(num_masks):
        # Randomly select the top-left corner of the mask
        top = random.randint(0, height - mask_size)
        left = random.randint(0, width - mask_size)

        # Apply the mask by setting the selected region to zero
        masked_tensor[:, top:top + mask_size, left:left + mask_size] = 0

    return masked_tensor


def random_color_jitter(image_tensor, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1):
    """
    Applies random color jitter to the given tensor image.

    Parameters:
    image_tensor: Tensor, the input image tensor.
    brightness: float or tuple, how much to jitter brightness.
    contrast: float or tuple, how much to jitter contrast.
    saturation: float or tuple, how much to jitter saturation.
    hue: float or tuple, how much to jitter hue.

    Returns:
    jittered_tensor: Tensor, the image tensor with random color jitter applied.
    """
    # Create a ColorJitter transform
    color_jitter = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)

    # Apply the transform to the image tensor
    jittered_tensor = color_jitter(image_tensor)

    return jittered_tensor


def add_synthetic_shadows(image_tensor, num_shadows=3, shadow_intensity=0.5, shadow_color=(0, 0, 0)):
    """
    Adds synthetic shadows to an image tensor to mimic uneven lighting conditions.
    
    Parameters:
    - image_tensor (torch.Tensor): The input image tensor with shape (C, H, W).
    - num_shadows (int): Number of shadow shapes to add.
    - shadow_intensity (float): The intensity of the shadows (0 = no shadow, 1 = completely black).
    - shadow_color (tuple): The color of the shadow in RGB.

    Returns:
    - torch.Tensor: The image tensor with synthetic shadows.
    """
    
    _, H, W = image_tensor.shape
    shadow_image = image_tensor.clone()

    for _ in range(num_shadows):
        # Randomly generate an ellipse
        center_x = np.random.randint(0, W)
        center_y = np.random.randint(0, H)
        axis_x = np.random.randint(W // 8, W // 2)
        axis_y = np.random.randint(H // 8, H // 2)
        angle = np.random.uniform(0, 180)
        angle = torch.tensor(angle)  # Convert angle to a tensor


        # Create a meshgrid for the image
        Y, X = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')

        # Apply the ellipse equation
        ellipse = (((X - center_x) * torch.cos(angle) + (Y - center_y) * torch.sin(angle)) ** 2) / axis_x ** 2 + \
                  (((X - center_x) * torch.sin(angle) - (Y - center_y) * torch.cos(angle)) ** 2) / axis_y ** 2

        # Create a mask where the ellipse condition is satisfied
        mask = ellipse <= 1

        # Apply the shadow by reducing the intensity of the masked region
        for i in range(3):  # Assuming image is RGB
            shadow_image[i][mask] = (shadow_image[i][mask] * (1 - shadow_intensity) + 
                                      shadow_color[i] * shadow_intensity)

    return shadow_image


def apply_random_shearing(image_tensor, shear):
    """
    Applies random shearing effect to the given tensor image.

    Parameters:
    image_tensor: Tensor, the input image tensor.
    shear: float or tuple, range of degrees to select from for shearing.

    Returns:
    sheared_tensor: Tensor, the image tensor with random shearing applied.
    """
    # Create a RandomAffine transform with shearing
    random_shearing = transforms.RandomAffine(degrees=0, shear=shear)

    # Apply the transform to the image tensor
    sheared_tensor = random_shearing(image_tensor)

    return sheared_tensor


def apply_perspective_transform(image_tensor, distortion_scale=0.5, p=1.0):
    """
    Applies perspective transformations to simulate viewing the image from different angles.

    Parameters:
    image_tensor: Tensor, the input image tensor.
    distortion_scale: float, the degree of distortion (0 to 1).
    p: float, probability of applying the transformation.

    Returns:
    transformed_tensor: Tensor, the image tensor with perspective transformation applied.
    """
    # Create a RandomPerspective transform
    perspective_transform = transforms.RandomPerspective(distortion_scale=distortion_scale, p=p)

    # Apply the transform to the image tensor
    transformed_tensor = perspective_transform(image_tensor)

    return transformed_tensor


def apply_photographic_transformations(image_tensor, gamma_range=(0.8, 1.2), exposure_range=(0.8, 1.2), lighting_direction_range=(0.8, 1.2)):
    """
    Applies transformations like random changes in gamma, exposure, or lighting direction to simulate different photographic conditions.

    Parameters:
    image_tensor: Tensor, the input image tensor.
    gamma_range: tuple, range of gamma values to select from.
    exposure_range: tuple, range of exposure values to select from.
    lighting_direction_range: tuple, range of lighting direction values to select from.

    Returns:
    transformed_tensor: Tensor, the image tensor with photographic transformations applied.
    """
    # Apply random gamma correction
    gamma = random.uniform(*gamma_range)
    gamma_transform = transforms.functional.adjust_gamma(image_tensor, gamma)
    
    # Apply random exposure adjustment
    exposure = random.uniform(*exposure_range)
    exposure_transform = transforms.functional.adjust_brightness(gamma_transform, exposure)
    
    # Apply random lighting direction adjustment (simulated using brightness and contrast)
    lighting_direction = random.uniform(*lighting_direction_range)
    lighting_transform = transforms.functional.adjust_contrast(exposure_transform, lighting_direction)
    
    return lighting_transform

## Test Cell

In [5]:
# # Load an image and convert to tensor
# image = Image.open('/Users/ilerisoy/Library/CloudStorage/GoogleDrive-mtilerisoy@gmail.com/My Drive/Vlisco/ML-based-Image-Matching/data/designs/VL00815.jpg')
# image_tensor = transforms.ToTensor()(image)

# # Apply photographic transformations to the image tensor
# transformed_tensor = apply_photographic_transformations(image_tensor, gamma_range=(0.8, 1.2), exposure_range=(0.8, 1.2), lighting_direction_range=(0.8, 1.2))

# # Convert back to PIL image to visualize
# transformed_image = transforms.ToPILImage()(transformed_tensor)
# transformed_image.show()

# Configuration

In [5]:
DEVICE = "mps"

# Source directory containing the scraped folders
designs_dir = "../data/designs"

convert_mode = "RGB"

# Load the CLIP model
CLIP_model, CLIP_transform = clip.load("ViT-L/14@336px")

# Segmentation model initialization
seg_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
seg_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")

# Create the reference embeddings
create_reference_embeddings(designs_dir, CLIP_model, CLIP_transform, convert_mode=convert_mode, save_folder="")

# Load the design database embeddings and labels
with open(f'../data/design_embeddings_{convert_mode}.pkl', 'rb') as f:
    design_embeddings = pickle.load(f)
with open(f'../data/design_labels_{convert_mode}.pkl', 'rb') as f:
    design_labels = pickle.load(f)

print(f'Total number of embeddings: {len(design_embeddings)}')
print(f'Type of design embeddings: {type(design_embeddings)}')
print(f'Design Labels: {design_labels}')
print(f'Length of Design Labels: {len(design_labels)}')

  return func(*args, **kwargs)


Processing VL0H516.jpg...
Processing VL00562.jpg...
Processing VL03916.jpg...
Processing VL00760.jpg...
Processing VL49600.jpg...
Processing VL58650.jpg...
Processing VL08932.jpg...
Processing VL44050.jpg...
Processing VL00564.jpg...
Processing VL54350.jpg...
Processing VL08759.jpg...
Processing VL03816.jpg...
Processing VL03784.jpg...
Processing VL02918.jpg...
Processing VL03541.jpg...
Processing VL03999.jpg...
Processing VL48350.jpg...
Processing VL73650.jpg...
Processing VL8870.jpg...
Processing VL04009.jpg...
Processing VL2961R.jpg...
Processing VLH1167.jpg...
Processing VL01201.jpg...
Processing VLA0020.jpg...
Processing VLS8589.jpg...
Processing VL80021.jpg...
Processing VL2961Rotated.jpg...
Processing VL04490.jpg...
Processing VL00815.jpg...
Processing VL00633.jpg...
Processing VL65450.jpg...
Total number of embeddings: 31
Type of design embeddings: <class 'list'>
Design Labels: ['VL0H516.jpg', 'VL00562.jpg', 'VL03916.jpg', 'VL00760.jpg', 'VL49600.jpg', 'VL58650.jpg', 'VL08932.j

## Create triplets

It will take apprx. 1 hour to transform 250 designs -> 12 sec per design

In [71]:
def create_triplets(source_dir):
    
    # List all files in the source directory
    sub_files = os.listdir(source_dir)

    for file in sub_files:
        if file == ".DS_Store" or file == ".ipynb_checkpoints":
            continue
        print(f"Processing {file}...")

        start = time.time()

        # Get the path to the folder containing the images
        image_path = os.path.join(source_dir, file)

        # Load the images from the folder
        image = open_image(image_path, convert_mode=convert_mode)

        # Convert image to tensor
        image_tensor = image_to_tensor(image)

        # # Apply random rotation to the image tensor
        # rotated_tensor = apply_random_rotation(image_tensor, degrees=30)

        # # Apply downsample and upsample to the image tensor
        # downsampled_upsampled_tensor = downsample_and_upsample(image_tensor, downsample_level=10)

        # # Apply random JPEG compression to the image tensor
        # compressed_tensor = random_jpeg_compression(image_tensor, min_quality=30, max_quality=50)

        # # Apply random masks to the image tensor
        # masked_tensor = random_mask(image_tensor, mask_size=900)

        # # Apply color jitter to the image tensor
        # jittered_tensor = random_color_jitter(image_tensor, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)

        # # Add synthetic shadows to the image tensor
        # shadowed_tensor = add_synthetic_shadows(image_tensor)

        # # Apply perspective transformation to the image tensor
        # transformed_tensor = apply_perspective_transform(image_tensor, distortion_scale=0.5)

        # # Apply random shearing to the image tensor
        # sheared_tensor = apply_random_shearing(image_tensor, shear=20)

        # # Apply photographic transformations to the image tensor
        # photographic_tensor = apply_photographic_transformations(image_tensor, gamma_range=(0.8, 1.2), exposure_range=(0.8, 1.2), lighting_direction_range=(0.8, 1.2))

        print(f"Time taken: {time.time() - start}")

        # Embed the image
        image_features = image_encoder(image, CLIP_model, CLIP_transform)

        # Do cosine similarity with the design embeddings
        similarities = [torch.nn.functional.cosine_similarity(image_features, t) for t in design_embeddings]
        similarities = torch.stack(similarities)
        
        # Get the index of the most k similar designs
        k = 10
        top_k_similarities = similarities.T.topk(k)

        # Get the design labels of the top k similar designs
        top_k_design_labels = [design_labels[i] for i in top_k_similarities.indices[0]]

        print(f"Top K similarity values: {top_k_similarities.values[0,-5:]}")
        print(f"Top K design labels: {top_k_design_labels[-5:]}")
        
        break
    

create_triplets(designs_dir)

Processing VL0H516.jpg...
Time taken: 0.3835258483886719
Top K similarity values: tensor([0.7664, 0.7593, 0.7535, 0.7511, 0.7481], device='mps:0')
Top K design labels: ['VL04490.jpg', 'VL48350.jpg', 'VL00633.jpg', 'VL01201.jpg', 'VL73650.jpg']


## Compare Similarities

In [6]:
source_dir = "../data/models"
sub_files = os.listdir(source_dir)

# Initialize the vars to keep track of stats
match = 0
ds_strore_count = 0
failed_files = []
for file in sub_files:
    if file == ".DS_Store":
        ds_strore_count += 1
        continue
    # print(f"{file}")

    # Get the path to the folder containing the images
    image_path = os.path.join(source_dir, file)

    # Load the images from the folder
    image = open_image(image_path, convert_mode=convert_mode)
    
    # Get cloth segmentation mask
    segmented_image = get_segmentation_mask(image, seg_processor, seg_model)

    # Convert the tensor to a numpy array
    segmented_image = segmented_image.cpu().numpy()
    segmented_image = np.array(segmented_image, dtype=np.uint8)

    # Create a 3-channel mask
    segmented_image_3ch = np.stack([segmented_image] * 3, axis=-1)

    # Apply the mask to the input image
    filtered_image_np = np.where(segmented_image_3ch == 255, np.array(image), 0)

    # Convert the filtered image back to PIL format
    filtered_image = Image.fromarray(filtered_image_np, mode='RGB')

    # # Save the filtered image
    # filtered_image.save(f"../data/filtered_images/{file[:-4]}_filtered.jpg")

    # # Display the filtered image
    # Image._show(filtered_image)

    # # Display the segmented image
    # plt.imshow(segmented_image)

    # Embed the image
    image_features = image_encoder(filtered_image, CLIP_model, transform=CLIP_transform, save_folder="", filename=file)

    # Do cosine similarity with the design embeddings
    similarities = [torch.nn.functional.cosine_similarity(image_features, t) for t in design_embeddings]
    similarities = torch.stack(similarities)
    # print(f"Shape of similarities: {similarities.shape}")
    
    # Get the index of the most k similar designs
    k = 5
    top_k_similarities = similarities.T.topk(k)

    # print(f"Top K similarity values: {top_k_similarities.values}")

    # Get the design labels of the top k similar designs
    top_k_design_labels = [design_labels[i] for i in top_k_similarities.indices[0]]

    # print(f"Top K similarity values: {top_k_similarities.values}")
    # print(f"Top {k} similar designs for image {file}: {top_k_design_labels}")
    # print("################################")

    temp_match = match
    for design_label in top_k_design_labels:
        # print(f"Design label: {design_label[:7]}")
        # print(f"File: {file[:7]}")
        if design_label[:6] == file[:6]:
            match += 1
            print(f"MATCH: {match} in {file}   || Top K similarity values: {top_k_similarities.values}")
            # print(f"Top K similarity values: {top_k_similarities.values}")
            # print(f"Top {k} similar designs for image {file}: {top_k_design_labels}")
            break
    
    if temp_match == match:
        print(f"{file}")
        print(f"Top K similarity values: {top_k_similarities.values}")
        failed_files.append(file)
        # print(f"Top {k} similar designs for image {file}: {top_k_design_labels}")


print(f"Match: {match}/{len(sub_files)-ds_strore_count}")
print(f"Failed files: {failed_files}")

MATCH: 1 in VL0H516.jpg   || Top K similarity values: tensor([[0.7430, 0.7414, 0.7175, 0.7074, 0.7039]], device='mps:0')
VLXXXBeatlesCROPPED.png
Top K similarity values: tensor([[0.7188, 0.7029, 0.7028, 0.6957, 0.6882]], device='mps:0')
VL00562.jpg
Top K similarity values: tensor([[0.7314, 0.7084, 0.7044, 0.7041, 0.7018]], device='mps:0')
MATCH: 2 in VL03916.jpg   || Top K similarity values: tensor([[0.7601, 0.7437, 0.7270, 0.7234, 0.7125]], device='mps:0')
MATCH: 3 in VL00760.jpg   || Top K similarity values: tensor([[0.7787, 0.7510, 0.7396, 0.7393, 0.7377]], device='mps:0')
MATCH: 4 in VL49600.jpg   || Top K similarity values: tensor([[0.7247, 0.7187, 0.7180, 0.7148, 0.7144]], device='mps:0')
MATCH: 5 in VL58650.jpg   || Top K similarity values: tensor([[0.7906, 0.7355, 0.7337, 0.7320, 0.7284]], device='mps:0')
MATCH: 6 in VL08932.jpg   || Top K similarity values: tensor([[0.7355, 0.7245, 0.7115, 0.7007, 0.7003]], device='mps:0')
MATCH: 7 in VL44050.jpg   || Top K similarity values: 