In [1]:
import torch
import json
import clip
import os
import torch
import random
import json
from generate import generate_single_image, load_model
from dataclasses import dataclass


  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


### config

In [2]:
@dataclass
class TestConfig:
    device = "cuda:1"
    ours_steps = 8
    hpsv2_sub_dirs = ["anime","photo","concept-art","paintings"]
    ours_hpsv2_path = "/data/liutao/data/ours_44k_8s_kl_hpsv2"
    ours_coco_path = "/data/liutao/data/ours_44k_8s_kl_coco"
    clip_model_id = "ViT-L/14@336px"
    ours_model_id = "/data/20231212/SwiftBrush_reproduce_se_parallel/checkpoints_klloss/vsd_global_step44000_8nis_kl.pth"
    ours_base_path = "/data/"
    coco_caption_path = "/data/dataset/coco2014-val/annotations/captions_val2014.json"
    caption_num = 30000
    seed = 2024
config = TestConfig()

In [3]:
if not os.path.exists(config.ours_coco_path):
    os.makedirs(config.ours_coco_path)
if not os.path.exists(config.ours_hpsv2_path):
    os.makedirs(config.ours_hpsv2_path)
for sub_dir in config.hpsv2_sub_dirs:
    sub_dir_path = os.path.join(config.ours_hpsv2_path,sub_dir)
    if not os.path.exists(sub_dir_path):
        os.makedirs(sub_dir_path)

### load model

In [4]:
torch.cuda.empty_cache()
clip_model, clip_preprocess = clip.load(config.clip_model_id)
clip_model = clip_model.to(config.device)

In [5]:
vae, tokenizer, text_encoder, unet, scheduler, alphas = load_model(config.ours_base_path, config.ours_model_id, config.device)

[INFO] loading student unet checkpoint


  torch.utils._pytree._register_pytree_node(


## clip score

In [6]:
def get_clip_score(image, text):
    # Load the pre-trained CLIP model and the image

    # Preprocess the image and tokenize the text
    image_input = clip_preprocess(image).unsqueeze(0)
    text_input = clip.tokenize([text], truncate=True)
    
    # Move the inputs to GPU if available
    image_input = image_input.to(config.device)
    text_input = text_input.to(config.device)
    
    # Generate embeddings for the image and text
    with torch.no_grad():
        image_features = clip_model.encode_image(image_input)
        text_features = clip_model.encode_text(text_input)
    
    # Normalize the features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    # Calculate the cosine similarity to get the CLIP score
    clip_score = torch.matmul(image_features, text_features.T).item()
    
    return clip_score

### load coco30k_caption

In [7]:
coco_f = open(config.coco_caption_path)
coco_annotations = json.load(coco_f)
captions = []
for annotation in coco_annotations['annotations']:
    caption = annotation['caption']
    captions.append(caption)
coco_f.close()
random.seed(config.seed)
captions_30k = random.choices(captions, k=config.caption_num)
print(len(captions_30k),captions[0],captions[1])

30000 A bicycle replica with a clock as the front wheel. A black Honda motorcycle parked in front of a garage.


In [8]:
count = 0
total_score = 0
for case_number, prompt in enumerate(captions_30k):
    image = generate_single_image(network=(vae, tokenizer, text_encoder, unet, scheduler),prompt=prompt,seed=2024,num_inference_steps=config.ours_steps)
    score = get_clip_score(image, prompt)
    save_name = str(count)+".jpg"
    image.save(os.path.join(config.ours_coco_path,save_name))
    total_score += score
    count += 1
    if count % 1000 == 0:
        print("current num:",count,"current avg clip score:",total_score/count)
print(f"AVG CLIP Score: {total_score/count}")

current num: 1000 current avg clip score: 0.2618109130859375
current num: 2000 current avg clip score: 0.26182708740234373
current num: 3000 current avg clip score: 0.26194840494791666
current num: 4000 current avg clip score: 0.26146597290039064
current num: 5000 current avg clip score: 0.26180556640625
current num: 6000 current avg clip score: 0.26171663411458335
current num: 7000 current avg clip score: 0.2617603236607143
current num: 8000 current avg clip score: 0.26200048828125
current num: 9000 current avg clip score: 0.2617371554904514
current num: 10000 current avg clip score: 0.2616024291992187
current num: 11000 current avg clip score: 0.2615211181640625
current num: 12000 current avg clip score: 0.2614950968424479
current num: 13000 current avg clip score: 0.2614385235126202
current num: 14000 current avg clip score: 0.2613802228655134
current num: 15000 current avg clip score: 0.2614111490885417
current num: 16000 current avg clip score: 0.26129731750488283
current num: 170

## hpsv2 score

In [9]:
import hpsv2
all_prompts = hpsv2.benchmark_prompts('all') 
for style, prompts in all_prompts.items():
    for idx, prompt in enumerate(prompts):
        image = generate_single_image(network=(vae, tokenizer, text_encoder, unet, scheduler),prompt=prompt,seed=2024,num_inference_steps=config.ours_steps)
        image.save(os.path.join(config.ours_hpsv2_path, style, f"{idx:05d}.jpg")) 

Failed to get repository contents: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/datasets/zhwang/HPDv2 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fe1dd777070>, 'Connection to huggingface.co timed out. (connect timeout=None)'))


In [10]:
hpsv2.evaluate(config.ours_hpsv2_path) 

Loading model ...
Loading model successfully!
-----------benchmark score ---------------- 
ours_44k_8s_kl_hpsv2 paintings       26.11 	 0.1047
ours_44k_8s_kl_hpsv2 anime           26.46 	 0.1735
ours_44k_8s_kl_hpsv2 photo           26.39 	 0.2102
ours_44k_8s_kl_hpsv2 concept-art     26.00 	 0.1058
