In [1]:
import torch
import json
import clip
from PIL import Image
import os
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from torchmetrics.functional.multimodal import clip_score
from functools import partial
from pipeline_rf import RectifiedFlowPipeline
import random
from diffusers import AutoPipelineForText2Image
import json

random.seed(2024)
device = "cuda:0"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import hashlib

def text_to_hash(text):
    # Create a new SHA-256 hash object
    hash_object = hashlib.sha256()

    # Update the hash object with the bytes representation of the text
    hash_object.update(text.encode('utf-8'))

    # Get the hexadecimal representation of the hash
    hash_value = hash_object.hexdigest()

    return hash_value

### load model

In [None]:
clip.available_models()

In [3]:
clip_model, clip_preprocess = clip.load('ViT-L/14@336px')
clip_model = clip_model.to(device)

In [None]:
model_id = "/data/model/stable-diffusion-2-1"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=True)

In [None]:
instaflow_pipe = RectifiedFlowPipeline.from_pretrained("/data/model/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float32, safety_checker=None, requires_safety_checker=False) 
### switch to torch.float32 for higher quality

instaflow_pipe.to(device)  ### if GPU is not available, comment this line
instaflow_pipe.set_progress_bar_config(disable=True)

In [None]:
image = instaflow_pipe(prompt="caption", num_inference_steps=1, guidance_scale=0.0).images[0]

In [2]:
sdxl_turbo_pipe = AutoPipelineForText2Image.from_pretrained("/data/model/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
sdxl_turbo_pipe.to(device)
sdxl_turbo_pipe.set_progress_bar_config(disable=True)

Loading pipeline components...: 100%|██████████| 7/7 [06:04<00:00, 52.09s/it] 


In [3]:
lcm_pipe = DiffusionPipeline.from_pretrained("/data/model/LCM_Dreamshaper_v7", safety_checker=None, requires_safety_checker=False)
# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
lcm_pipe.to(torch_device=device, torch_dtype=torch.float32)
lcm_pipe.set_progress_bar_config(disable=True)

Loading pipeline components...: 100%|██████████| 6/6 [00:03<00:00,  1.71it/s]
  deprecate("torch_dtype", "0.27.0", "")
  deprecate("torch_device", "0.27.0", "")


In [None]:
image = lcm_pipe(prompt="caption", height=512, width=512, num_inference_steps=1, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images[0]
image

## clip score

In [4]:
def get_clip_score(image, text):
    # Load the pre-trained CLIP model and the image

    # Preprocess the image and tokenize the text
    image_input = clip_preprocess(image).unsqueeze(0)
    text_input = clip.tokenize([text], truncate=True)
    
    # Move the inputs to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    image_input = image_input.to(device)
    text_input = text_input.to(device)
    
    # Generate embeddings for the image and text
    with torch.no_grad():
        image_features = clip_model.encode_image(image_input)
        text_features = clip_model.encode_text(text_input)
    
    # Normalize the features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    # Calculate the cosine similarity to get the CLIP score
    clip_score = torch.matmul(image_features, text_features.T).item()
    
    return clip_score

### image with prompt test

In [None]:
# Path to the folder containing your images
folder_path = "/home/liutao/workspace/data/ours_coco30k_test"

total_score = 0
count = 0

# Loop through each file in the folder
for filename in os.listdir(folder_path):
    # Check if the file is an image (you can customize the extension check)
    if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
        # Load the image
        image_path = os.path.join(folder_path, filename)
        image = Image.open(image_path)
        text = os.path.splitext(filename)[0]
        score = get_clip_score(image, text)
        total_score += score
        count += 1
        if count % 1000 == 0:
            print("current num:", count,f"current AVG CLIP Score: {total_score/count}")

print(f"AVG CLIP Score: {total_score/count}") 


In [None]:
# Load the .npz file
# data = np.load('/data/20231212/SwiftBrush_reproduce_final20231227/val2014_captions.npz')
# captions = data['captions'][()]
# print(len(captions),captions[0],captions[1])
# data.close()

### load coco30k_caption

In [5]:
coco_f = open('/data/dataset/coco2014-val/annotations/captions_val2014.json')
coco_annotations = json.load(coco_f)
captions = []
for annotation in coco_annotations['annotations']:
    caption = annotation['caption']
    captions.append(caption)
coco_f.close()
print(len(captions),captions[0],captions[1])

202654 A bicycle replica with a clock as the front wheel. A black Honda motorcycle parked in front of a garage.


In [6]:
random.seed(2024)
captions_30k = random.choices(captions, k=30000)
print(len(captions_30k),captions_30k[0],captions_30k[1])

30000 a little green cart filled with assorted suitcases  A woman in an odd outfit on a bed


### load vaild_caption

In [None]:
# Specify the path to your JSONL file
jsonl_file_path = '/data/20231212/SwiftBrush_reproduce_se_parallel/JourneyDB/valid/valid_prompt.jsonl'
# Initialize an empty list to store the prompts
prompts_list = []
with open(jsonl_file_path) as f:
    d = json.load(f)
    for line in d:
        prompts_list.append(line)
print(len(prompts_list),prompts_list[0])

In [None]:
random.seed(2024)
captions_30k = random.choices(prompts_list, k=30000)
print(len(captions_30k),captions_30k[0],captions_30k[1])

### ours

In [7]:
from generate import generate_single_image, load_model
vae, tokenizer, text_encoder, unet, scheduler, alphas = load_model("/data/", "/data/20231212/SwiftBrush_reproduce_se_parallel/checkpoints/vsd_global_step54000-coco.pth", device)
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = generate_single_image(network=(vae, tokenizer, text_encoder, unet, scheduler),prompt=caption,seed=2024)
    score = get_clip_score(image, caption)
    try:
        image.save("/home/liutao/workspace/data/ours_4_coco30k/"+str(count)+".jpg")
    except:
        print(caption)
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

[INFO] loading student unet checkpoint
current num: 1000 current avg clip score: 0.2511959228515625
current num: 2000 current avg clip score: 0.251618408203125
current num: 3000 current avg clip score: 0.2516767578125
current num: 4000 current avg clip score: 0.25137857055664065
current num: 5000 current avg clip score: 0.251731982421875
current num: 6000 current avg clip score: 0.2516671346028646
current num: 7000 current avg clip score: 0.25156898716517856
current num: 8000 current avg clip score: 0.2519543151855469
current num: 9000 current avg clip score: 0.2520390353732639
current num: 10000 current avg clip score: 0.25204867553710936
current num: 11000 current avg clip score: 0.2521516945578835
current num: 12000 current avg clip score: 0.25220989481608075
current num: 13000 current avg clip score: 0.2521011915940505
current num: 14000 current avg clip score: 0.2520157950265067
current num: 15000 current avg clip score: 0.2520740844726562
current num: 16000 current avg clip score

In [8]:
from generate import generate_single_image, load_model
vae, tokenizer, text_encoder, unet, scheduler, alphas = load_model("/data/", "/data/20231212/SwiftBrush_reproduce_se_parallel/checkpoints/vsd_global_step54000-mixcoco.pth", device)
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = generate_single_image(network=(vae, tokenizer, text_encoder, unet, scheduler),prompt=caption,seed=2024)
    score = get_clip_score(image, caption)
    try:
        image.save("/home/liutao/workspace/data/ours_5_coco30k/"+str(count)+".jpg")
    except:
        print(caption)
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

[INFO] loading student unet checkpoint


current num: 1000 current avg clip score: 0.248734375
current num: 2000 current avg clip score: 0.2490621337890625
current num: 3000 current avg clip score: 0.2491714884440104
current num: 4000 current avg clip score: 0.24889283752441407
current num: 5000 current avg clip score: 0.24929661865234376
current num: 6000 current avg clip score: 0.24908493041992188
current num: 7000 current avg clip score: 0.24897997174944198
current num: 8000 current avg clip score: 0.24921016693115233
current num: 9000 current avg clip score: 0.24911797417534723
current num: 10000 current avg clip score: 0.24897572021484374
current num: 11000 current avg clip score: 0.2489399746981534
current num: 12000 current avg clip score: 0.24896912638346355
current num: 13000 current avg clip score: 0.24885927170973557
current num: 14000 current avg clip score: 0.24883101981026787
current num: 15000 current avg clip score: 0.2488192626953125
current num: 16000 current avg clip score: 0.2488178024291992
current num: 1

### instalflow

In [None]:
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = instaflow_pipe(prompt=caption, num_inference_steps=1, guidance_scale=0.0).images[0]
    score = get_clip_score(image, caption)
    try:
        image.save("/home/liutao/workspace/data/instaflow_coco30k/"+caption+".jpg")
    except:
        print(caption)
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

### sdxl_turbo

In [None]:
#smallset_test >>> instaflow:0.26 sd_1_step:0.138 sd_25_step:0.22
#instaflow coco30k clip_socre: 0.2580452107747396
#sdxl_turbo_4_step coco30K clip_socre: 0.27137984619140626 
#sdxl_turbo_1_step coco30K clip_socre: 0.2724981628417969 
#lcm coco30k_4_step clip_socre: 
#lcm coco30k_1_step clip_socre: 
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = sdxl_turbo_pipe(prompt=caption, num_inference_steps=4, guidance_scale=0.0).images[0]
    score = get_clip_score(image, caption)
    try:
        image.save("/home/liutao/workspace/data/sdxl_turbo_4_step_coco30K/"+caption+".jpg")
    except:
        print(caption)
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

In [None]:
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = sdxl_turbo_pipe(prompt=caption, num_inference_steps=1, guidance_scale=0.0).images[0]
    score = get_clip_score(image, caption)
    try:
        image.save("/home/liutao/workspace/data/sdxl_turbo_1_step_coco30k/"+caption+".jpg")
    except:
        print(caption)
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

### lcm

In [None]:
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = lcm_pipe(prompt=caption, height=512, width=512, num_inference_steps=4, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil", truncation=True).images[0]
    score = get_clip_score(image, caption)
    print("B")
    try:
        image.save("/home/liutao/workspace/data/prompt_test_lcm4/"+text_to_hash(caption)+".jpg")
    except:
        print(caption)
    total_score += score
    count += 1
    if count % 100 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

In [None]:
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = lcm_pipe(prompt=caption, num_inference_steps=1, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images[0]
    score = get_clip_score(image, caption)
    # try:
    #     image.save("/home/liutao/workspace/data/lcm_1_step_coco30K/"+caption+".jpg")
    # except:
    #     print(caption)
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

In [None]:
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = lcm_pipe(prompt=caption, height=512, width=512, num_inference_steps=4, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images[0]
    score = get_clip_score(image, caption)
    try:
        image.save("/home/liutao/workspace/data/lcm_4_step_512_coco30K/"+caption+".jpg")
    except:
        print(caption)
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

In [None]:
count = 0
total_score = 0
for case_number, caption in enumerate(captions_30k):
    image = lcm_pipe(prompt=caption, height=512, width=512, num_inference_steps=1, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images[0]
    score = get_clip_score(image, caption)
    try:
        image.save("/home/liutao/workspace/data/lcm_1_step_512_coco30K/"+caption+".jpg")
    except:
        print(caption)
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

## hpsv2 score

In [4]:
import hpsv2

Failed to get repository contents: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/datasets/zhwang/HPDv2 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fc93d768f70>, 'Connection to huggingface.co timed out. (connect timeout=None)'))


In [3]:

from generate import generate_single_image, load_model
vae, tokenizer, text_encoder, unet, scheduler, alphas = load_model("/data/", "/data/20231212/SwiftBrush_reproduce_se_parallel/checkpoints/vsd_global_step54000-coco.pth", device)
# Get benchmark prompts (<style> = all, anime, concept-art, paintings, photo)
all_prompts = hpsv2.benchmark_prompts('all') 
path = "/home/liutao/workspace/data/ours_4_hpsv2"
# Iterate over the benchmark prompts to generate images
for style, prompts in all_prompts.items():
    for idx, prompt in enumerate(prompts):
        image = generate_single_image(network=(vae, tokenizer, text_encoder, unet, scheduler),prompt=prompt,seed=2024)
        # TextToImageModel is the model you want to evaluate
        image.save(os.path.join(path, style, f"{idx:05d}.jpg")) 
hpsv2.evaluate(path) 

Failed to get repository contents: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/datasets/zhwang/HPDv2 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7f41deb4bcd0>, 'Connection to huggingface.co timed out. (connect timeout=None)'))
[INFO] loading student unet checkpoint
Loading model ...
Loading model successfully!




-----------benchmark score ---------------- 
ours_4_hpsv2 anime           25.76 	 0.1923
ours_4_hpsv2 concept-art     25.41 	 0.0901
ours_4_hpsv2 photo           26.27 	 0.2177
ours_4_hpsv2 paintings       25.42 	 0.1553


In [4]:
import hpsv2
from generate import generate_single_image, load_model
vae, tokenizer, text_encoder, unet, scheduler, alphas = load_model("/data/", "/data/20231212/SwiftBrush_reproduce_se_parallel/checkpoints/vsd_global_step54000-mixcoco.pth", device)
# Get benchmark prompts (<style> = all, anime, concept-art, paintings, photo)
all_prompts = hpsv2.benchmark_prompts('all') 
path = "/home/liutao/workspace/data/ours_5_hpsv2"
# Iterate over the benchmark prompts to generate images
for style, prompts in all_prompts.items():
    for idx, prompt in enumerate(prompts):
        image = generate_single_image(network=(vae, tokenizer, text_encoder, unet, scheduler),prompt=prompt,seed=2024)
        # TextToImageModel is the model you want to evaluate
        image.save(os.path.join(path, style, f"{idx:05d}.jpg")) 
hpsv2.evaluate(path) 

[INFO] loading student unet checkpoint
Loading model ...
Loading model successfully!
-----------benchmark score ---------------- 
ours_5_hpsv2 anime           26.10 	 0.1443
ours_5_hpsv2 concept-art     25.77 	 0.1013
ours_5_hpsv2 photo           26.97 	 0.1935
ours_5_hpsv2 paintings       25.90 	 0.1088


### sdxl_4

In [5]:
# Get benchmark prompts (<style> = all, anime, concept-art, paintings, photo)
all_prompts = hpsv2.benchmark_prompts('all') 
path = "/home/liutao/workspace/data/sdxl4_hpsv2"
# Iterate over the benchmark prompts to generate images
for style, prompts in all_prompts.items():
    for idx, prompt in enumerate(prompts):
        image = sdxl_turbo_pipe(prompt=prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
        image.save(os.path.join(path, style, f"{idx:05d}.jpg")) 

hpsv2.evaluate(path) 

Token indices sequence length is longer than the specified maximum sequence length for this model (80 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['yamamoto.']
Token indices sequence length is longer than the specified maximum sequence length for this model (80 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['yamamoto.']
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['tooth wu, wlop, beeple, and greg rutkowski.']
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['tooth wu, wlop, beeple, and greg rutkowski.']


Loading model ...
Loading model successfully!




-----------benchmark score ---------------- 
sdxl4_hpsv2 anime           28.65 	 0.1071
sdxl4_hpsv2 concept-art     27.84 	 0.0932
sdxl4_hpsv2 photo           27.88 	 0.2003
sdxl4_hpsv2 paintings       27.97 	 0.1341


### lcm4

In [6]:
# Get benchmark prompts (<style> = all, anime, concept-art, paintings, photo)
all_prompts = hpsv2.benchmark_prompts('all') 
path = "/home/liutao/workspace/data/lcm4_hpsv2"
for style, prompts in all_prompts.items():
    for idx, prompt in enumerate(prompts):
        image = lcm_pipe(prompt=prompt, width=512, height=512, num_inference_steps=4, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images[0]
        image.save(os.path.join(path, style, f"{idx:05d}.jpg")) 

hpsv2.evaluate(path) 

Token indices sequence length is longer than the specified maximum sequence length for this model (80 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['yamamoto.']
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: ['tooth wu, wlop, beeple, and greg rutkowski.']


Loading model ...
Loading model successfully!
-----------benchmark score ---------------- 
lcm4_hpsv2 anime           26.66 	 0.0932
lcm4_hpsv2 concept-art     26.22 	 0.0851
lcm4_hpsv2 photo           26.34 	 0.1967
lcm4_hpsv2 paintings       26.26 	 0.1468


## fid score