In [None]:
import os
import json
import numpy as np
from IPython.display import display
from tqdm.auto import tqdm, trange
from torchvision.datasets import CocoCaptions
from torchvision.transforms import Compose, CenterCrop, Resize
from datasets import load_dataset, load_from_disk, Dataset
from open_clip import get_tokenizer
from metrics import (
    load_perplexity_model_and_tokenizer,
    compute_prompt_perplexity,
    load_aesthetics_and_artifacts_models,
    compute_aesthetics_and_artifacts_scores,
)

# Utilities
clip_tokenizer = get_tokenizer("ViT-g-14")
ppl_models = load_perplexity_model_and_tokenizer()
aa_models = load_aesthetics_and_artifacts_models()


# coco_dataset = CocoCaptions(
#     root="/fs/nexus-projects/HuangWM/datasets/source/MSCOCO/val2017/",
#     annFile="/fs/nexus-projects/HuangWM/datasets/source/MSCOCO/annotations/captions_val2017.json",
# )


# # Convert to HuggingFace Dataset
# def gen():
#     for i in trange(len(coco_dataset)):
#         image = dataset[i][0]
#         transform = Compose([CenterCrop(min(image.size)), Resize((512, 512))])
#         prompt = dataset[i][1][0]
#         # Only use the first caption
#         yield {"image": transform(image), "prompt": prompt}


# dataset = Dataset.from_generator(gen)
# dataset.save_to_disk("/fs/nexus-projects/HuangWM/datasets/source/mscoco_5k")

dataset = load_from_disk("/fs/nexus-projects/HuangWM/datasets/source/mscoco_5k")

In [None]:
# Filter and normalize prompts according to CLIP tokenizer
# For MSCOCO, this does not filter out any data
dataset = dataset.filter(
    lambda data: 0 < len(clip_tokenizer.encode(data["prompt"])) <= 75
)
normalized_prompts = []
for prompt in tqdm(dataset["prompt"]):
    normalized_prompts.append(clip_tokenizer.decode(clip_tokenizer.encode(prompt)))
dataset = dataset.add_column("normalized_prompt", normalized_prompts)
dataset = dataset.remove_columns(["prompt"])
dataset = dataset.rename_column("normalized_prompt", "prompt")

# Add perplexity, aesthetic, and artifact scores
ppls = []
for prompt in tqdm(dataset["prompt"]):
    ppls.append(compute_prompt_perplexity(prompt, ppl_models))
aesthetics = []
artifacts = []
for image in tqdm(dataset["image"]):
    aesthetic, artifact = compute_aesthetics_and_artifacts_scores(image, aa_models)
    aesthetics.append(aesthetic)
    artifacts.append(artifact)
dataset = dataset.add_column("ppl", ppls)
dataset = dataset.add_column("aesthetic", aesthetics)
dataset = dataset.add_column("artifact", artifacts)

# Calculate the score for ranking
# Curently, the score is defined as: aesthetic + (10 - artifact)
# The perplexity is not included in the score because it is not a good indicator
# For MSCOCO, this does change the final images used, since all 5k images are used
key = np.array(dataset["aesthetic"]) + (10 - np.array(dataset["artifact"]))
dataset_appended = dataset.add_column("key", key)
dataset = dataset_appended.sort("key", reverse=True)

In [None]:
# Output the images and prompts as json
selected_size = 5000
prompt_dict = {}
for i in trange(selected_size):
    dataset[i]["image"].save(
        f"/fs/nexus-projects/HuangWM/datasets/main/mscoco/real/{i}.png"
    )
    prompt_dict[str(i)] = dataset[i]["prompt"]
with open(
    "/fs/nexus-projects/HuangWM/datasets/main/mscoco/prompts.json",
    "w",
) as json_file:
    json.dump(prompt_dict, json_file, ensure_ascii=False, indent=4)


# Save the dataset to disk
dataset.save_to_disk("/fs/nexus-projects/HuangWM/datasets/source/mscoco_5k_ranked")