In [129]:
from datasets import load_dataset
import pandas as pd

dataset = load_dataset("DIBT/img_prefs_style", cache_dir="data")
df = dataset["train"].to_pandas()

Generating train split: 100%|██████████| 5/5 [00:00<00:00, 502.30 examples/s]


In [131]:
prompt_columns = df.columns.where(df.columns.str.endswith("_prompt")).dropna()
df["generation"] = df[prompt_columns].apply(lambda row: row.values, axis=1)
df = df.drop(columns=prompt_columns)

In [None]:
categories = list(df.category.unique())
sub_categories = list(df.subcategory.unique())

In [None]:
import argilla as rg

client = rg.Argilla(
    api_url="https://dibt-image-preferences-argilla.hf.space",
    api_key="_api_key_",
)

In [152]:
from importlib import metadata


settings = rg.Settings(
    fields=[
        rg.CustomField("images", template="template.html"),
    ],
    questions=[
        rg.LabelQuestion(
            name="preference",
            description="Which image do you prefer given the prompt? ",
            labels=["image_1", "image_2", "both_good", "both_bad"],
        ),
    ],
    metadata=[
        rg.TermsMetadataProperty(
            name="model_1",
            options=[
                "dev",
                "schnell",
            ]
        ),
        rg.TermsMetadataProperty(
            name="model_2",
            options=[
                "dev",
                "schnell",
            ]
        ),
        rg.TermsMetadataProperty(
            name="evolution",
            options=[
                "style",
                "quality",
            ]
        ),
    ],
    allow_extra_metadata=True,
)
dataset = rg.Dataset(
    name="image_preferences",
    settings=settings,
)
dataset.create()



Dataset(id=UUID('fcd5d89c-6b67-4d98-8b87-2275c22a6f93') inserted_at=datetime.datetime(2024, 11, 11, 17, 54, 26, 308632) updated_at=datetime.datetime(2024, 11, 11, 17, 54, 27, 58266) name='image_preferences' status='ready' guidelines=None allow_extra_metadata=True distribution=OverlapTaskDistributionModel(strategy='overlap', min_submitted=1) workspace_id=UUID('943fe7cc-eebc-4da7-ba9d-da9f8a06a37a') last_activity_at=datetime.datetime(2024, 11, 11, 17, 54, 27, 58266))

In [None]:
from importlib import metadata
import requests
from io import BytesIO
from PIL import Image

dataset_name = "DIBT/img_prefs_style"
output_dir = "images"


def make_image_from_url(image_url):
    base_url = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/"
    full_url = base_url + image_url
    return full_url


records = []
for i, row in df.iterrows():

    for source, generation in zip(prompt_columns, row["generation"]):
        evolution = source.split("_")[0]
        image_columns = df.columns.where(
            df.columns.str.startswith(f"image_{evolution}")
        ).dropna()
        image_1 = row[image_columns[0]]["path"]
        image_2 = row[image_columns[1]]["path"]
        image_1 = make_image_from_url(image_1)
        image_2 = make_image_from_url(image_2)
        model_1 = image_columns[0].split("_")[-1]
        model_2 = image_columns[1].split("_")[-1]

        record = rg.Record(
            fields={
                "images": {
                    "image_1": image_1,
                    "image_2": image_2,
                    "prompt": generation,
                }
            },
            metadata={
                "model_1": model_1,
                "model_2": model_2,
                "evolution": evolution,
                "category": row.category,
                "sub_category": row.subcategory,
            },
        )
        records.append(record)

In [154]:
dataset.records.log(records)

Sending records...: 100%|██████████| 1/1 [00:00<00:00,  4.20batch/s]


DatasetRecords(Dataset(id=UUID('fcd5d89c-6b67-4d98-8b87-2275c22a6f93') inserted_at=datetime.datetime(2024, 11, 11, 17, 54, 26, 308632) updated_at=datetime.datetime(2024, 11, 11, 17, 54, 27, 58266) name='image_preferences' status='ready' guidelines=None allow_extra_metadata=True distribution=OverlapTaskDistributionModel(strategy='overlap', min_submitted=1) workspace_id=UUID('943fe7cc-eebc-4da7-ba9d-da9f8a06a37a') last_activity_at=datetime.datetime(2024, 11, 11, 17, 54, 27, 58266)))