In [14]:
import pickle
import os
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
# import matplotlib.pyplot as plt
# import io
# import random
# import time

import torch
# import torchvision.transforms as transforms
import torch.nn as nn

import clip
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation

# File Organization

In [15]:
def rename_files(source_dir):
    """
    Renames all jpg files in the source directory with their Design Labels.

    Parameters:
    source_dir: str, the path to the directory containing the jpg files.

    Returns:
    None
    """
    
    # List all files in the source directory
    files = os.listdir(source_dir)
    
    for file in files:

        # Check if the file is a jpg
        if file.endswith('.jpg') or file.endswith('.jp2'):

            # Get the file extension
            _, ext = os.path.splitext(file)

            # Skip the first VL in the file name
            first_vl_index = file.find('VL')

            # Find the next VL in the file name
            if first_vl_index != -1:
                start_index = file.find('VL', first_vl_index + 2)
                if start_index != -1:
                    end_index = file.find('.', start_index)
                    new_name = file[start_index:end_index] if end_index != -1 else file[start_index:]

                    # Rename the file
                    original_file_path = os.path.join(source_dir, file)
                    new_file_path = os.path.join(source_dir, new_name + ext)
                    os.rename(original_file_path, new_file_path)


def process_filenames(folder_path):
    # Iterate through each file in the folder
    for filename in os.listdir(folder_path):
        # Split the filename into parts based on the underscore
        parts = filename.split('_')
        
        # Ensure the filename has the correct number of parts
        if len(parts) >= 4:
            # Extract the part between the second and third underscore
            target_part = parts[2]
            
            # Replace the dot with an underscore in the extracted part
            new_part = target_part.replace('.', '_')
            
            # Get the file extension
            file_extension = os.path.splitext(filename)[1]
            
            # Construct the new filename
            new_filename = f"{new_part}{file_extension}"
            
            # Print or rename the file as needed
            print(f"Original filename: {filename}")
            print(f"Processed filename: {new_filename}\n")
            
            # To actually rename the file, uncomment the next line
            os.rename(os.path.join(folder_path, filename), os.path.join(folder_path, new_filename))

# # Example usage
# folder_path = '/Users/ilerisoy/Downloads/Classics/models'
# process_filenames(folder_path)


# Helpers

In [16]:
def open_image(image_path, convert_mode="RGB"):
    """
    Opens an image from the given path.

    Parameters:
    image_path: str, the path to the image.
    convert_mode: str, the mode to convert the image to. Options are "RGB" and "L".

    Returns:
    image: Image, the opened image.
    """

    assert convert_mode in ["RGB", "L"], "Invalid convert mode. Options are 'RGB' and 'L'."
    
    # Open the image
    image = Image.open(image_path)

    # Convert the image to specified mode
    image = image.convert(convert_mode)

    return image

def display_image(image):
    """
    Displays the image.

    Parameters:
    image: Image, the image to display.

    Returns:
    None
    """

    image.show()

def get_palette(num_cls):
    """ Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
            palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
            palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
            i += 1
            lab >>= 3
    return palette


def resize_images_in_folder(source_folder, destination_folder, size=(336, 336)):
    """
    Iterates through a given folder, opens each image file, resizes it to the specified size,
    and saves it to the specified destination folder.

    Parameters:
    - source_folder: str, the folder containing the images to resize.
    - destination_folder: str, the folder to save the resized images.
    - size: tuple, the target size for resizing (default is (336, 336)).

    Returns:
    - None
    """
    # Ensure the destination folder exists
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)

    # Iterate through each file in the source folder
    for filename in os.listdir(source_folder):
        file_path = os.path.join(source_folder, filename)
        
        # Check if the file is an image
        try:
            with Image.open(file_path) as img:
                # Resize the image
                img_resized = img.resize(size, Image.ANTIALIAS)
                
                # Save the resized image to the destination folder
                save_path = os.path.join(destination_folder, filename)
                img_resized.save(save_path)
                print(f"Resized and saved {filename} to {save_path}")
        except IOError:
            print(f"Skipping non-image file: {filename}")


def convert_jp2_to_jpg(folder_path):
    for filename in os.listdir(folder_path):
        if filename.endswith('.jp2'):
            # Open the jp2 file
            jp2_path = os.path.join(folder_path, filename)
            img = Image.open(jp2_path)
            
            # Convert the filename to .jpg
            new_filename = os.path.splitext(filename)[0] + '.jpg'
            jpg_path = os.path.join(folder_path, new_filename)
            
            # Save the image as a jpg
            img.convert('RGB').save(jpg_path, 'JPEG')
            
            # Optionally, delete the original .jp2 file
            # os.remove(jp2_path)
            
            print(f"Converted {filename} to {new_filename}")

# Functions

In [17]:
def image_encoder(image, model, transform):
    """
    Use CLIP model to encode the image and save the tranformed version.
    First resizes the image to 224x224, then normalizes it, and finally encodes it.
    Then, encodes the image into a 512-dimensional feature vector.
    
    
    Parameters:
    - image: PIL.Image object, the image to encode.
    - model: CLIP model, the model used for encoding.
    - transform: CLIP transform, the transformation to required for CLIP.
    - save_folder: str, the folder to save the transformed image.
    - filename: str, the name of the file to save the transformed image as.

    Returns:
    - image_features: torch.Tensor, the encoded image.
    """

    # Load the CLIP model
    model = model.eval().to(DEVICE)

    # Preprocess the image
    image = transform(image).unsqueeze(0).to(DEVICE)

    # Encode the image
    with torch.no_grad():
        image_features = model.encode_image(image)

    return image_features

def create_reference_embeddings(source_dir, CLIP_model, CLIP_transform, convert_mode, dataset_name):
    """
    Creates the image embeddings for the images in the source directory and saves together with labels.
    
    Parameters:
    - source_dir: str, the path to the directory containing the images.
    - CLIP_model: CLIP model, the CLIP model to use for encoding.
    - CLIP_transform: CLIP transforms, the CLIP transformation to apply to the images.
    - convert_mode: str, the mode to convert the image to. Options are "RGB" and "L".
    - dataset_name: str, the name of the dataset.
    
    Returns:
    None
    """

    # Get the list of files in the source directory
    sub_files = os.listdir(source_dir)

    # Initialize the list of image features and labels
    design_features_list = []
    design_labels_list = []

    for file in sub_files:
        if file == ".DS_Store":
            continue
        print(f"Processing {file}...")

        # Get the path to the folder containing the images
        image_path = os.path.join(source_dir, file)

        # Load the images from the folder
        image = open_image(image_path, convert_mode)

        # Embed the image
        image_features = image_encoder(image, CLIP_model, CLIP_transform)

        # Append the image features and labels to the lists
        design_features_list.append(image_features)
        design_labels_list.append(file)

    # Save the image features and labels
    with open(f"embeddings_{dataset_name}.pkl", 'wb') as f:
        pickle.dump(design_features_list, f)
    with open(f"labels_{dataset_name}.pkl", 'wb') as f:
        pickle.dump(design_labels_list, f)

def get_segmentation_mask(image, processor, model):
    """
    Function to segment clothes in an image.

    Parameters:
    - image: PIL.Image object, the image to segment.
    - processor: SegformerImageProcessor object, the processor used to preprocess the image.
    - model: AutoModelForSemanticSegmentation object, the model used to segment the image.

    Returns:
    - pred_seg: torch.Tensor, the segmented image.
    """
    inputs = processor(images=image, return_tensors="pt")

    outputs = model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]

    # Create a mask for the labels 4, 5, 6, and 7
    mask = (pred_seg == 4) | (pred_seg == 5) | (pred_seg == 6) | (pred_seg == 7) | (pred_seg == 8) | (pred_seg == 16) | (pred_seg == 17)

    # Set all other labels to 0
    pred_seg[~mask] = 0

    # Set the labels 4, 5, 6, and 7 to 255
    pred_seg[mask] = 255

    return pred_seg
    
    # plt.imshow(pred_seg)

# Configuration

In [18]:
DEVICE = "mps"

# Source directory containing the scraped folders
designs_dir = "/Users/ilerisoy/Downloads/Classics/high_accuracy_designs"

# Name of the embeddings file
dataset_name = "high_accuracy_designs"

# Color mode for the images
convert_mode = "L"

# Load the CLIP model
CLIP_model, CLIP_transform = clip.load("ViT-L/14@336px")

# Segmentation model initialization
seg_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
seg_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")

# Create the reference embeddings
# create_reference_embeddings(designs_dir, CLIP_model, CLIP_transform, convert_mode, dataset_name)

# Load the design database embeddings and labels
with open(f"embeddings_{dataset_name}_RGB.pkl", 'rb') as f:
    design_embeddings = pickle.load(f)
with open(f"labels_{dataset_name}_RGB.pkl", 'rb') as f:
    design_labels = pickle.load(f)

print(f'Total number of embeddings: {len(design_embeddings)}')
print(f'Type of design embeddings: {type(design_embeddings)}')
print(f'Design Labels: {design_labels}')
print(f'Length of Design Labels: {len(design_labels)}')

  return func(*args, **kwargs)


Total number of embeddings: 170
Type of design embeddings: <class 'list'>
Design Labels: ['VL48350_031.jp2', 'VL00014_308.jp2', 'VL04339_090.jp2', 'VL76950_202.jp2', 'VLH1498_020.jp2', 'VL03639_212.jp2', 'VL03636_136.jp2', 'VL65450_048.jp2', 'VL04918_075.jp2', 'VL0H596_111.jp2', 'VL03327_129.jp2', 'VL01003_124.jp2', 'VL00535_145.jp2', 'VL0H600_050.jp2', 'VL00017_342.jp2', 'VL00052_153.jp2', 'VL00864_001.jp2', 'VL01003_087.jp2', 'VL03143_077.jp2', 'VL03499_055.jp2', 'VL00017_346.jp2', 'VL00924_247.jp2', 'VL00511_279.jp2', 'VLA0455_030.jp2', 'VL00921_260.jp2', 'VL03988_252.jp2', 'VL03854_067.jp2', 'VL08863_083.jp2', 'VL00633_292.jp2', 'VL04528_083.jp2', 'VL0H418_012.jp2', 'VL01451_068.jp2', 'VL01681_071.jp2', 'VLH1455_048.jp2', 'VL00948_011.jp2', 'VL02511_224.jp2', 'VL49150_055.jp2', 'VL05124_148.jp2', 'VL45750_175.jp2', 'VL02210_034.jp2', 'VL03541_225.jp2', 'VL01178_253.jp2', 'VL01178_247.jp2', 'VL08682_009.jp2', 'VL0H907_145.jp2', 'VL04490_055.jp2', 'VL0H628_139.jp2', 'VL45750_216.jp2'

## Compare Similarities

In [19]:
# source_dir = "../data/models"
source_dir = "/Users/ilerisoy/Downloads/Classics/high_accuracy_models"
sub_files = os.listdir(source_dir)

# Initialize the vars to keep track of stats
match = 0
ds_strore_count = 0
failed_files = []
for file in sub_files:
    if file == ".DS_Store":
        ds_strore_count += 1
        continue
    # print(f"{file}")

    # Get the path to the folder containing the images
    image_path = os.path.join(source_dir, file)

    # Load the images from the folder
    image = open_image(image_path)
    
    # Get cloth segmentation mask
    segmented_image = get_segmentation_mask(image, seg_processor, seg_model)

    # Convert the tensor to a numpy array
    segmented_image = segmented_image.cpu().numpy()
    segmented_image = np.array(segmented_image, dtype=np.uint8)

    # Create a 3-channel mask
    segmented_image_3ch = np.stack([segmented_image] * 3, axis=-1)

    # Apply the mask to the input image
    filtered_image_np = np.where(segmented_image_3ch == 255, np.array(image), 0)

    # Convert the filtered image back to PIL format
    filtered_image = Image.fromarray(filtered_image_np, mode='RGB')

    filtered_image = filtered_image.convert("RGB")

    # Check if the save folder exists
    if not os.path.exists("filtered_images"):
        os.makedirs("filtered_images")
    
    # Save the filtered image
    filtered_image.save(f"filtered_images/{file[:-4]}_filtered.jpg")

    # # Display the filtered image
    # Image._show(filtered_image)

    # # Display the segmented image
    # plt.imshow(segmented_image)

    # Embed the image
    image_features = image_encoder(filtered_image, CLIP_model, transform=CLIP_transform)

    # Do cosine similarity with the design embeddings
    similarities = [torch.nn.functional.cosine_similarity(image_features, t) for t in design_embeddings]
    similarities = torch.stack(similarities)
    
    # Get the index of the most k similar designs
    k = 170
    top_k_similarities = similarities.T.topk(k)

    # Get the design labels of the top k similar designs
    top_k_design_labels = [design_labels[i] for i in top_k_similarities.indices[0]]

    temp_match = match
    for design_label in top_k_design_labels:
        # print(f"Design label: {design_label[:7]}")
        # print(f"File: {file[:7]}")
        if design_label[:6] == file[:6]:
            match += 1
            print(f"MATCH: {match} in {file}\n Top K similarity values: {top_k_similarities.values} \n Design Labels: {top_k_design_labels} \n")
            # print(f"Top K similarity values: {top_k_similarities.values}")
            # print(f"Top {k} similar designs for image {file}: {top_k_design_labels}")
            break
    
    if temp_match == match:
        print(f"######## in {file}\n Top K similarity values: {top_k_similarities.values} \n Design Labels: {top_k_design_labels} \n")
        failed_files.append(file)
        # print(f"Top {k} similar designs for image {file}: {top_k_design_labels}")

    # # Sort the similarities
    # sorted_similarities = torch.sort(similarities, descending=True)

    # # Sort the design labels based on the similarities
    # sorted_design_labels = np.array(design_labels)
    # sorted_design_labels = sorted_design_labels[sorted_similarities.indices.cpu().numpy()]

    # with open(f"output.txt", "w") as f:
    #     for score, label in zip(sorted_similarities, design_labels):
    #         f.write(f"Design Label: {label}\nSimilarity Score: {score}\n")


print(f"Match: {match}/{len(sub_files)-ds_strore_count}")
print(f"Failed files: {failed_files}")

MATCH: 1 in VL00014_308.jp2
 Top K similarity values: tensor([[0.7771, 0.7635, 0.7620, 0.7594, 0.7542, 0.7502, 0.7469, 0.7454, 0.7444,
         0.7417, 0.7410, 0.7397, 0.7393, 0.7376, 0.7357, 0.7338, 0.7317, 0.7302,
         0.7292, 0.7270, 0.7258, 0.7256, 0.7252, 0.7248, 0.7247, 0.7244, 0.7241,
         0.7231, 0.7224, 0.7219, 0.7204, 0.7203, 0.7182, 0.7154, 0.7134, 0.7133,
         0.7121, 0.7116, 0.7114, 0.7095, 0.7095, 0.7093, 0.7093, 0.7084, 0.7082,
         0.7081, 0.7079, 0.7059, 0.7053, 0.7038, 0.7033, 0.7030, 0.7023, 0.7001,
         0.6999, 0.6996, 0.6996, 0.6991, 0.6981, 0.6980, 0.6979, 0.6978, 0.6977,
         0.6973, 0.6967, 0.6966, 0.6960, 0.6958, 0.6958, 0.6955, 0.6945, 0.6930,
         0.6929, 0.6921, 0.6919, 0.6917, 0.6915, 0.6913, 0.6900, 0.6880, 0.6876,
         0.6865, 0.6863, 0.6861, 0.6853, 0.6844, 0.6817, 0.6815, 0.6814, 0.6813,
         0.6804, 0.6803, 0.6790, 0.6783, 0.6781, 0.6780, 0.6772, 0.6770, 0.6768,
         0.6765, 0.6763, 0.6754, 0.6752, 0.6749, 0.6749

KeyboardInterrupt: 