In [None]:
%pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import ssl
import os
import numpy as np
from PIL import Image
import clip

ssl._create_default_https_context = ssl._create_unverified_context

# Device selection
if torch.cuda.is_available():
    print("Using GPU")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    print("Using MPS")
    device = torch.device("mps")
else:
    print("Using CPU")
    device = torch.device("cpu")

os.makedirs('output', exist_ok=True)
os.makedirs('data', exist_ok=True)

In [None]:
from tqdm import tqdm # For progress bars

def precompute_clip_embeddings(output_path):
    print("Loading CLIP model...")
    model, preprocess = clip.load("ViT-B/32", device=device)

    # ImageFolder loads all images from the subfolders of the root directory
    dataset = datasets.ImageFolder(
        root='data/celeba_hq/img_align',
        transform=preprocess
    )

    print(f"Found {len(dataset)} images.")
    
    batch_size_embed = 50 if device.type == 'cpu' else 100
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size_embed, shuffle=False)

    all_embeddings = []
    with torch.no_grad():
        for images, _ in tqdm(dataloader):
            images = images.to(device)
            image_features = model.encode_image(images)
            all_embeddings.append(image_features.cpu())

    embeddings = torch.cat(all_embeddings, dim=0)
    print(f"Final embeddings shape: {embeddings.shape}")

    torch.save({'embeddings': embeddings}, output_path)
    print(f"Saved all embeddings to {output_path}")

In [None]:
embeddings_path = 'celeba_hq_clip_embeddings.pt'

if not os.path.exists(embeddings_path):
    precompute_clip_embeddings(embeddings_path)
else:
    print(f"Embeddings file already exists: {embeddings_path}")