In [None]:
import os
import json
import numpy as np
from IPython.display import display
from tqdm.auto import tqdm, trange
from datasets import load_dataset, load_from_disk
from open_clip import get_tokenizer
from metrics import (
    load_perplexity_model_and_tokenizer,
    compute_prompt_perplexity,
    load_aesthetics_and_artifacts_models,
    compute_aesthetics_and_artifacts_scores,
)

# Utilities
clip_tokenizer = get_tokenizer("ViT-g-14")
ppl_models = load_perplexity_model_and_tokenizer()
aa_models = load_aesthetics_and_artifacts_models()

# Download and save to disk
dataset = load_dataset("poloclub/diffusiondb", "2m_random_100k")
dataset = dataset["train"]
dataset.save_to_disk(
    "/fs/nexus-scratch/mcding/Watermark-Experiments/datasets/source/diffusiondb/2m_random_100k"
)

# Filter generation setup and nsfw
dataset = dataset.filter(
    # filter width and height
    lambda x: (x["width"] == 512 and x["height"] == 512)
    # filter diffusion hyperparameters
    and (x["step"] == 50 and x["cfg"] == 7 and x["sampler"] == "k_lms")
    # filter nsfw
    and (x["image_nsfw"] < 0.2 and x["prompt_nsfw"] < 0.1)
)

# Delete unnecessary columns
dataset = dataset.remove_columns(
    [
        "seed",
        "step",
        "cfg",
        "sampler",
        "width",
        "height",
        "user_name",
        "timestamp",
        "image_nsfw",
        "prompt_nsfw",
    ]
)

# Filter and normalize prompts according to CLIP tokenizer
dataset = dataset.filter(
    lambda data: 0 < len(clip_tokenizer.encode(data["prompt"])) <= 75
)
normalized_prompts = []
for prompt in tqdm(dataset["prompt"]):
    normalized_prompts.append(clip_tokenizer.decode(clip_tokenizer.encode(prompt)))
dataset = dataset.add_column("normalized_prompt", normalized_prompts)
dataset = dataset.remove_columns(["prompt"])
dataset = dataset.rename_column("normalized_prompt", "prompt")

# Filter and remove duplicated prompts
unique_prompts = dataset.unique("prompt")
prompt_to_index = {}
for i, prompt in tqdm(enumerate(dataset["prompt"]), total=len(dataset)):
    if prompt not in prompt_to_index:
        prompt_to_index[prompt] = i
dataset = dataset.select(list(prompt_to_index.values()))
assert len(unique_prompts) == len(dataset)


# Add perplexity, aesthetic, and artifact scores
ppls = []
for prompt in tqdm(dataset["prompt"]):
    ppls.append(compute_prompt_perplexity(prompt, ppl_models))
aesthetics = []
artifacts = []
for image in tqdm(dataset["image"]):
    aesthetic, artifact = compute_aesthetics_and_artifacts_scores(image, aa_models)
    aesthetics.append(aesthetic)
    artifacts.append(artifact)
dataset = dataset.add_column("ppl", ppls)
dataset = dataset.add_column("aesthetic", aesthetics)
dataset = dataset.add_column("artifact", artifacts)


# Calculate the score for ranking
# Curently, the score is defined as: aesthetic + (10 - artifact)
# The perplexity is not included in the score because it is not a good indicator
key = np.array(dataset["aesthetic"]) + (10 - np.array(dataset["artifact"]))
dataset_appended = dataset.add_column("key", key)
dataset = dataset_appended.sort("key", reverse=True)


# Output the images and prompts as json
selected_size = 5000
prompt_dict = {}
for i in trange(selected_size):
    dataset[i]["image"].save(
        f"/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/real/{i}.png"
    )
    prompt_dict[str(i)] = dataset[i]["prompt"]
with open(
    "/fs/nexus-projects/HuangWM/datasets/main/diffusiondb/prompts.json",
    "w",
) as json_file:
    json.dump(prompt_dict, json_file, ensure_ascii=False, indent=4)


# Save the dataset to disk
dataset.save_to_disk(
    os.path.join(os.environ.get("DATASET_DIR"), "./selected/diffusiondb_from_100k")
)

In [5]:
import os
import numpy as np
from datasets import load_from_disk

dataset = load_from_disk(
    os.path.join(os.environ.get("DATASET_DIR"), "./selected/diffusiondb_temp4")
)
dataset = dataset.rename_column("artifacts", "artifact")
key = np.array(dataset["aesthetic"]) + (10 - np.array(dataset["artifact"]))
dataset_appended = dataset.add_column("key", key)
dataset = dataset_appended.sort("key", reverse=True)
print(dataset)
print(dataset[0])

Dataset({
    features: ['image', 'prompt', 'ppl', 'aesthetic', 'artifact', 'key'],
    num_rows: 28875
})
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x7F612203FF10>, 'prompt': 'digital painting of a lake at sunset suronded by forests and mountains , great lakes , fjords , sun reflecting on the water , open sky , clouds , fantasy art , concept art , video game art , pastel colours , volumetric lighting , highly detailed , artem cheboka , rhads , artstation , 4 k , 8 k ', 'ppl': 85.329345703125, 'aesthetic': 10.430047035217285, 'artifact': 0.3915136754512787, 'key': 20.038533359766006}


  table = cls._concat_blocks(blocks, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
