In [None]:
# !pip install transformers accelerate datasets diffusers Pillow==9.4.0

# Imagenet-1k-recaptioned with AR
* center cropping breaks a lot of images
* instead, define aspect ratios, including portrait and lanscape
* resize image to closest AR (mult. of 32px)
* aspect crop
* store in HF dataset, for each image include AR info
* remove augmentations again

In [15]:
import torch
import random
import os
from datasets import load_dataset, Dataset, DatasetDict
from diffusers import AutoencoderDC
# from torch.utils.data import DataLoader
from utils import make_grid, PIL_to_latent, latent_to_PIL, dcae_scalingf
from tqdm import tqdm

from utils_preprocess import resize, pad

In [None]:
# from local_secrets import hf_token
# from huggingface_hub import login
# login(token=hf_token)

# Load IN1k recaptions dataset

In [None]:
ds = load_dataset("visual-layer/imagenet-1k-vl-enriched", cache_dir="~/ssd-2TB/hf_cache")
ds

In [None]:
print("splits", ds.keys())
print("features", ds["train"].features.keys())

## Inspect augmentation before actually processing

In [None]:
## Test run
resizeTo = 256
split = "train"

ASPECT_RATIOS = [
    ("AR_1_to_1", 1),
    ("AR_4_to_3", 4/3),
    ("AR_3_to_4", 3/4),
]

for i in [random.randint(0, len(ds[split])) for _ in range(10)]:
    img=ds[split][i]["image"]
    label=ds[split][i]["caption_enriched"]

    print("image dimension", img.size)
    display(img)

    images = []
    images.append( pad(img).resize((256,256)) )
    images.append( pad(resize(img, resizeTo=resizeTo, ARs=ASPECT_RATIOS, debug=True)[1]).resize((256,256)) )

    print(label)
    display(make_grid(images))

# Load DCAE

In [None]:
model = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

dcae = AutoencoderDC.from_pretrained(model, subfolder="vae", torch_dtype=dtype).to(device)

# Batch augment and create dataset

In [21]:
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
    print("DDP run")
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    pass
else:
    print("Non DDP run")
    rank = 0
    world_size = 2

split="train"
# ddp
# ds["train"]
indices = list(range(len(ds[split])))
indices_rank = indices[rank : len(ds[split]) : world_size]
len(indices), len(indices_rank), len(indices_rank) * world_size

Non DDP run


(1281167, 640584, 1281168)

In [35]:
indices = list(range(20))
world_size=2
for rank in range(world_size):
    indices_rank = indices[rank : len(ds[split]) : world_size]
    print("rank",rank,indices_rank)

rank 0 [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
rank 1 [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]


In [31]:
[0:10:2]

SyntaxError: invalid syntax (4003291929.py, line 1)

In [None]:
test_run = True
resizeTo = 256
ASPECT_RATIOS = [
    ("AR_1_to_1", 1),
    ("AR_4_to_3", 4/3),
    ("AR_3_to_4", 3/4),
]
hf_dataset = "g-ronimo/IN1k256-AR-buckets-latents_dc-ae-f32c32-sana-1.0"
dcae_batch_size = 32
upload_every = 100_000
splits=["train", "validation"]
col_img="image"
col_label="caption_enriched"

def process_dcae_batch(batch):
    labels = batch["labels"]
    images = batch["images"]
    latents = PIL_to_latent(images, dcae).cpu()   
    
    return [
        dict(label=label, latent = latents[None,i])
        for i, label in enumerate(labels)
    ]

for split in splits:
    dataset_list = {}   # list of dicts per AR bucket, each containing list of {label=.., latent=..}
    parts_uploaded = {}   # dict of int per AR bucket
    dcae_batches = {}   # buffer, collect samples and batch process when full
    samples_uploaded = 0

    for i, d in tqdm(enumerate(ds[split]), total=len(ds[split]), desc=f"Processing split {split}"):
        img=d[col_img]
        label=d[col_label]
        ar_bucket, img = resize(img, resizeTo=resizeTo, ARs=ASPECT_RATIOS)

        # fill dcae-queue  
        if not ar_bucket in dcae_batches: dcae_batches[ar_bucket]={"labels": [], "images": []}
        dcae_batches[ar_bucket]["labels"].append(label)
        dcae_batches[ar_bucket]["images"].append(img)
        del ar_bucket

        # process batch if full or at the end
        ar_buckets = list(dcae_batches.keys())
        for ar_bucket in ar_buckets:
            target_split = f"{split}_{ar_bucket}"     # name of split the images of this batch belong to

            if (
                # batch is full -> process
                (len(dcae_batches[ar_bucket]["labels"]) >= dcae_batch_size)
                or 
                # batch is not full but we reached end of dataset -> process
                (i == len(ds[split])-1 and len(dcae_batches[ar_bucket]["labels"]) > 0)
            ):
                if target_split not in dataset_list: 
                    dataset_list[target_split] = []                
                latents = process_dcae_batch(dcae_batches[ar_bucket])
                dataset_list[target_split].extend(latents)

                # empty the dcae batch we just processed
                dcae_batches[ar_bucket]={"labels": [], "images": []}

        # upload to HF if we gathered more than upload_every OR reached the end 
        target_splits = list(dataset_list.keys())
        for target_split in target_splits:
            if (
                # processed enough -> upload
                (len(dataset_list[target_split]) >= upload_every)
                or 
                # reached end of dataset -> upload
                (i == len(ds[split])-1 and len(dataset_list[target_split]) > 0)
            ):
                if target_split not in parts_uploaded: 
                    parts_uploaded[target_split]=0
                if not test_run:
                    Dataset.from_list(dataset_list[target_split]).push_to_hub(
                        hf_dataset, 
                        split=f"{target_split}.part_{parts_uploaded[target_split]}", 
                        num_shards=1
                    )
                parts_uploaded[target_split]+=1
                samples_uploaded += len(dataset_list[target_split])
                print("Uploaded",len(dataset_list[target_split]), "samples of split", target_split, "part", parts_uploaded[target_split])
                dataset_list[target_split]=[]   

    print("split", split, "total samples uploaded:", samples_uploaded)

In [None]:
# dataset_list.keys()

In [None]:
# # check a few samples
# num_samples = 12
# dataset = dataset_list
# for split in dataset:
#     print("split", split)
#     for idx in [random.randint(0, len(dataset[split])-1) for _ in range(num_samples)]:
#         latent = torch.Tensor(dataset[split][idx]["latent"])
#         label = dataset[split][idx]["label"]
#         print(label, latent.shape)
#         display(
#             make_grid(
#                 [latent_to_PIL(latent.to(dcae.dtype).to(dcae.device), dcae)]
#             )
#         )