In [None]:
import torch
import clip
from PIL import Image
import os
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from torchmetrics.functional.multimodal import clip_score
from functools import partial


In [None]:
model_id = "/data/model/stable-diffusion-2-1-base"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

In [None]:
clip_model, clip_preprocess = clip.load('ViT-B/32')
# clip_model = CLIPModel.from_pretrained("/data/model/blip-image-captioning-large")
# clip_preprocess = CLIPProcessor.from_pretrained("/data/model/blip-image-captioning-large")
clip_model = clip_model.to("cuda")

In [None]:
def get_clip_score(image, text):
    # Load the pre-trained CLIP model and the image

    # Preprocess the image and tokenize the text
    image_input = clip_preprocess(image).unsqueeze(0)
    text_input = clip.tokenize([text])
    
    # Move the inputs to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    image_input = image_input.to(device)
    text_input = text_input.to(device)
    
    # Generate embeddings for the image and text
    with torch.no_grad():
        image_features = clip_model.encode_image(image_input)
        text_features = clip_model.encode_text(text_input)
    
    # Normalize the features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    # Calculate the cosine similarity to get the CLIP score
    clip_score = torch.matmul(image_features, text_features.T).item()
    
    return clip_score

In [None]:

clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")

def calculate_clip_score(images, prompts):
    images_int = (images * 255).astype("uint8")
    clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
    return round(float(clip_score), 4)

In [None]:
prompts = ["a photo of an astronaut riding a horse on mars"]
images = pipe(prompts, num_inference_steps=25, output_type="np").images  

sd_clip_score = calculate_clip_score(images, prompts)
print(f"CLIP score: {sd_clip_score}")
# CLIP score: 35.7038

In [None]:
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt, num_inference_steps=25).images[0]  
    
image
print(get_clip_score(image,prompt))

In [None]:
# Path to the folder containing your images
folder_path = "/home/liutao/workspace/distill/swift_photo_with_text"

# Initialize empty lists to store images and their names
image_list = []
image_name_list = []

# Loop through each file in the folder
for filename in os.listdir(folder_path):
    # Check if the file is an image (you can customize the extension check)
    if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
        # Load the image
        image_path = os.path.join(folder_path, filename)
        image = Image.open(image_path)

        # Append the image and its name to the lists
        image_list.append(image)
        image_name_list.append(filename)

# Now, image_list contains PIL Image objects, and image_name_list contains corresponding names
avg_score = 0
for i in range(len(image_list)):
    image = image_list[i]
    text = image_name_list[i]
    score = get_clip_score(image, text)
    avg_score += score
    
print(f"AVG CLIP Score: {avg_score/len(image_list)}")


In [None]:
# Load the .npz file
data = np.load('/data/20231212/SwiftBrush_reproduce_final20231227/val2014_captions.npz')
captions = data['captions'][()]
count = 0
avg_score = 0
for case_number, caption in enumerate(captions):
    image = pipe(caption, num_inference_steps=25).images[0]  
    score = get_clip_score(image, caption)
    # image.save("/home/liutao/workspace/data/sd2_1_base/"+caption+".jpg")
    avg_score += score
    count += 1
    if count >= 30000:
        break
    if count % 10 == 0:
        print("current num:",count,"current avg clip score:",avg_score/count)
print(f"AVG CLIP Score: {avg_score/count}")
data.close()
