In [None]:
import os
from urllib.request import urlretrieve
import pandas as pd
from datasets import load_dataset

table_url = f'https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/metadata-large.parquet'

pbar = None

if not os.path.exists('metadata.parquet'):
    print("retrieving metadata file")
    urlretrieve(table_url, 'metadata.parquet')
# Read the table using Pandas\n",
metadata_df = pd.read_parquet('metadata.parquet')

In [None]:
import importlib.util
spec = importlib.util.find_spec("en_core_web_trf")
if spec is None:
    print("Installing en_core_web_trf")
    ! pip install https://huggingface.co/spacy/en_core_web_trf/resolve/main/en_core_web_trf-any-py3-none-any.whl

In [None]:
import en_core_web_trf
import spacy
nlp = en_core_web_trf.load()

In [None]:
metadata_df["prompt"].head(10)

In [None]:

n_samples = 100_000
first_n_unique_prompts = metadata_df["prompt"].sample(n=n_samples, random_state=42).drop_duplicates().head(n_samples)
display(first_n_unique_prompts.head(5).tolist())
display(first_n_unique_prompts.shape)

In [None]:
from datasets import Dataset, Features
from spacy import displacy
from spacy.symbols import nsubj, VERB


def process(batch):
    out = {
        "subject": [],
        "descriptor": [],
    }
    for prompt in batch["prompt"]:
        doc = nlp(prompt)
        # displacy.render(doc, style="dep")
        subject_tokens, descriptor_tokens = [], []
        # find the first chunk with either an entity or a proper noun.
        subject_found = False
        for chunk in doc.noun_chunks:
            if subject_found:
                descriptor_tokens.append(chunk.text)
            else:
                proper_nouns = [token for token in chunk if token.pos_ == "PROPN"]
                proper_ents, non_proper_ents = [], []
                for ent in chunk.ents:
                    if ent.label_ == "PERSON" or ent.label_ == "ORG":
                        proper_ents.append(ent)
                    else:
                        non_proper_ents.append(ent)
                subject_tokens.append(chunk.root.text)
                if len(non_proper_ents) > 0:
                    subject_tokens.append(chunk.text)
                    subject_found = True
                elif len(proper_nouns) > 0 and len(proper_ents) == 0:
                    subject_tokens.append(chunk.text)
                    subject_found = True

        # print("token deps")
        subject_tokens = [
            tok for i, tok in enumerate(subject_tokens) if tok not in subject_tokens[:i]
        ]
        out["subject"].append(" ".join(subject_tokens))
        out["descriptor"].append(" ".join(descriptor_tokens))
    return out


# display([(p, process(p)) for p in [
#     "stunning goddess of beers portrait, clear eyes and dark skin. realistic, symmetrical face. art by bowater charlie, mark brooks, julie bell, arian mark, tony sandoval "
# ]])
display([(p, process({"prompt": [p]})) for p in first_n_unique_prompts[:10]])

In [None]:
prompt_only_df = first_n_unique_prompts.to_frame()
dataset = Dataset.from_pandas(prompt_only_df, preserve_index=False)
dataset

In [None]:
dataset = dataset.map(process, batched=True, batch_size=512, remove_columns=["prompt"])

In [None]:
from huggingface_hub import login

display(dataset)
login("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", add_to_git_credential=True)
dataset.push_to_hub("roborovski/diffusiondb-seq2seq")
