In [7]:
import torch
import clip
from PIL import Image
import pandas as pd
import os
from tqdm import tqdm
import numpy as np

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def generate_attribute_prompts():
    """Generate text prompts for all attributes"""
    prompts = {
        'hair_color': {
            0: "a photo of a person with blonde hair",
            1: "a photo of a person with brown hair",
            2: "a photo of a person with black hair",
            3: "a photo of a person with other hair color"
        },
        'pale_skin': {
            1: "a photo of a person with pale skin",
            -1: "a photo of a person with non-pale skin"
        },
        'male': {
            1: "a photo of a man",
            -1: "a photo of a woman"
        },
        'no_beard': {
            1: "a photo of a person without a beard",
            -1: "a photo of a person with a beard"
        }
    }
    return prompts

def evaluate_attributes(image_path, attribute_prompts, model, preprocess, device):
    """Evaluate a single image for all attributes using CLIP"""
    # Load and preprocess image
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    
    results = {}
    
    with torch.no_grad():
        # Get image features
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        
        # Process each attribute
        for attr, prompts in attribute_prompts.items():
            # Prepare text inputs
            text_inputs = clip.tokenize([prompt for prompt in prompts.values()]).to(device)
            
            # Get text features
            text_features = model.encode_text(text_inputs)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            
            # Calculate similarity
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            
            # Get prediction (index of highest similarity)
            pred_idx = similarity[0].argmax().item()
            
            # Map back to original label
            pred_label = list(prompts.keys())[pred_idx]
            results[attr] = pred_label
            
    return results

def evaluate_clip_performance(df, image_dir):
    """Evaluate CLIP's performance on the dataset"""
    # Generate prompts
    attribute_prompts = generate_attribute_prompts()
    
    # Initialize results
    results = {
        'hair_color': {'correct': 0, 'total': 0, 'per_class': {}},
        'pale_skin': {'correct': 0, 'total': 0, 'per_class': {}},
        'male': {'correct': 0, 'total': 0, 'per_class': {}},
        'no_beard': {'correct': 0, 'total': 0, 'per_class': {}}
    }
    
    # Initialize per-class counters
    for attr in results:
        for label in attribute_prompts[attr].keys():
            results[attr]['per_class'][label] = {'correct': 0, 'total': 0}
    
    # Process each image
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing images"):
        image_path = os.path.join(image_dir, row['image_id'])
        
        try:
            # Get predictions for all attributes
            predictions = evaluate_attributes(image_path, attribute_prompts, model, preprocess, device)
            
            # Update results
            for attr, pred in predictions.items():
                # Map DataFrame columns to attribute names
                column_mapping = {
                    'hair_color': 'Hair_Color',
                    'pale_skin': 'Pale_Skin',
                    'male': 'Male',
                    'no_beard': 'No_Beard'
                }
                
                true_label = row[column_mapping[attr]]
                
                # Update overall counts
                results[attr]['total'] += 1
                if pred == true_label:
                    results[attr]['correct'] += 1
                
                # Update per-class counts
                results[attr]['per_class'][true_label]['total'] += 1
                if pred == true_label:
                    results[attr]['per_class'][true_label]['correct'] += 1
                
        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")
            continue
    
    # Calculate accuracies
    accuracies = {}
    per_class_accuracies = {}
    
    for attr, data in results.items():
        # Overall accuracy
        accuracies[attr] = data['correct'] / data['total'] if data['total'] > 0 else 0
        
        # Per-class accuracy
        per_class_accuracies[attr] = {}
        for label, counts in data['per_class'].items():
            acc = counts['correct'] / counts['total'] if counts['total'] > 0 else 0
            per_class_accuracies[attr][label] = acc
    
    return accuracies, per_class_accuracies

if __name__ == "__main__":
    # Load and preprocess the DataFrame
    df = pd.read_csv(r"/home/omrid/Desktop/jungo /projectCLIPvae/celeba_dataset/list_attr_celeba.csv")
    
    # Define hair color mapping function
    def haircolor(x):
        if x["Blond_Hair"] == 1:
            return 0
        elif x["Brown_Hair"] == 1:
            return 1
        elif x["Black_Hair"] == 1:
            return 2
        else:
            return 3
    
    # Apply hair color mapping and select relevant columns
    df["Hair_Color"] = df.apply(haircolor, axis=1)
    df = df[["image_id", "Hair_Color", 'Pale_Skin', "Male", "No_Beard"]]
    
    # Take only first 10000 images
    df = df.head(10000)
    
    # Set your image directory
    image_dir = "/home/omrid/Desktop/jungo /projectCLIPvae/celeba_dataset/img_align_celeba/img_align_celeba/"
    
    # Evaluate CLIP
    accuracies, per_class_accuracies = evaluate_clip_performance(df, image_dir)
    
    # Print results
    print("\nOverall Accuracies:")
    for attr, acc in accuracies.items():
        print(f"{attr}: {acc:.4f}")
    
    print("\nPer-class Accuracies:")
    for attr, class_accs in per_class_accuracies.items():
        print(f"\n{attr}:")
        for cls, acc in class_accs.items():
            print(f"  Class {cls}: {acc:.4f}")

Processing images: 100%|██████████| 10000/10000 [06:43<00:00, 24.78it/s]


Overall Accuracies:
hair_color: 0.4821
pale_skin: 0.8708
male: 0.9890
no_beard: 0.8248

Per-class Accuracies:

hair_color:
  Class 0: 0.8712
  Class 1: 0.7500
  Class 2: 0.6152
  Class 3: 0.1457

pale_skin:
  Class 1: 0.1986
  Class -1: 0.9020

male:
  Class 1: 0.9810
  Class -1: 0.9948

no_beard:
  Class 1: 0.9795
  Class -1: 0.0505



