In [None]:
#install necessary packages (refer to requirements.txt in root folder)

In [None]:
import base64
import os
from diffusers import DiffusionPipeline
import numpy as np
import requests
import torch
import torchvision.transforms.functional as F
import tqdm
from vendi_score import vendi
from PIL import Image
import google.generativeai as genai

In [None]:
# API Keys
#OpenAI
os.environ['OPENAI_API_KEY'] = 'YOUR_OPENAI_API_KEY'
#Gemini
genai.configure(api_key = "YOUR_GEMINI_API_KEY")
gemini_model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest")

In [None]:
def get_gemini_response(image_path, prompt):

    image_pil = Image.open(image_path)
    return gemini_model.generate_content([image_pil, prompt]).text


In [None]:
def get_gpt4_response(image_path, question):

    with open(image_path, "rb") as image_file:
      base64_image = base64.b64encode(image_file.read()).decode('utf-8')

    headers = {
      "Content-Type": "application/json",
      "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
    }

    payload = {
      "model": "gpt-4-turbo",
      "messages": [
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": f"{question}"
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image}"
              }
            }
          ]
        }
      ],
      "max_tokens": 300
    }

    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)

    return response.json()['choices'][0]['message']['content']

In [None]:
def prompt_stable_diffusion_xl(prompt,
                               negative_prompt,
                               base,
                               refiner,
                               use_refiner = False,
                               n_steps= 40,
                               HIGH_NOISE_FRAC = 0.8,
                               seed=1):
    # run both experts
    image = base(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=n_steps,
        denoising_end=HIGH_NOISE_FRAC,
        output_type="latent",
        seed=seed
    ).images
    if use_refiner:
        image = refiner(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=n_steps,
            denoising_start=HIGH_NOISE_FRAC,
            image=image,
            seed=seed
        ).images

    else:
        image = [F.to_pil_image(image[0].to("cpu"))]

    return image[0]

Add parameters here.
Use 'concept_space_samples' to provide examples of artifacts that are expected to be output from the T2I model. This will help increase geo-tagger accuracy.
Use 'sample_prompt' for input prompt. The prompt needs to be under-specified.
More details of the parameters in the paper: https://arxiv.org/abs/2407.06863


In [None]:
model_path =  "stabilityai/stable-diffusion-xl-base-1.0" #@param
model_name = 'sdxl' #@param
sample_prompt = "Image of traditional clothing" #@param
concept_space_samples = "sari, kimono, jeans, kurta" #@param
gpu_id = 0
geo_tagger_vlm = 'gemini' # choose between {'gemini', 'gpt4'}

Generate images

In [None]:
base = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
base.to(f"cuda:{gpu_id}")
refiner = DiffusionPipeline.from_pretrained(model_path, text_encoder_2=base.text_encoder_2, vae=base.vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
refiner.to(f"cuda:{gpu_id}")
num_images = 16
save_folder = 'diversity_images'
save_model_folder = os.path.join(save_folder, model_name)

if not os.path.exists(save_model_folder):
    os.makedirs(save_model_folder)

for seed_ in tqdm.tqdm(range(num_images)):

    img = prompt_stable_diffusion_xl(
            prompt=sample_prompt,
            base=base,
            refiner=refiner,
            use_refiner = True,
            negative_prompt = "multiple artifacts, blurry, painting, cartoon, artificial, nsfw, bad quality, bad anatomy, worst quality, low quality, low resolutions, extra fingers, blur, blurry, ugly, wrongs proportions, watermark, image artifacts, lowres, ugly, jpeg artifacts, deformed, noisy image",
            seed = seed_
    )


    image_name = f"{seed_}.png"
    save_path = os.path.join(save_model_folder, image_name)
    img.save(save_path)

Geo-tag images

In [None]:
image_paths = []
images = os.listdir(save_model_folder)
for item in images:
    item_path = os.path.join(save_model_folder, item)
    if item[-3:] == 'png':
        image_paths.append(item_path)


question1 = f"""Given is an image generated for prompt: {sample_prompt}.
                Which country is the artifact in image from?
                Which continent does the country belong to?
                What is the name of the cultural artifact? For example, cultural artifacts can be {concept_space_samples}
                You need to tag the image to the closest country/continent and name the artifact.
                Finally just output 3 terms "continent, country and artifact name" in a comma separated fashion and nothing else".
            """

all_annotations = []

for image_path in tqdm.tqdm((image_paths)):

    annotation = {}
    annotation['image_name'] = image_path.split('/')[-1]
    continent, country, artifact = '', '', ''
    if geo_tagger_vlm == 'gemini':
        [continent, country, artifact] = get_gemini_response(image_path, question1).strip().split(',')

    elif geo_tagger_vlm == 'gpt4':

        [continent, country, artifact] = get_gpt4_response(image_path, question1).strip().split(',')

    annotation['label'] = {'continent': continent, 'country': country, 'artifact':artifact}

    all_annotations.append(annotation)

Calculate Cultural Diversity score

In [None]:
def calculate_cultural_diversity(labels, similarity_function, _global=False, batch_size = 8):
    """Calculates normalized Vendi scores from annotation labels repeated over batches of 8 images.

    Args:
        labels: List of annotations loaded from a JSON file.
        similarity_function: The function used to calculate similarity for Vendi score.
        _global: Boolean to control if 'country' and 'continent' should be
                 included in the 'samples' tuple. True indicates global prompts,
                 False indicates within-culture prompts.

    Returns:
        A tuple containing the mean and standard deviation of the normalized Vendi scores.
    """
    if len(labels) < 32:
        labels = labels[:24]

    chunks = [labels[i:i + batch_size] for i in range(0, len(labels), batch_size)]
    all_vendi = []

    for chunk in chunks:
        if _global:
            samples = [(item['continent'], item['country'], item['artifact']) for item in chunk]
        else:
            samples = [(item, None, None) for item in chunk]
        if len(samples) == 0:
            continue
        normalized_vs = vendi.vendi_score(samples, similarity_function) / batch_size
        all_vendi.append(normalized_vs)

    return np.array(all_vendi).mean()

In [None]:
samples = [a['label'] for a in all_annotations]

In [None]:
similarity_function = lambda a, b: 1 * int(a[0]==b[0]) + 0 * int(a[1]==b[1]) + 0 * int(a[2]==b[2]) # hierarchical similarity function described in Section 5 of paper: https://arxiv.org/abs/2407.06863

In [None]:
calculate_cultural_diversity(samples, similarity_function, _global = True)