# micro diffusion

https://huggingface.co/datasets/JourneyDB/JourneyDB

In [None]:
!nvidia-smi # CUDA Version 13.0

In [None]:
!python --version # 3.12.11

In [None]:
%pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu128

In [None]:
%pip install -U \
    accelerate \
    diffusers \
    huggingface_hub \
    torch==2.7.0 \
    torchvision \
    transformers \
    timm \
    open_clip_torch \
    easydict \
    einops \
    mosaicml-streaming \
    torchmetrics \
    tqdm \
    pandas \
    fastparquet \
    omegaconf \
    datasets \
    hydra-core \
    beautifulsoup4

In [None]:
%pip install "mosaicml[tensorboard, wandb]"

In [None]:
import os

# パスを指定
USER_ROOT = os.path.expanduser("~")
CACHE_DIR = os.path.join(USER_ROOT, ".cache", "micro_diffusion")
DATA_DIR = os.path.join(CACHE_DIR, "data")
MODEL_DIR = os.path.join(CACHE_DIR, "models")

COMPRESSED_DIR = os.path.join(DATA_DIR, 'compressed')
RAW_DIR = os.path.join(DATA_DIR, 'raw')

TRAIN_IMGS_DIR = os.path.join(RAW_DIR, 'train', 'imgs')
VALID_IMGS_DIR = os.path.join(RAW_DIR, 'valid', 'imgs')
TEST_DIR = os.path.join(RAW_DIR, 'test')

# ディレクトリ作成
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(COMPRESSED_DIR, exist_ok=True)
os.makedirs(RAW_DIR, exist_ok=True)
os.makedirs(TRAIN_IMGS_DIR, exist_ok=True)
os.makedirs(VALID_IMGS_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)

In [None]:
import os
import shutil
import argparse
import subprocess
import numpy as np
from glob import iglob
from multiprocessing import Pool
from torchvision import transforms
from huggingface_hub import hf_hub_download
from PIL import Image, UnidentifiedImageError

In [None]:
def download_and_process_metadata():
    # Only using a single process for downloading metadata
    metadata_files = [
        ('data/train', 'train_anno.jsonl.tgz'),
        ('data/train', 'train_anno_realease_repath.jsonl.tgz'),
        ('data/valid', 'valid_anno_repath.jsonl.tgz'),
        ('data/test', 'test_questions.jsonl.tgz'),
        ('data/test', 'imgs.tgz'),
    ]

    for subfolder, filename in metadata_files:
        hf_hub_download(
            repo_id="JourneyDB/JourneyDB",
            repo_type="dataset",
            subfolder=subfolder,
            filename=filename,
            local_dir=COMPRESSED_DIR,
            local_dir_use_symlinks=False,
        )

    metadata_tars = [
        os.path.join(dir, fname) for (dir, fname) in metadata_files
    ]

    for tar_file in metadata_tars:
        subprocess.call(
            f'tar -xvzf {os.path.join(COMPRESSED_DIR, tar_file)} '
            f'-C {os.path.join(COMPRESSED_DIR, os.path.dirname(tar_file))}',
            shell=True,
        )

    shutil.copy(
        f'{os.path.join(COMPRESSED_DIR, "data/train/train_anno_realease_repath.jsonl")}',
        f'{os.path.join(RAW_DIR, "train/train_anno_realease_repath.jsonl")}',
    )

    shutil.copy(
        f'{os.path.join(COMPRESSED_DIR, "data/valid/valid_anno_repath.jsonl")}',
        f'{os.path.join(RAW_DIR, "valid/valid_anno_repath.jsonl")}',
    )

    shutil.copy(
        f'{os.path.join(COMPRESSED_DIR, "data/test/test_questions.jsonl")}',
        f'{os.path.join(RAW_DIR, "test/test_questions.jsonl")}',
    )

    shutil.move(
        f'{os.path.join(COMPRESSED_DIR, "data/test/imgs")}',
        f'{os.path.join(RAW_DIR, "test/")}',
    )

if not os.path.exists(os.path.join(RAW_DIR, 'train', 'train_anno_realease_repath.jsonl')):
    download_and_process_metadata()

In [None]:
valid_ids = list(np.arange(1))

pool_args = [('train', i) for i in valid_ids] + [('valid', i) for i in valid_ids]
pool_args

In [None]:
max_image_size = 512
min_image_size = 64

In [None]:
def download_uncompress_resize(
    valid_ids: list,
    max_image_size: int,
    min_image_size: int,
    split: str,
    idx: int,
):
    """Download, uncompress, and resize images for a given archive index."""
    assert split in ('train', 'valid')
    assert idx in valid_ids

    print(f"Downloading idx: {idx}")
    if not os.path.exists(f'{COMPRESSED_DIR}/data/{split}/imgs/{idx:>03}/'):
        hf_hub_download(
            repo_id="JourneyDB/JourneyDB",
            repo_type="dataset",
            subfolder=f'data/{split}/imgs',
            filename=f'{idx:>03}.tgz',
            local_dir=COMPRESSED_DIR,
            local_dir_use_symlinks=False,
        )
    print(f"Downloaded idx: {idx}")

    print(f"Extracting idx: {idx}")
    if not os.path.exists(f'{COMPRESSED_DIR}/data/{split}/imgs/{idx:>03}/'):
        subprocess.call(
            f'tar -xzf {COMPRESSED_DIR}/data/{split}/imgs/{idx:>03}.tgz '
            f'-C {COMPRESSED_DIR}/data/{split}/imgs/',
            shell=True,
        )
    print(f"Extracted idx: {idx}")

    print(f"Removing idx: {idx}")
    if os.path.exists(f'{COMPRESSED_DIR}/data/{split}/imgs/{idx:>03}.tgz'):
        os.remove(f'{COMPRESSED_DIR}/data/{split}/imgs/{idx:>03}.tgz')
    print(f"Removed idx: {idx}")

    # add bicubic downsize
    downsize = transforms.Resize(
        max_image_size,
        antialias=True,
        interpolation=transforms.InterpolationMode.BICUBIC,
    )

    print(f"Downsizing idx: {idx}")
    os.makedirs(
        f'{RAW_DIR}/{split}/imgs/{idx:>03}/',
        exist_ok=True,
    )
    for f in iglob(f'{COMPRESSED_DIR}/data/{split}/imgs/{idx:>03}/*'):
        save_path = f'{RAW_DIR}/{split}/imgs/{idx:>03}/{os.path.basename(f)}'

        if os.path.exists(save_path):
            continue

        if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
            try:
                img = Image.open(f)
                w, h = img.size
                if min(w, h) > max_image_size:
                    img = downsize(img)
                if min(w, h) < min_image_size:
                    print(
                        f'Skipping image with resolution ({h}, {w}) - '
                        f'Since at least one side has resolution below {min_image_size}'
                    )
                    continue

                img.save(save_path)
                os.remove(f)
            except (UnidentifiedImageError, OSError) as e:
                print(f"Error {e}, File: {f}")
    print(f'Downsized idx: {idx}')


num_proc = 4

with Pool(processes=num_proc) as pool:
    pool.starmap(
        download_uncompress_resize,
        [(valid_ids, max_image_size, min_image_size, split, idx) \
            for split, idx in pool_args])

In [None]:
# !cd micro_diffusion/micro_diffusion/datasets/prepare/jdb && python download.py --datadir $DATA_DIR --max_image_size 512 --min_image_size 256 --valid_ids 0 1 --num_proc 2

## Convert

In [None]:
import os
import json
from glob import glob
from argparse import ArgumentParser
from PIL import Image
from streaming.base import MDSWriter
from tqdm import tqdm

In [None]:
def convert_to_mds(
    images_dir: str,
    captions_jsonl: str,
    local_mds_dir: str,
):
    """Converts JourneyDB dataset to mds format."""
    columns = {
        'width': 'int32',
        'height': 'int32',
        'jpg': 'jpeg',
        'caption': 'str',
    }
    
    writer = MDSWriter(
        out=local_mds_dir,
        columns=columns,
        compression=None,
        size_limit=256 * (2**20),
        max_workers=64,
    )
    
    # Retrieving achieve indies, in case only a subset of the data is downloaded
    valid_archieve_idx = [
        os.path.basename(p) for p in glob(os.path.join(images_dir, '*'))
    ]
    
    metadata = list(open(captions_jsonl, 'r'))
    for f in tqdm(metadata):
        d = json.loads(f)
        cap, p = d['prompt'], d['img_path'].strip('./')
        
        if os.path.dirname(p) not in valid_archieve_idx:
            continue
            
        try:
            img = Image.open(os.path.join(images_dir, p))
            w, h = img.size
            mds_sample = {
                'jpg': img,
                'caption': cap,
                'width': w,
                'height': h,
            }
            writer.write(mds_sample)
        except Exception as e:
            print(
                "Something went wrong in reading caption, "
                f"skipping writing this sample in mds. Error: {e}"
            )

    writer.finish()


# python convert.py --images_dir ./datadir/jdb/raw/train/imgs/ --captions_jsonl ./datadir/jdb/raw/train/train_anno_realease_repath.jsonl --local_mds_dir ./datadir/jdb/mds/train/
# python convert.py --images_dir ./datadir/jdb/raw/valid/imgs/ --captions_jsonl ./datadir/jdb/raw/valid/valid_anno_repath.jsonl --local_mds_dir ./datadir/jdb/mds/valid/

TRAIN_IMAGES_DIR = os.path.join(RAW_DIR, 'train', 'imgs')
TRAIN_CAPTIONS_JSONL = os.path.join(RAW_DIR, 'train', 'train_anno_realease_repath.jsonl')
TRAIN_LOCAL_MDS_DIR = os.path.join(DATA_DIR, 'mds', 'train')

if not os.path.exists(TRAIN_LOCAL_MDS_DIR):
    convert_to_mds(
        images_dir=TRAIN_IMAGES_DIR,
        captions_jsonl=TRAIN_CAPTIONS_JSONL,
        local_mds_dir=TRAIN_LOCAL_MDS_DIR,
    )

VALID_IMAGES_DIR = os.path.join(RAW_DIR, 'valid', 'imgs')
VALID_CAPTIONS_JSONL = os.path.join(RAW_DIR, 'valid', 'valid_anno_repath.jsonl')
VALID_LOCAL_MDS_DIR = os.path.join(DATA_DIR, 'mds', 'valid')

if not os.path.exists(VALID_LOCAL_MDS_DIR):
    convert_to_mds(
        images_dir=VALID_IMAGES_DIR,
        captions_jsonl=VALID_CAPTIONS_JSONL,
        local_mds_dir=VALID_LOCAL_MDS_DIR,
    )

In [None]:
!cd micro_diffusion/micro_diffusion/datasets/prepare/jdb && \
    # accelerate launch --num_processes 8 precompute.py --datadir $DATADIR/jdb/mds/train/ --savedir $DATADIR/jdb/mds_latents_sdxl1_dfnclipH14/train/ --vae stabilityai/stable-diffusion-xl-base-1.0 --text_encoder openclip:hf-hub:apple/DFN5B-CLIP-ViT-H-14-378 --batch_size 32

### Precompute Latents


In [None]:
import os
import time
from argparse import ArgumentParser

import numpy as np
import torch
from accelerate import Accelerator
from diffusers import AutoencoderKL
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from streaming import MDSWriter
from streaming.base.util import merge_index
from tqdm import tqdm

In [None]:
import math
from collections.abc import Iterable
from itertools import repeat
from typing import Optional, Tuple, Dict, Union, List, Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torchmetrics import Metric

import open_clip
from transformers import (
    CLIPTextModel,
    CLIPTokenizer,
    T5EncoderModel, 
    T5Tokenizer
)


class simple_2_hf_tokenizer_wrapper:
    """Simple wrapper to make OpenCLIP tokenizer match HuggingFace interface.
    
    Args:
        tokenizer (Any): OpenCLIP tokenizer instance
    """
    def __init__(self, tokenizer: Any):
        self.tokenizer = tokenizer
        self.model_max_length = self.tokenizer.context_length
        
    def __call__(
        self,
        caption: str,
        padding: str = 'max_length',
        max_length: Optional[int] = None,
        truncation: bool = True,
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        return {'input_ids': self.tokenizer(caption, context_length=max_length)}


class UniversalTokenizer:
    """Universal tokenizer supporting multiple model types.
    
    Args:
        name (str): Name/path of the tokenizer to load
    """
    def __init__(self, name: str):
        self.name = name
        s, d = text_encoder_embedding_format(name)
        if self.name.startswith("openclip:"):
            self.tokenizer = simple_2_hf_tokenizer_wrapper(
                open_clip.get_tokenizer(name.lstrip('openclip:'))
            )
            assert s == self.tokenizer.model_max_length, "simply check of text_encoder_embedding_format"
        elif self.name == "DeepFloyd/t5-v1_1-xxl":
            self.tokenizer = T5Tokenizer.from_pretrained(name) # for t5 we would use a smaller than max_seq_length
        else:
            self.tokenizer = CLIPTokenizer.from_pretrained(name, subfolder='tokenizer')
            assert s == self.tokenizer.model_max_length, "simply check of text_encoder_embedding_format"
        self.model_max_length = s
        
    def tokenize(self, captions: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
        if self.name == "DeepFloyd/t5-v1_1-xxl":
            text_tokens_and_mask = self.tokenizer(
                captions,
                padding='max_length',
                max_length=self.model_max_length,
                truncation=True,
                return_attention_mask=True,
                add_special_tokens=True,
                return_tensors='pt'
            )
            return {
                'input_ids': text_tokens_and_mask['input_ids'],
                'attention_mask': text_tokens_and_mask['attention_mask']
            }
        else:
            # Avoid attention mask for CLIP tokenizers as they are not used
            tokenized_caption = self.tokenizer(
                captions,
                padding='max_length',
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors='pt'
            )['input_ids']
            return {'input_ids': tokenized_caption}

In [None]:
from typing import Callable, Dict, List, Optional, Sequence, Union
from torch.utils.data import DataLoader
from torchvision import transforms
from streaming import Stream, StreamingDataset

# from micro_diffusion.models.utils import UniversalTokenizer

class StreamingJdbDatasetForPreCompute(StreamingDataset):
    """Streaming dataset that resizes images to user-provided resolutions and tokenizes captions."""

    def __init__(
        self,
        streams: Sequence[Stream],
        transforms_list: List[Callable],
        batch_size: int,
        tokenizer_name: str,
        shuffle: bool = False,
        caption_key: str = 'caption',
    ):
        super().__init__(
            streams=streams,
            shuffle=shuffle,
            batch_size=batch_size,
        )

        self.transforms_list = transforms_list
        self.caption_key = caption_key
        self.tokenizer = UniversalTokenizer(tokenizer_name)
        print("Created tokenizer: ", tokenizer_name)
        assert self.transforms_list is not None, 'Must provide transforms to resize and center crop images'

    def __getitem__(self, index: int) -> Dict:
        sample = super().__getitem__(index)
        ret = {}

        out = self.tokenizer.tokenize(sample[self.caption_key])
        ret[self.caption_key] = out['input_ids'].clone().detach()
        if 'attention_mask' in out:
            ret[f'{self.caption_key}_attention_mask'] = out['attention_mask'].clone().detach()

        for i, tr in enumerate(self.transforms_list):
            img = sample['jpg']
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img = tr(img)
            ret[f'image_{i}'] = img

        ret['sample'] = sample
        return ret


def build_streaming_jdb_precompute_dataloader(
    datadir: Union[List[str], str],
    batch_size: int,
    resize_sizes: Optional[List[int]] = None,
    drop_last: bool = False,
    shuffle: bool = True,
    caption_key: Optional[str] = None,
    tokenizer_name: Optional[str] = None,
    **dataloader_kwargs,
) -> DataLoader:
    """Builds a streaming mds dataloader returning multiple image sizes and text captions."""
    assert resize_sizes is not None, 'Must provide target resolution for image resizing'
    datadir = [datadir] if isinstance(datadir, str) else datadir
    streams = [Stream(remote=None, local=l) for l in datadir]

    transforms_list = []
    for resize in resize_sizes:
        transforms_list.append(
            transforms.Compose([
                transforms.Resize(
                    resize,
                    interpolation=transforms.InterpolationMode.BICUBIC,
                ),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        )

    dataset = StreamingJdbDatasetForPreCompute(
        streams=streams,
        shuffle=shuffle,
        transforms_list=transforms_list,
        batch_size=batch_size,
        caption_key=caption_key,
        tokenizer_name=tokenizer_name,
    )

    def custom_collate(list_of_dict: List[Dict]) -> Dict:
        out = {k: [] for k in list_of_dict[0].keys()}
        for d in list_of_dict:
            for k, v in d.items():
                out[k].append(v)
        return out

    dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        drop_last=drop_last,
        collate_fn=custom_collate,
        **dataloader_kwargs,
    )

    return dataloader

In [None]:


def text_encoder_embedding_format(enc: str) -> Tuple[int, int]:
    """Returns sequence length and token embedding dimension for text encoder."""
    if enc in [
        'stabilityai/stable-diffusion-2-base',
        'runwayml/stable-diffusion-v1-5',
        'CompVis/stable-diffusion-v1-4'
    ]:
        return 77, 1024
    if enc in ['openclip:hf-hub:apple/DFN5B-CLIP-ViT-H-14-378']:
        return 77, 1024
    if enc in ["DeepFloyd/t5-v1_1-xxl"]:
        return 120, 4096
    raise ValueError(f'Please specifcy the sequence and embedding size of {enc} encoder')
    
    

In [None]:

class UniversalTextEncoder(torch.nn.Module):
    """Universal text encoder supporting multiple model types.
    
    Args:
        name (str): Name/path of the model to load
        dtype (str): Data type for model weights
        pretrained (bool, True): Whether to load pretrained weights
    """
    def __init__(self, name: str, dtype: str, pretrained: bool = True):
        super().__init__()
        self.name = name
        if self.name.startswith("openclip:"):
            assert pretrained, 'Load default pretrained model from openclip'
            self.encoder = openclip_text_encoder(
                open_clip.create_model_and_transforms(name.lstrip('openclip:'))[0],
                torch_dtype=DATA_TYPES[dtype]
            )
        elif self.name == "DeepFloyd/t5-v1_1-xxl":
            self.encoder = T5EncoderModel.from_pretrained(
                name,
                torch_dtype=DATA_TYPES[dtype],
                pretrained=pretrained
            )
        else:
            self.encoder = CLIPTextModel.from_pretrained(
                name,
                subfolder='text_encoder',
                torch_dtype=DATA_TYPES[dtype],
                pretrained=pretrained
            )

    def encode(self, tokenized_caption: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        if self.name == "DeepFloyd/t5-v1_1-xxl":
            out = self.encoder(
                tokenized_caption,
                attention_mask=attention_mask
            )['last_hidden_state']
            out = out.unsqueeze(dim=1)
            return out, None
        else:
            return self.encoder(tokenized_caption)

In [None]:
DATA_TYPES = {
    'float16': torch.float16,
    'bfloat16': torch.bfloat16,
    'float32': torch.float32
}

In [None]:
class openclip_text_encoder(torch.nn.Module):
    """OpenCLIP text encoder wrapper.
    
    Args:
        clip_model (Any): OpenCLIP model instance
        dtype (torch.dtype, torch.float32): Data type for model weights
    """
    def __init__(self, clip_model: Any, dtype: torch.dtype = torch.float32, **kwargs) -> None:
        super().__init__()
        self.clip_model = clip_model
        self.device = None
        self.dtype = dtype

    def forward_fxn(self, text: torch.Tensor) -> Tuple[torch.Tensor, None]:
        cast_dtype = self.clip_model.transformer.get_cast_dtype()
        x = self.clip_model.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]
        x = x + self.clip_model.positional_embedding.to(cast_dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.clip_model.transformer(x, attn_mask=self.clip_model.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.clip_model.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        x = x.unsqueeze(dim=1) # [batch_size, 1, n_ctx, transformer.width] expected for text_emb
        return x, None # HF encoders expected to return multiple values with first being text emb

    def forward(self, text: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, None]:
        with torch.autocast(device_type='cuda', dtype=self.dtype):
            return self.forward_fxn(text)

In [None]:
def main(
    datadir: str,
    savedir: str = "",
    image_resolutions: list = [256, 512],
    save_images: bool = False,
    model_dtype: str = "bfloat16",
    save_dtype: str = "float16",
    vae: str = "stabilityai/stable-diffusion-xl-base-1.0",
    text_encoder: str = "openclip:hf-hub:apple/DFN5B-CLIP-ViT-H-14-378",
    batch_size: int = 32,
    seed: int = 2024,
):
    """Precompute image and text latents and store them in MDS format.

    By default, we only save the image latents for 256x256 and 512x512 image
    resolutions (using center crop).

    Note that the image latents will be scaled by the vae_scaling_factor.
    """
    cap_key = 'caption'  # Hardcoding the image caption key to 'caption' in MDS dataset

    accelerator = Accelerator()
    device = accelerator.device
    device_idx = int(accelerator.process_index)

    # Set random seeds
    torch.manual_seed(device_idx + seed)
    torch.cuda.manual_seed(device_idx + seed)
    np.random.seed(device_idx + seed)

    dataloader = build_streaming_jdb_precompute_dataloader(
        datadir=[datadir],
        batch_size=batch_size,
        resize_sizes=image_resolutions,
        drop_last=False,
        shuffle=False,
        caption_key=cap_key,
        tokenizer_name=text_encoder,
        # prefetch_factor=2,
        # num_workers=2,
        # persistent_workers=True,
        pin_memory=True,
    )
    print(f'Device: {device_idx}, Dataloader sample count: {len(dataloader.dataset)}')

    # print(
    #     f"MP variable -> world size: {os.environ['WORLD_SIZE']}, "
    #     f"RANK: {os.environ['RANK']}, {device}"
    # )

    vae = AutoencoderKL.from_pretrained(
        vae,
        subfolder='vae',  # Change subfolder to appropriate one in hf_hub, if needed
        torch_dtype=DATA_TYPES[model_dtype],
    )
    print("Created VAE: ", vae)
    assert isinstance(vae, AutoencoderKL)

    text_encoder = UniversalTextEncoder(
        text_encoder,
        dtype=model_dtype,
        pretrained=True,
    )
    print("Created text encoder: ", text_encoder)

    vae = vae.to(device)
    text_encoder = text_encoder.to(device)

    columns = {
        cap_key: 'str',
        f'{cap_key}_latents': 'bytes',
        'latents_256': 'bytes',
        'latents_512': 'bytes',
    }
    if save_images:
        columns['jpg'] = 'jpeg'

    remote_upload = os.path.join(savedir, str(accelerator.process_index))
    writer = MDSWriter(
        out=remote_upload,
        columns=columns,
        compression=None,
        size_limit=256 * (2**20),
        max_workers=64,
    )

    for batch in tqdm(dataloader):
        image_256 = torch.stack(batch['image_0']).to(device)
        image_512 = torch.stack(batch['image_1']).to(device)
        captions = torch.stack(batch[cap_key]).to(device)

        with torch.no_grad():
            with torch.autocast(device_type='cuda', dtype=DATA_TYPES[model_dtype]):
                latent_dist_256 = vae.encode(image_256)
                assert isinstance(latent_dist_256, AutoencoderKLOutput)
                latents_256 = (
                    latent_dist_256['latent_dist'].sample().data * vae.config.scaling_factor
                ).to(DATA_TYPES[save_dtype])

                latent_dist_512 = vae.encode(image_512)
                assert isinstance(latent_dist_512, AutoencoderKLOutput)
                latents_512 = (
                    latent_dist_512['latent_dist'].sample().data * vae.config.scaling_factor
                ).to(DATA_TYPES[save_dtype])

                attention_mask = None
                if f'{cap_key}_attention_mask' in batch:
                    attention_mask = torch.stack(
                        batch[f'{cap_key}_attention_mask']
                    ).to(device)

                conditioning = text_encoder.encode(
                    captions.view(-1, captions.shape[-1]),
                    attention_mask=attention_mask,
                )[0].to(DATA_TYPES[save_dtype])

        try:
            if isinstance(latents_256, torch.Tensor) and isinstance(
                latents_512, torch.Tensor
            ):
                latents_256 = latents_256.detach().cpu().numpy()
                latents_512 = latents_512.detach().cpu().numpy()
            else:
                continue

            if isinstance(conditioning, torch.Tensor):
                conditioning = conditioning.detach().cpu().numpy()
            else:
                continue

            # Write the batch to the MDS file
            for i in range(latents_256.shape[0]):
                mds_sample = {
                    cap_key: batch['sample'][i][cap_key],
                    f'{cap_key}_latents': np.reshape(conditioning[i], -1).tobytes(),
                    'latents_256': latents_256[i].tobytes(),
                    'latents_512': latents_512[i].tobytes(),
                }
                if save_images:
                    mds_sample['jpg'] = batch['sample'][i]['jpg']
                writer.write(mds_sample)
        except RuntimeError:
            print('Runtime error CUDA, skipping this batch')

    writer.finish()

    # Wait for all processes to finish
    accelerator.wait_for_everyone()
    print(f'Process {accelerator.process_index} finished')
    time.sleep(10)

    # Merge the mds shards created by each device (only do on main process)
    if accelerator.is_main_process:
        shards_metadata = [
            os.path.join(savedir, str(i), 'index.json')
            for i in range(accelerator.num_processes)
        ]
        merge_index(shards_metadata, out=savedir, keep_local=True)

# for split in train valid; do
#     accelerate launch --multi_gpu --num_processes 8 precompute.py --datadir ./datadir/jdb/mds/$split/ --savedir ./datadir/jdb/mds_latents_sdxl1_dfnclipH14/$split/ --vae stabilityai/stable-diffusion-xl-base-1.0 --text_encoder openclip:hf-hub:apple/DFN5B-CLIP-ViT-H-14-378 --batch_size 32
# done

main(
    datadir=os.path.join(DATA_DIR, 'mds', 'train'),
    savedir=os.path.join(DATA_DIR, 'mds_latents_sdxl1_dfnclipH14', 'train'),
    image_resolutions=[256, 512],
    save_images=False,
    model_dtype="bfloat16",
    save_dtype="float16",
    vae="stabilityai/stable-diffusion-xl-base-1.0",
    text_encoder="openclip:hf-hub:apple/DFN5B-CLIP-ViT-H-14-378",
    batch_size=32,
    seed=2024,
)