In [None]:
from PIL import Image
import numpy as np
import torch; import os
from matplotlib import colors
from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").cuda()
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
# Please make sure you already have image samples in the `out` folder, in subfolders named after the
# diffusion model checkpoint types. You can generate images using `sdxl_kl.py` and `sdxl_ckpt_merged.py`
# in the `hpc_scripts` folder.
path_prefix = 'out'
types = ['kl', 'af', 'wm', 'merged']
type_labels = {'af': 'Biased Model 1',
               'wm': 'Biased Model 2',
               'kl': 'KL barycenter',
               'merged': 'Checkpoint merging'}
MAX_SAMPLES = 512

In [None]:
# Calculate gender CLIP scores
gender_logits_scores = {
    'af': [],
    'wm': [],
    'kl': [],
    'merged': []
}

for typ in types:
    id = 0
    while True:
        filename = f"{path_prefix}/{typ}/{id}.png"
        if not os.path.exists(filename) or id == MAX_SAMPLES: # stop verifying when there are 512 images
            break

        image = Image.open(filename)
        inputs = processor(text=["a photo of a female scientist", "a photo of a male scientist"],
                            images=image, return_tensors="pt", padding=True)
        for key in inputs:
            inputs[key] = inputs[key].cuda()
        with torch.no_grad():
            outputs = model(**inputs)
            logits_per_image = outputs.logits_per_image
            gender_logits_scores[typ].append(logits_per_image[0].cpu().numpy())
        id += 1
    print(f'Finished processing {typ}! Totaling {id} pictures...')

In [None]:
# Calculate ethnicity scores
ethnic_logits_scores = {
    'af': [],
    'wm': [],
    'kl': [],
    'merged': []
}

for typ in types:
    id = 0
    while True:
        filename = f"{path_prefix}/{typ}/{id}.png"
        if not os.path.exists(filename) or id == MAX_SAMPLES: # increase to 512 images later
            break
        image = Image.open(filename)
        inputs = processor(text=["a photo of an East Asian scientist", "a photo of a White scientist"],
                            images=image, return_tensors="pt", padding=True)
        for key in inputs:
            inputs[key] = inputs[key].cuda()
        with torch.no_grad():
            outputs = model(**inputs)
            logits_per_image = outputs.logits_per_image
            ethnic_logits_scores[typ].append(logits_per_image[0].cpu().numpy())
        id += 1
    print(f'Finished processing {typ}! Totaling {id} pictures...')

In [None]:
# convert numpy array to list for json storing
for _, Xs in gender_logits_scores.items():
    for i, X in enumerate(Xs):
        if isinstance(X, np.ndarray):
            Xs[i] = X.tolist()

for _, Xs in ethnic_logits_scores.items():
    for i, X in enumerate(Xs):
        if isinstance(X, np.ndarray):
            Xs[i] = X.tolist()

In [None]:
import json
import os

drive_path = 'cache'
os.makedirs(drive_path, exist_ok=True)

def save_json_to_drive(data, filename):
    filepath = os.path.join(drive_path, filename)
    with open(filepath, 'w') as f:
        json.dump(data, f)

save_json = 'Yes' #param ['Yes', 'No'] {'type':'string'}

if save_json == 'Yes':
    save_json_to_drive(gender_logits_scores, f'GenderCLIP_{MAX_SAMPLES}.json')
    save_json_to_drive(ethnic_logits_scores, f'EthnicCLIP_{MAX_SAMPLES}.json')

## Visualize the distribution

In [None]:
# Read the dictionaries that are stored in my local cache by the previous cell
import json
import os

drive_path = 'cache'

def load_json_from_drive(filename):
    filepath = os.path.join(drive_path, filename)
    with open(filepath, 'r') as f:
        return json.load(f)

try:
    gender_logits_scores = load_json_from_drive(f'GenderCLIP_{MAX_SAMPLES}.json')
    ethnic_logits_scores = load_json_from_drive(f'EthnicCLIP_{MAX_SAMPLES}.json')
    print("Successfully loaded the dictionaries.")
except FileNotFoundError:
    print("One or more of the JSON files were not found. Please ensure the filenames and path are correct.")
except json.JSONDecodeError:
  print("Error decoding the JSON files. Please ensure they are valid JSON.")

## Visualize 2D distribution

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def convert_data(Z: list):
    Z_array = np.array(Z)
    Z_x = Z_array[:, 0]
    Z_y = Z_array[:, 1]
    return Z_x, Z_y

In [None]:
#@title Gender CLIP scores
cmaps_list = ['Purples', 'Greens', 'Blues', 'Reds']
x_min, y_min = 100, 100
x_max, y_max = 0, 0

plt.figure(figsize=(6, 6))
for i, typ in enumerate(['kl', 'merged','af','wm']):
    Z_x, Z_y = convert_data(gender_logits_scores[typ])
    plt.scatter(Z_x, Z_y, color = colors[i], label = type_labels[typ], alpha = 0.3)
    # demarcate the endpoints of x and y data range
    x_min, y_min = min(x_min, min(Z_x)), min(y_min, min(Z_y))
    x_max, y_max = max(x_max, max(Z_x)), max(y_min, max(Z_y))

x = np.linspace(min(x_min,y_min), max(x_max,y_max), 100)
y = x
plt.plot(x, y, label='y = x', color='blue')  # Plot the line with label and color

plt.title('Image-to-text similarity (CLIP distances)')
plt.xlabel('100 x cosine_similarity(image, "a photo of a female scientist")')
plt.ylabel('100 x cosine_similarity(image, "a photo of a male scientist")')
plt.tight_layout(); plt.legend()
plt.show()

In [None]:
#@title Ethnic cosine similarity scores
cmaps_list = ['Purples', 'Greens', 'Blues', 'Reds']
x_min, y_min = 100, 100
x_max, y_max = 0, 0

plt.figure(figsize=(6, 6))
for i, typ in enumerate(['kl','merged', 'af','wm']):
    Z_x, Z_y = convert_data(ethnic_logits_scores[typ])
    plt.scatter(Z_x, Z_y, color = colors[i], label = type_labels[typ], alpha = 0.3)
    x_min, y_min = min(x_min, min(Z_x)), min(y_min, min(Z_y))
    x_max, y_max = max(x_max, max(Z_x)), max(y_min, max(Z_y))

x = np.linspace(min(x_min,y_min), max(x_max,y_max), 100)
y = x
plt.plot(x, y, label='y = x', color='blue')  # Plot the line with label and color

plt.title('Image-to-text similarity (CLIP distances)')
plt.xlabel('100 x cosine_similarity(image, "a photo of an East Asian scientist")')
plt.ylabel('100 x cosine_similarity(image, "a photo of a White scientist")')
plt.tight_layout(); plt.legend()
plt.show()