# Generate CLIP Embeddings for Polyvore Dataset

This notebook replicates the embedding generation process from the *OutfitTransformer* repository ([owj0421/outfit-transformer](https://github.com/owj0421/outfit-transformer)), specifically the script `src/run/1_generate_clip_embeddings.py`. It generates CLIP embeddings for the Polyvore dataset using the `patrickjohncyh/fashion-clip` model, saving them in the same format (`polyvore_{rank}.pkl`) and location (`{polyvore_dir}/precomputed_clip_embeddings`). The goal is to produce embeddings identical to the official implementation, which we can verify by comparing with the official outputs.

## Setup
- **Dataset**: Polyvore (~251,008 items, `item_metadata.json`, `images/`).
- **Model**: `patrickjohncyh/fashion-clip` (CLIP ViT-B/32).
- **Output**: Embeddings saved as `{polyvore_dir}/precomputed_clip_embeddings/polyvore_{rank}.pkl`.
- **Environment**: Python 3.8+, PyTorch 1.9.0, transformers, numpy, etc.

## Prerequisites
Run the following to install dependencies:
```bash
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install transformers==4.9.2 numpy pickle5 tqdm wandb pillow
pip install git+https://github.com/patrickjohncyh/fashion-clip.git
```

In [1]:
# Import libraries
import json
import logging
import os
import pathlib
import pickle

import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from PIL import Image

from transformers import CLIPProcessor, CLIPModel

# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Define paths
SRC_DIR = pathlib.Path(os.getcwd()).absolute()
LOGS_DIR = SRC_DIR / 'logs'
os.makedirs(LOGS_DIR, exist_ok=True)

POLYVORE_DIR = './datasets/polyvore'  # Adjust if your dataset is elsewhere
POLYVORE_PRECOMPUTED_CLIP_EMBEDDING_DIR = f"{POLYVORE_DIR}/precomputed_clip_embeddings"
POLYVORE_METADATA_PATH = f"{POLYVORE_DIR}/item_metadata.json"
POLYVORE_IMAGE_DATA_PATH = f"{POLYVORE_DIR}/images/{{item_id}}.jpg"

# Configure logging
# Configure logging
logging.basicConfig(
    filename=LOGS_DIR / 'precompute_clip_embedding.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('precompute_clip_embedding')

  from .autonotebook import tqdm as notebook_tqdm


## Define Utility Functions

We replicate the dataset loading, model setup, and distributed processing logic from the official implementation.

In [2]:
# Import required libraries for utility functions
import random
import numpy as np
import torch
from PIL import Image
import json

# Define FashionItem class (from datatypes.py)
class FashionItem:
    def __init__(self, item_id, category, image=None, description="", metadata=None, embedding=None):
        self.item_id = item_id
        self.category = category
        self.image = image
        self.description = description
        self.metadata = metadata or {}
        self.embedding = embedding

# Load metadata (from polyvore.py)
def load_metadata(dataset_dir):
    metadata = {}
    with open(POLYVORE_METADATA_PATH.format(dataset_dir=dataset_dir), 'r') as f:
        metadata_ = json.load(f)
        for item in metadata_:
            metadata[item['item_id']] = item
    logger.info(f"Loaded {len(metadata)} metadata")
    print(f"Loaded {len(metadata)} metadata")
    return metadata

# Load image (from polyvore.py)
def load_image(dataset_dir, item_id, size=(224, 224)):
    image_path = POLYVORE_IMAGE_DATA_PATH.format(dataset_dir=dataset_dir, item_id=item_id)
    try:
        image = Image.open(image_path).convert('RGB')
        return image
    except Exception as e:
        logger.error(f"Error loading image {image_path}: {e}")
        print(f"Error loading image {image_path}: {e}")
        return None

# Load item (from polyvore.py)
def load_item(dataset_dir, metadata, item_id, should_load_image=False, embedding_dict=None):
    metadata_ = metadata[item_id]
    return FashionItem(
        item_id=metadata_['item_id'],
        category=metadata_['semantic_category'],
        image=load_image(dataset_dir, metadata_['item_id']) if should_load_image else None,
        description=metadata_['title'] if metadata_.get('title') else metadata_['url_name'],
        metadata=metadata_,
        embedding=embedding_dict[item_id] if embedding_dict else None
    )

# PolyvoreItemDataset (from polyvore.py)
class PolyvoreItemDataset:
    def __init__(self, dataset_dir, metadata=None, embedding_dict=None, load_image=False):
        self.dataset_dir = dataset_dir
        self.metadata = metadata if metadata else load_metadata(dataset_dir)
        self.load_image = load_image
        self.embedding_dict = embedding_dict
        self.all_item_ids = list(self.metadata.keys())

    def __len__(self):
        return len(self.all_item_ids)

    def __getitem__(self, idx):
        return load_item(self.dataset_dir, self.metadata, self.all_item_ids[idx], 
                         should_load_image=self.load_image, embedding_dict=self.embedding_dict)

# Collate function (from collate_fn.py)
def item_collate_fn(batch):
    return [item for item in batch]

# Distributed setup (from utils/distributed_utils.py)
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

# Seed everything (from utils/utils.py)
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## Configure Arguments

We replicate the argument parsing from `1_generate_clip_embeddings.py`.

In [3]:
# Define arguments (from parse_args)
class Args:
    model_type = 'clip'
    polyvore_dir = './datasets/polyvore'
    polyvore_type = 'nondisjoint'
    batch_sz_per_gpu = 128
    n_workers_per_gpu = 4
    checkpoint = None
    world_size = torch.cuda.device_count() if torch.cuda.is_available() else 1
    demo = False

args = Args()
print(f"Arguments: {vars(args)}")

# Set seed
seed_everything(42)

Arguments: {}


## Define Model Loading

We replicate the model loading logic from `models/load.py`, using `patrickjohncyh/fashion-clip`.

In [4]:
# Model loading (simplified from models/load.py)
def load_model(model_type='clip', checkpoint=None):
    if model_type != 'clip':
        raise ValueError("Only 'clip' model_type is supported in this notebook")
    
    model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
    processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
    model.eval()
    
    if checkpoint:
        state_dict = torch.load(checkpoint, map_location='cpu')
        model.load_state_dict(state_dict)
    
    return model, processor

# Precompute CLIP embedding function
def precompute_clip_embedding(model, processor, batch):
    images = [item.image for item in batch]
    texts = [item.description for item in batch]
    
    # Process images
    inputs = processor(images=images, text=texts, return_tensors="pt", padding=True, truncation=True, max_length=64)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        image_embeds = outputs.image_embeds  # (batch_size, 512)
        text_embeds = outputs.text_embeds   # (batch_size, 512)
        embeddings = torch.cat((image_embeds, text_embeds), dim=-1)  # (batch_size, 1024)
    
    return embeddings.cpu().numpy()

## Compute Embeddings

We replicate the `compute` function from `1_generate_clip_embeddings.py`, adapting it for DDP in a notebook.

In [5]:
def compute(rank, world_size, args):
    # Setup DDP
    setup(rank, world_size)
    logger.info(f"Logger Setup Completed", extra={'rank': rank})
    
    # Setup dataloader
    item_dataset = PolyvoreItemDataset(
        dataset_dir=args.polyvore_dir,
        load_image=True
    )
    
    n_items = len(item_dataset)
    n_items_per_gpu = n_items // world_size
    start_idx = n_items_per_gpu * rank
    end_idx = start_idx + n_items_per_gpu if rank < world_size - 1 else n_items
    item_dataset = torch.utils.data.Subset(item_dataset, range(start_idx, end_idx))
    
    item_dataloader = DataLoader(
        dataset=item_dataset,
        batch_size=args.batch_sz_per_gpu,
        shuffle=False,
        num_workers=args.n_workers_per_gpu,
        collate_fn=item_collate_fn
    )
    logger.info(f"Dataloaders Setup Completed", extra={'rank': rank})
    
    # Load model
    model, processor = load_model(model_type=args.model_type, checkpoint=args.checkpoint)
    model.to(rank)
    if world_size > 1:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    logger.info(f"Model Loaded", extra={'rank': rank})
    
    # Compute embeddings
    all_ids, all_embeddings = [], []
    with torch.no_grad():
        for batch in tqdm(item_dataloader, desc=f"Rank {rank}"):
            if args.demo and len(all_embeddings) > 10:
                break
            
            embeddings = precompute_clip_embedding(model.module if world_size > 1 else model, processor, batch)
            all_ids.extend([item.item_id for item in batch])
            all_embeddings.append(embeddings)
    
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    logger.info(f"Computed {len(all_embeddings)} embeddings", extra={'rank': rank})
    print(f"Rank {rank}: Computed {len(all_embeddings)} embeddings")
    
    # Save embeddings
    save_dir = POLYVORE_PRECOMPUTED_CLIP_EMBEDDING_DIR.format(polyvore_dir=args.polyvore_dir)
    os.makedirs(save_dir, exist_ok=True)
    save_path = f"{save_dir}/polyvore_{rank}.pkl"
    with open(save_path, 'wb') as f:
        pickle.dump({'ids': all_ids, 'embeddings': all_embeddings}, f)
    logger.info(f"Saved embeddings to {save_path}", extra={'rank': rank})
    print(f"Rank {rank}: Saved embeddings to {save_path}")
    
    # Cleanup DDP
    cleanup()

# Run computation
world_size = args.world_size
if world_size > 1:
    mp.spawn(
        compute,
        args=(world_size, args),
        nprocs=world_size,
        join=True
    )
else:
    compute(0, 1, args)

Loaded 251008 metadata


Rank 0: 100%|██████████| 1961/1961 [12:47<00:00,  2.55it/s]


Rank 0: Computed 251008 embeddings
Rank 0: Saved embeddings to ./datasets/polyvore/precomputed_clip_embeddings/polyvore_0.pkl
