In [1]:
import torch
import os
import copy
import math
import logging
import itertools
import shutil
import pandas as pd

import diffusers
import transformers
from tqdm.auto import tqdm

from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed

from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    SD3Transformer2DModel,
    StableDiffusion3Pipeline,
)
from diffusers.training_utils import (
    _set_state_dict_into_text_encoder,
    cast_training_params,
    compute_density_for_timestep_sampling,
    compute_loss_weighting_for_sd3,
)
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.optimization import get_scheduler
from torch.utils.data import Dataset



logger = logging.getLogger(__name__)
global_config = None

In [2]:
dataset_path = 'genshin_dataset'
train_df_path = os.path.join(dataset_path, 'dataset1_sd3_emb2.csv')
train_df = pd.read_csv(train_df_path)
prompts = train_df['description'].unique()
print('Number of prompts: ', len(prompts))
print('Number of characters: ', len(train_df['character'].unique()))
print('Number of imgs: ', len(train_df['im_path'].unique()))
print('Total dataset: ', len(train_df))

Number of prompts:  1102
Number of characters:  56
Number of imgs:  557
Total dataset:  11140


### Dataset and collate_fn

In [3]:
class GenshinDataset(Dataset):
    def __init__(self, data, args):
        self.df = data
        self.size = args.resolution
        self.custom_instance_prompts = True # we use a customized dataset
        
        self.image_column = args.image_column
        self.caption_column = args.caption_column
        self.emb_column = args.embeddings
        self.pooled_emb_column = args.pooled_embeddings

        self.image_transforms = train_transforms = transforms.Compose(
            [
                transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
                transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )
        self.num_instance_images = len(self.df)
        self._length = self.num_instance_images

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        row = self.df.iloc[index]
        class_image = Image.open(row[self.image_column])
        class_image = exif_transpose(class_image)
        if not class_image.mode == "RGB": class_image = class_image.convert("RGB")

        #instance_images, instance_prompt
        example["instance_images"] = self.image_transforms(class_image)
        example["instance_prompt"] = row[self.caption_column]

        example["instance_emb"] = torch.load(row[self.emb_column], weights_only=True)
        example["instance_pooled_emb"] = torch.load(row[self.pooled_emb_column], weights_only=True)
        return example


def collate_fn(examples):
    pixel_values = [example["instance_images"] for example in examples]
    prompts = [example["instance_prompt"] for example in examples]
    embeddings = [example["instance_emb"] for example in examples]
    pooled_embeddings = [example["instance_pooled_emb"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    embeddings = torch.stack(embeddings)
    embeddings = embeddings.to(memory_format=torch.contiguous_format).float()
    pooled_embeddings = torch.stack(pooled_embeddings)
    pooled_embeddings = pooled_embeddings.to(memory_format=torch.contiguous_format).float()

    batch = {"pixel_values": pixel_values, "prompts": prompts, "embeddings": embeddings, "pooled_embeddings": pooled_embeddings}
    return batch

### Parameters for training

In [4]:
from dataclasses import dataclass
from pathlib import Path

@dataclass
class TrainParams():
    exp_name = 'sd3_exp4'

    train_dataframe_path:str = train_df_path
    train_dataframe = train_df #train data    
    dataset_name:str = 'dataset1_genshin_dataset'
    caption_column:str = 'description'
    image_column:str = 'im_path'
    embeddings:str = 'embeddings'
    pooled_embeddings:str = 'pooled_embeddings'

    output_dir:str = f"models/{exp_name}" #'models/train_dataset1_sd2'#'finetune_stable-diffusion-2-1'
    pretrained_model_name_or_path:str = "stabilityai/stable-diffusion-3-medium-diffusers"
    #"stabilityai/stable-diffusion-2-1"

    cache_dir:str = os.path.join(output_dir, 'cache')
    logging_dir:Path = Path(os.path.join(output_dir, 'logs'))
    seed:int = 1991

    ### specific new params..
    precondition_outputs:int = 1 #Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how model `target` is calculated.

    weighting_scheme:str = "none" #choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"] 
    #We default to the "none" weighting scheme for uniform sampling and uniform loss'
    mode_scale:float = 1.29 #Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`
    logit_mean:float = 0.0 #mean to use when using the `'logit_normal'` weighting scheme.
    logit_std:float = 1.0 #std to use when using the `'logit_normal'` weighting scheme.


    ## image params and transforms
    resolution:int = 256
    center_crop:bool = True #Whether to center crop the input images to the resolution. 
    #If not set, the images will be randomly cropped. 
    random_flip:bool = True
    dataloader_num_workers:int=2

    ## accelerator params
    gradient_accumulation_steps:int = 1
    mixed_precision:str = "fp16" #choices=["no", "fp16", "bf16"]
    report_to:str = "tensorboard"
    gradient_checkpointing: bool = True

    # model
    revision:str=None
    variant:str="fp16"
    #lora rank
    rank:int = 4
    lora_alpha:int = 16 #32    

    #optimizer train params
    learning_rate:float = 1e-4
    text_encoder_lr: float = 5e-6
    guidance_scale:float = 3.5 #the FLUX.1 dev variant is a guidance distilled model

    scale_lr:bool = False
    lr_scheduler:str = "constant" #Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",
                                #"constant", "constant_with_warmup"]'
    lr_warmup_steps:int = 500
    lr_num_cycles:int = 1 #Number of hard resets of the lr in cosine_with_restarts scheduler.
    lr_power:float = 1.0 #Power factor of the polynomial scheduler

    train_batch_size:int = 32 #16
    sample_batch_size:int = 10
    num_train_epochs:int = 30
    max_train_steps:int = 5000

    #### optimizer
    optimizer:str = "AdamW" ##chose between adamw and prodigy
    use_8bit_adam:bool = True   #Choose between 'epsilon' or 'v_prediction' or leave `None`. 
    #If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.
    prediction_type:str = None 
    
    ###
    local_rank:int = 1

    adam_beta1:float = 0.9
    adam_beta2:float = 0.999
    adam_weight_decay:float = 1e-2
    adam_weight_decay_text_encoder:float = 1e-3
    adam_epsilon:float = 1e-08

    #coefficients for computing the Prodigy stepsize using running averages. If set to None, uses the value of square root of beta2. Ignored if optimizer is adamW
    prodigy_beta3:float = None
    prodigy_decouple:bool = True #Use AdamW style decoupled weight decay

    #snr args
    snr_gamma:float = None #SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0.
    noise_offset:float = 0 #The scale of noise offset.

    #other
    validation_prompt:str="A girl with blue eyes, black hair and a bright smile. The girl wears a red dress and a ponytail with a flower decoration."
    num_validation_images:int=4 #Number of images that should be generated during validation with `validation_prompt`.
    validation_epochs:int = 1
    max_train_samples:int=None

    max_grad_norm:float = 1.0

    #Choose between 'epsilon' or 'v_prediction' or leave `None`. 
    #If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.
    prediction_type:str = None 
    
    ###
    local_rank:int = 1

    ##save params
    checkpointing_steps:int=200 # Save a checkpoint of the training state every X updates
    checkpoints_total_limit:int = 10 # Max number of checkpoints to store.
    
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != self.local_rank:
        self.local_rank = env_local_rank
        
args = TrainParams()


In [5]:
import json
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, 'args.json'), 'w') as file:
    args_dict = vars(args)
    del args_dict['logging_dir']
    file.write(json.dumps(args_dict, ensure_ascii=False))

### Train preparation

In [6]:
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=args.logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
    log_with=args.report_to,
    project_config=accelerator_project_config,
    kwargs_handlers=[kwargs],
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
    accelerator.native_amp = False

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state)#, main_process_only=False)
if accelerator.is_local_main_process:
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

set_seed(args.seed)


if accelerator.is_main_process:
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
08/20/2024 12:35:54 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16



In [7]:
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)

vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
transformer = SD3Transformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer")
transformer.requires_grad_(False)
vae.requires_grad_(False)

weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

vae.to(accelerator.device, dtype=torch.float32)
transformer.to(accelerator.device, dtype=weight_dtype)

if args.gradient_checkpointing: transformer.enable_gradient_checkpointing()

{'max_shift', 'base_image_seq_len', 'max_image_seq_len', 'base_shift', 'use_dynamic_shifting'} was not found in config. Values will be initialized to default values.
{'mid_block_add_attention'} was not found in config. Values will be initialized to default values.


In [8]:
#Add LoRA weights to attention layers
transformer_lora_config = LoraConfig(
    r=args.rank,
    lora_alpha=args.lora_alpha,
    init_lora_weights="gaussian",
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)

In [9]:
def unwrap_model(model):
    model = accelerator.unwrap_model(model)
    model = model._orig_mod if is_compiled_module(model) else model
    return model

In [10]:
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
    if accelerator.is_main_process:
        transformer_lora_layers_to_save = None
        text_encoder_one_lora_layers_to_save = None
        text_encoder_two_lora_layers_to_save = None

        for model in models:
            if isinstance(model, type(unwrap_model(transformer))):
                transformer_lora_layers_to_save = get_peft_model_state_dict(model)
            else:
                raise ValueError(f"unexpected save model: {model.__class__}")

            # make sure to pop weight so that corresponding model is not saved again
            weights.pop()

        StableDiffusion3Pipeline.save_lora_weights(
            output_dir,
            transformer_lora_layers=transformer_lora_layers_to_save,
            text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
            text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
        )

accelerator.register_save_state_pre_hook(save_model_hook)

<torch.utils.hooks.RemovableHandle at 0x7a3480e42dd0>

In [11]:
if args.mixed_precision == "fp16":
        models = [transformer]
        # only upcast trainable parameters (LoRA) into fp32
        cast_training_params(models, dtype=torch.float32)

transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))


if args.scale_lr: args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes)

# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
params_to_optimize = [transformer_parameters_with_lr]

if args.use_8bit_adam:
        try:
                import bitsandbytes as bnb
        except ImportError:
                raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")

        optimizer_class = bnb.optim.AdamW8bit
else:
        optimizer_class = torch.optim.AdamW

optimizer = optimizer_class(params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,)

In [12]:
# Dataset and DataLoaders creation:
train_dataset = GenshinDataset(data=args.train_dataframe, args=args)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.train_batch_size,
    shuffle=True,
    collate_fn=lambda examples: collate_fn(examples),
    num_workers=args.dataloader_num_workers,
)

In [13]:
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=args.max_train_steps * accelerator.num_processes,
    num_cycles=args.lr_num_cycles,
    power=args.lr_power,
)

In [14]:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(transformer, optimizer, train_dataloader, lr_scheduler)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

In [15]:
if accelerator.is_main_process:
    tracker_name = args.exp_name
    accelerator.init_trackers(tracker_name)#, config=vars(args))

### Train

In [16]:
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
    sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
    schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
    timesteps = timesteps.to(accelerator.device)
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

    sigma = sigmas[step_indices].flatten()
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma

TODO: validation

In [17]:
    
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
logger.info(f"  Num Epochs = {args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
initial_global_step = 0

progress_bar = tqdm(
    range(0, args.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
    transformer.train()
    for step, batch in enumerate(train_dataloader):
        models_to_accumulate = [transformer]
        with accelerator.accumulate(models_to_accumulate):
            pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
            prompts = batch["prompts"]
            prompt_embeds, pooled_prompt_embeds = batch["embeddings"], batch["pooled_embeddings"]
                        
            # Convert images to latent space
            model_input = vae.encode(pixel_values).latent_dist.sample()
            model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
            model_input = model_input.to(dtype=weight_dtype)

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(model_input)
            bsz = model_input.shape[0]

            # Sample a random timestep for each image
            # for weighting schemes where we sample timesteps non-uniformly
            u = compute_density_for_timestep_sampling(
                weighting_scheme=args.weighting_scheme,
                batch_size=bsz,
                logit_mean=args.logit_mean,
                logit_std=args.logit_std,
                mode_scale=args.mode_scale,
            )
            indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
            timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

            # Add noise according to flow matching.
            # zt = (1 - texp) * x + texp * z1
            sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
            noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

            # Predict the noise residual
            model_pred = transformer(
                hidden_states=noisy_model_input,
                timestep=timesteps,
                encoder_hidden_states=prompt_embeds,
                pooled_projections=pooled_prompt_embeds,
                return_dict=False,
            )[0]

            # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
            # Preconditioning of the model outputs.
            if args.precondition_outputs:
                model_pred = model_pred * (-sigmas) + noisy_model_input

            # these weighting schemes use a uniform timestep sampling
            # and instead post-weight the loss
            weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)

            # flow matching loss
            if args.precondition_outputs:
                target = model_input
            else:
                target = noise - model_input

            # Compute regular loss.
            loss = torch.mean(
                (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
                1,
            )
            loss = loss.mean()

            accelerator.backward(loss)
            if accelerator.sync_gradients:
                params_to_clip = transformer_lora_parameters
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            if accelerator.is_main_process:
                if global_step % args.checkpointing_steps == 0:
                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                    if args.checkpoints_total_limit is not None:
                        checkpoints = os.listdir(args.output_dir)
                        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                        if len(checkpoints) >= args.checkpoints_total_limit:
                            num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                            removing_checkpoints = checkpoints[0:num_to_remove]

                            logger.info(
                                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                            )
                            logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                            for removing_checkpoint in removing_checkpoints:
                                removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                shutil.rmtree(removing_checkpoint)

                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")

        logs = {"train_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=global_step)

        if global_step >= args.max_train_steps:
            break

# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    transformer = unwrap_model(transformer)
    transformer = transformer.to(torch.float32)
    transformer_lora_layers = get_peft_model_state_dict(transformer)

    text_encoder_lora_layers = None
    text_encoder_2_lora_layers = None

    StableDiffusion3Pipeline.save_lora_weights(
        save_directory=args.output_dir,
        transformer_lora_layers=transformer_lora_layers,
        text_encoder_lora_layers=text_encoder_lora_layers,
        text_encoder_2_lora_layers=text_encoder_2_lora_layers,
    )

    # Final inference
    # Load previous pipeline
    pipeline = StableDiffusion3Pipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        revision=args.revision,
        variant=args.variant,
        torch_dtype=weight_dtype,
    )
    # load attention processors
    pipeline.load_lora_weights(args.output_dir)

accelerator.end_training()

08/20/2024 12:35:59 - INFO - __main__ - ***** Running training *****
08/20/2024 12:35:59 - INFO - __main__ -   Num examples = 11140
08/20/2024 12:35:59 - INFO - __main__ -   Num batches each epoch = 349
08/20/2024 12:35:59 - INFO - __main__ -   Num Epochs = 15
08/20/2024 12:35:59 - INFO - __main__ -   Instantaneous batch size per device = 32
08/20/2024 12:35:59 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 32
08/20/2024 12:35:59 - INFO - __main__ -   Gradient Accumulation steps = 1
08/20/2024 12:35:59 - INFO - __main__ -   Total optimization steps = 5000


Steps:   0%|          | 0/5000 [00:00<?, ?it/s]

  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
08/20/2024 12:43:17 - INFO - accelerate.accelerator - Saving current state to models/sd3_exp4/checkpoint-200
Model weights saved in models/sd3_exp4/checkpoint-200/pytorch_lora_weights.safetensors
08/20/2024 12:43:18 - INFO - accelerate.checkpointing - Optimizer state saved in models/sd3_exp4/checkpoint-200/optimizer.bin
08/20/2024 12:43:18 - INFO - accelerate.checkpointing - Scheduler state saved in models/sd3_exp4/checkpoint-200/scheduler.bin
08/20/2024 12:43:18 - INFO - accelerate.checkpointing - Sampler state for dataloader 0 saved in models/sd3_exp4/checkpoint-200/sampler.bin
08/20/2024 12:43:18 - INFO - accelerate.checkpointing - Gradient scaler state saved in models/sd3_exp4/checkpoint-200/scaler.pt
08/20/2024 12:43:18 - INFO - accelerate.checkpointing - Random states saved in models/sd3_exp4/checkpoint-200/random_states_0.pkl
08/20/2024 12:43:18 - INFO - __

Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]

Loaded transformer as SD3Transformer2DModel from `transformer` subfolder of stabilityai/stable-diffusion-3-medium-diffusers.
Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-3-medium-diffusers.
Loaded text_encoder as CLIPTextModelWithProjection from `text_encoder` subfolder of stabilityai/stable-diffusion-3-medium-diffusers.
{'max_shift', 'base_image_seq_len', 'max_image_seq_len', 'base_shift', 'use_dynamic_shifting'} was not found in config. Values will be initialized to default values.
Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-3-medium-diffusers.
{'mid_block_add_attention'} was not found in config. Values will be initialized to default values.
Loaded vae as AutoencoderKL from `vae` subfolder of stabilityai/stable-diffusion-3-medium-diffusers.
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loaded tokenizer_3 as T5TokenizerFast from 

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded text_encoder_3 as T5EncoderModel from `text_encoder_3` subfolder of stabilityai/stable-diffusion-3-medium-diffusers.
