In [None]:
#spltting up the data into smaller chunks for it to be accessed
import os
from pathlib import Path
import shutil

SRC = Path("/content/drive/MyDrive/FFHQ")
DEST = Path("/content/drive/MyDrive/FFHQ_SPLIT")
DEST.mkdir(exist_ok=True)

IMAGES_PER_FOLDER = 3000

files = sorted([p for p in SRC.iterdir() if p.suffix.lower() in [".png",".jpg",".jpeg"]])

folder_index = 1
count = 0

subfolder = DEST / f"part_{folder_index}"
subfolder.mkdir(exist_ok=True)

for img in files:
    if count >= IMAGES_PER_FOLDER:
        folder_index += 1
        count = 0
        subfolder = DEST / f"part_{folder_index}"
        subfolder.mkdir(exist_ok=True)

    shutil.move(str(img), str(subfolder / img.name))
    count += 1
    if count % 500 == 0:
        print(f"Moved {count} images into {subfolder}")

In [None]:
!pip install git+https://github.com/openai/CLIP.git

import os
import json
from pathlib import Path
from PIL import Image
import torch
import clip
from torchvision import transforms
import gc
import signal

#timeout exception to deal with files that take too long to load
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException()

signal.signal(signal.SIGALRM, timeout_handler)

#configuration
from google.colab import drive
drive.mount('/content/drive')

FFHQ_ROOT = Path("/content/drive/MyDrive/FFHQ_SPLIT")
OUTPUT_DIR = Path("/content/drive/MyDrive/HairCLIP")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

BATCH_SIZE = 5000
HIGH_CONF_THRESHOLD = 0.3
MED_CONF_THRESHOLD = 0.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

HAIRSTYLES = [
    "buzz cut hairstyle", "fade hairstyle", "high fade hairstyle",
    "low fade hairstyle", "taper fade hairstyle", "middle part hairstyle",
    "side part hairstyle", "wolfcut hairstyle", "mullet hairstyle",
    "bob cut hairstyle", "pixie cut hairstyle", "ponytail hairstyle",
    "long curly hairstyle", "long straight hairstyle", "short curly hairstyle",
    "short straight hairstyle", "bowl cut hairstyle", "undercut hairstyle",
    "textured crop hairstyle", "modern mullet hairstyle", "curtained flow hairstyle",
    "hockey flow hairstyle", "butterfly cut hairstyle", "curtain bangs hairstyle",
    "short length hairstyle", "medium length hairstyle", "long length hairstyle", "dreads hairstyle",
    "braids hairstyle"
]

#loading clip model onto device
print(f"Loading CLIP model on {DEVICE}...")
clip_model, preprocess_clip = clip.load("ViT-B/32", device=DEVICE)

prompts = [f"A photo of a person with a {h}." for h in HAIRSTYLES]
text_tokens = clip.tokenize(prompts).to(DEVICE)

with torch.no_grad():
    text_features = clip_model.encode_text(text_tokens)
    text_features /= text_features.norm(dim=-1, keepdim=True)

#finding all images to process
def scan_images_recursively(root):
    """Walk subfolders and return list of full paths."""
    images = []
    for root_dir, dirs, files in os.walk(root):
        for fname in files:
            if fname.lower().endswith((".png", ".jpg", ".jpeg")):
                images.append(os.path.join(root_dir, fname))
    return images

print("Scanning FFHQ_SPLIT recursively...")
all_images = scan_images_recursively(FFHQ_ROOT)
print(f"Found {len(all_images)} images across subfolders.")

num_batches = (len(all_images) + BATCH_SIZE - 1) // BATCH_SIZE
print(f"Processing in {num_batches} batches...")


def clip_label_image(img_path):
    img = preprocess_clip(Image.open(img_path).convert("RGB")).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        img_feat = clip_model.encode_image(img)
        img_feat /= img_feat.norm(dim=-1, keepdim=True)
        sims = (img_feat @ text_features.T).squeeze(0)
        best_idx = sims.argmax().item()
        score = sims[best_idx].item()
        return HAIRSTYLES[best_idx], score

#Use zero shot with CLIP model to assign images a text embedding
for batch_idx in range(num_batches):
    batch_images = all_images[batch_idx*BATCH_SIZE : (batch_idx+1)*BATCH_SIZE]
    batch_labels = {}

    print(f"\nProcessing batch {batch_idx+1}/{num_batches} ({len(batch_images)} images)...")

    for img_i, img_path in enumerate(batch_images):

        if img_i % 200 == 0:
            print(f"  Processed {img_i}/{len(batch_images)} images...", flush=True)

        try:
            signal.alarm(5)
            label, score = clip_label_image(img_path)
            signal.alarm(0)

            if score >= MED_CONF_THRESHOLD:
                batch_labels[os.path.basename(img_path)] = {
                    "label": label,
                    "confidence": float(score),
                    "use_for_training": score >= HIGH_CONF_THRESHOLD
                }

        except TimeoutException:
            print(f"⏳ Timeout on {img_path}, skipping.")
            signal.alarm(0)
            continue

        except Exception as e:
            print(f"❌ Error processing {img_path}: {e}")
            signal.alarm(0)
            continue

    batch_file = OUTPUT_DIR / f"ffhq_labels_batch_{batch_idx}.json"
    with open(batch_file, "w") as f:
        json.dump(batch_labels, f, indent=4)

    print(f"Saved batch labels to {batch_file}")

    gc.collect()
    torch.cuda.empty_cache()


print("\nMerging all batch JSONs...")
final_labels = {}

batch_files = sorted(OUTPUT_DIR.glob("ffhq_labels_batch_*.json"))
if not batch_files:
    raise ValueError("No batch JSON files found!")

for batch_file in batch_files:
    with open(batch_file, "r") as f:
        final_labels.update(json.load(f))

final_file = OUTPUT_DIR / "ffhq_labels.json"
with open(final_file, "w") as f:
    json.dump(final_labels, f, indent=4)

print(f"All batches merged. Final labels saved to {final_file}")