In [1]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import faiss

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
resnet = models.resnet50(pretrained=True)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])
resnet.to(device)
resnet.eval()



Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
dir = "./dataset/DeepFashion2/"
csv_paths = {
    "train": os.path.join(dir, "img_info_dataframes/train.csv"),
    "validation": os.path.join(dir, "img_info_dataframes/validation.csv"),
    "test": os.path.join(dir, "img_info_dataframes/test.csv")
}
image_dirs = {
    "train": os.path.join(dir, "deepfashion2_original_images/train/image"),
    "validation": os.path.join(dir, "deepfashion2_original_images/validation/image"),
    "test": os.path.join(dir, "deepfashion2_original_images/test/test/image")
}
embeddings_dir = "./embeddings/"

In [6]:
embeddings_dir = "./embeddings/"

def extract_features(split):
    print(f"\nüîπ Extracting features for {split} set...")

    df = pd.read_csv(csv_paths[split])
    image_list = df['path'].tolist()

    embeddings = []
    image_names = []
    missing_count = 0

    for img_name in tqdm(image_list, desc=f"Processing {split} images"):
        img_filename = os.path.basename(img_name)
        img_path = os.path.join(image_dirs[split], img_filename)

        if not os.path.exists(img_path):
            print(f"‚ùå Image not found: {img_path}")
            missing_count += 1
            continue

        img = Image.open(img_path).convert("RGB")
        img_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            features = resnet(img_tensor).squeeze().cpu().numpy()

        embeddings.append(features)
        image_names.append(img_filename)

    embeddings = np.array(embeddings)

    np.save(os.path.join(embeddings_dir, f"deepfashion2_{split}_embeddings.npy"), embeddings)
    np.save(os.path.join(embeddings_dir, f"deepfashion2_{split}_image_names.npy"), np.array(image_names))

    print(f"‚úÖ Feature extraction for {split} complete. Saved embeddings in `{embeddings_dir}`.")
    print(f"‚ùó Missing images: {missing_count}/{len(image_list)}")

In [7]:
for split in ["train", "validation", "test"]:
    extract_features(split)


üîπ Extracting features for train set...


Processing train images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 312186/312186 [1:13:52<00:00, 70.43it/s]


‚úÖ Feature extraction for train complete. Saved embeddings in `./embeddings/`.
‚ùó Missing images: 0/312186

üîπ Extracting features for validation set...


Processing validation images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 52490/52490 [12:43<00:00, 68.77it/s]


‚úÖ Feature extraction for validation complete. Saved embeddings in `./embeddings/`.
‚ùó Missing images: 0/52490

üîπ Extracting features for test set...


Processing test images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 62629/62629 [16:46<00:00, 62.24it/s]


‚úÖ Feature extraction for test complete. Saved embeddings in `./embeddings/`.
‚ùó Missing images: 0/62629


In [8]:
for split, path in csv_paths.items():
    df = pd.read_csv(path)
    print(f"Columns in {split}.csv: {df.columns.tolist()}")

Columns in train.csv: ['path', 'segmentation', 'landmarks', 'b_box', 'category_id', 'category_name', 'scale', 'viewpoint', 'occlusion', 'zoom_in', 'img_height', 'img_width']
Columns in validation.csv: ['path', 'segmentation', 'landmarks', 'b_box', 'category_id', 'category_name', 'scale', 'viewpoint', 'occlusion', 'zoom_in', 'img_height', 'img_width']
Columns in test.csv: ['path', 'img_height', 'img_width']
