Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I train with mixed precision, I get a "ValueError: Attempting to unscale FP16 gradients" error. #1778

Closed
2 of 4 tasks
aihao2000 opened this issue Jul 26, 2023 · 11 comments
Assignees

Comments

@aihao2000
Copy link

aihao2000 commented Jul 26, 2023

System Info

- `Accelerate` version: 0.21.0
- Platform: Linux-5.19.0-46-generic-x86_64-with-glibc2.35
- Python version: 3.11.4
- Numpy version: 1.25.0
- PyTorch version (GPU?): 2.0.1 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 62.72 GB
- GPU type: NVIDIA RTX A4000
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: NO
        - mixed_precision: no
        - use_cpu: False
        - num_processes: 1
        - machine_rank: 0
        - num_machines: 1
        - rdzv_backend: static
        - same_network: False
        - main_training_function: main
        - downcast_bf16: False
        - tpu_use_cluster: False
        - tpu_use_sudo: False

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

when run ”accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)“ or ”optimizer.step()“, will error "if (not allow_fp16) and param.grad.dtype == torch.float16: raise ValueError("Attempting to unscale FP16 gradients.")"

from train_config import TrainConfig
import logging
import os
import math
from pathlib import Path
from attr import dataclass
from tqdm.auto import tqdm
import random
import accelerate
import diffusers
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import transformers
from torchvision import transforms
from packaging import version
import shutil
from accelerate.logging import get_logger
import PIL
from PIL import Image
import numpy as np
import datasets
import itertools

logger = get_logger(__name__)


def train(config: TrainConfig):
    accelerator = (
        accelerate.Accelerator(
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            mixed_precision=config.mixed_precision,
            log_with=config.log_with,
            project_config=accelerate.utils.ProjectConfiguration(
                project_dir=config.output_dir,
                logging_dir=os.path.join(config.output_dir, "logs"),
            ),
        )
        if config.accelerator is None
        else config.accelerator
    )
    if config.seed is not None:
        accelerate.utils.set_seed(config.seed)

    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
    """
        日志设置
    """
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    """加载预训练组件"""
    tokenizer: transformers.CLIPTokenizer = (
        transformers.CLIPTokenizer.from_pretrained(
            config.pretrained_model_name_or_path,
            subfolder="tokenizer",
        )
        if config.tokenizer is None
        else config.tokenizer
    )

    text_encoder: transformers.CLIPTextModel = (
        (
            transformers.CLIPTextModel.from_pretrained(
                config.pretrained_model_name_or_path,
                subfolder="text_encoder",
                revision=config.revision,
            )
        )
        if config.text_encoder is None
        else config.tokenizer
    )

    noise_scheduler: diffusers.DDPMScheduler = (
        (
            diffusers.DDPMScheduler.from_pretrained(
                config.pretrained_model_name_or_path, subfolder="scheduler"
            )
        )
        if config.noise_scheduler is None
        else config.noise_scheduler
    )

    vae: diffusers.AutoencoderKL = (
        diffusers.AutoencoderKL.from_pretrained(
            config.pretrained_model_name_or_path,
            subfolder="vae",
            revision=config.revision,
        )
        if config.vae is None
        else config.vae
    )

    unet: diffusers.UNet2DConditionModel = (
        (
            diffusers.UNet2DConditionModel.from_pretrained(
                config.pretrained_model_name_or_path,
                subfolder="unet",
                revision=config.revision,
            )
        )
        if config.unet is None
        else config.unet
    )

    unet.requires_grad_(True)
    text_encoder.requires_grad_(True)
    vae.requires_grad_(False)
    unet.train()
    text_encoder.train()

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    unet.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)
    text_encoder.to(accelerator.device, dtype=weight_dtype)

    if config.enable_xformers_memory_efficient_attention:
        unet.enable_xformers_memory_efficient_attention()

    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    if config.scale_lr:
        config.learning_rate = (
            config.learning_rate
            * config.gradient_accumulation_steps
            * config.train_batch_size
            * accelerator.num_processes
        )
    """优化器设置"""
    if config.optimizer is None:
        if config.use_8bit_adam:
            import bitsandbytes as bnb

            optimizer = bnb.optim.AdamW8bit(
                itertools.chain(unet.parameters(), text_encoder.parameters()),
                lr=config.learning_rate,
                betas=(config.adam_beta1, config.adam_beta2),
                weight_decay=config.adam_weight_decay,
                eps=config.adam_epsilon,
            )
        else:
            optimizer: torch.optim.AdamW = torch.optim.AdamW(
                itertools.chain(unet.parameters(), text_encoder.parameters()),
                lr=config.learning_rate,
                betas=(config.adam_beta1, config.adam_beta2),
                weight_decay=config.adam_weight_decay,
                eps=config.adam_epsilon,
            )
    else:
        optimizer = config.optimizer

    """dataset setting"""
    train_dataset = datasets.load_dataset(
        "data",
        data_dir="data",
        name="page spilt",
        split="train",
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.train_batch_size, shuffle=True
    )

    """lr scheduler"""
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / config.gradient_accumulation_steps
    )
    if config.max_train_steps is None:
        config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = diffusers.optimization.get_scheduler(
        config.lr_scheduler_name,
        optimizer,
        num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps,
        num_training_steps=config.max_train_steps * config.gradient_accumulation_steps,
        num_cycles=config.lr_num_cycles * config.gradient_accumulation_steps,
    )
    """训练准备"""
    text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        text_encoder, optimizer, train_dataloader, lr_scheduler
    )

    """
        accelerator分布式后,重新计算训练参数
    """
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / config.gradient_accumulation_steps
    )
    if overrode_max_train_steps:
        config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch

    config.num_train_epochs = math.ceil(
        config.max_train_steps / num_update_steps_per_epoch
    )
    if accelerator.is_main_process:
        accelerator.init_trackers(
            "train_gen_2d",
            config={
                "learning_rate": config.learning_rate,
                "batch_size": config.train_batch_size,
            },
        )

    total_batch_size = (
        config.train_batch_size
        * accelerator.num_processes
        * config.gradient_accumulation_steps
    )
    """训练开始"""
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {config.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {config.train_batch_size}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(f"  Gradient Accumulation steps = {config.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {config.max_train_steps}")
    global_step = 0
    first_epoch = 0
    """加载保存节点"""
    if config.resume_train_from_accelerator_state:
        if config.resume_train_from_accelerator_state == "latest":
            # Get the most recent checkpoint
            dirs = os.listdir(config.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            accelerator_ckpt_path = dirs[-1] if len(dirs) > 0 else None

        if config.accelerator_ckpt_path is None:
            accelerator.print(
                f"Checkpoint '{config.resume_train_from_accelerator_state}' does not exist. Starting a new training run."
            )
            config.resume_train_from_accelerator_state = None
        else:
            accelerator.print(f"Resuming from checkpoint {accelerator_ckpt_path}")
            accelerator.load_state(accelerator_ckpt_path)
            global_step = int(accelerator_ckpt_path.split("-")[1])

            resume_global_step = global_step * config.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (
                num_update_steps_per_epoch * config.gradient_accumulation_steps
            )
    progress_bar = tqdm(
        range(global_step, config.max_train_steps),
        disable=not accelerator.is_local_main_process,
    )
    progress_bar.set_description("Steps")

    """训练"""
    for epoch in range(first_epoch, config.num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            if (
                config.resume_train_from_accelerator_state
                and epoch == first_epoch
                and step < resume_step
            ):
                if step % config.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue

            with accelerator.accumulate(text_encoder), accelerator.accumulate(unet):
                # Convert images to latent space
                latents = vae.encode(
                    batch["pixel_values"].to(dtype=weight_dtype)
                ).latent_dist.sample()
                latents = latents * vae.config.scaling_factor

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                batch_size, channels, h, w = latents.shape
                # Sample a random timestep for each image
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (batch_size,),
                    device=latents.device,
                )
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
                    noisy_model_input = torch.cat(
                        [noisy_model_input, noisy_model_input], dim=1
                    )
                # Get the text embedding for conditioning
                encoder_hidden_states = text_encoder(
                    batch["input_ids"].to(text_encoder.device),
                    attention_mask=batch["attention_mask"].to(text_encoder.device),
                )[0].to(dtype=weight_dtype)

                # Predict the noise residual
                model_pred = unet(
                    noisy_latents, timesteps, encoder_hidden_states
                ).sample

                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(
                        f"Unknown prediction type {noise_scheduler.config.prediction_type}"
                    )

                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = itertools.chain(
                        unet.parameters(), text_encoder.parameters()
                    )
                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                images = []
                progress_bar.update(1)
                global_step += 1
                if global_step % config.save_steps == 0:
                    """
                    TODO:保存模型
                    """

                    """
                        验证
                    """
                    if (
                        config.output_dir is not None
                        and global_step % config.validation_steps == 0
                    ):
                        """
                        TODO:验证当前模型效果
                        """

            """
                TODO:打印日志
            """
            """
                TODO:打印日志
            """

            if global_step >= config.max_train_steps:
                break

Expected behavior

fp16 was used for training

@sgugger
Copy link
Collaborator

sgugger commented Jul 26, 2023

cc @muellerzr as you are investigating this right now ;-)

@muellerzr muellerzr self-assigned this Jul 26, 2023
@pacman100
Copy link
Contributor

Hello @AisingioroHao0, you shouldn't explicitly call model.half() or model.to(torch.float16) when using amp. See this PyTorch forum message: https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372/14

Could you please remove the following snippet and see if things work?

-    weight_dtype = torch.float32
-    if accelerator.mixed_precision == "fp16":
-        weight_dtype = torch.float16
-    elif accelerator.mixed_precision == "bf16":
-        weight_dtype = torch.bfloat16
-    unet.to(accelerator.device, dtype=weight_dtype)
-    vae.to(accelerator.device, dtype=weight_dtype)
-    text_encoder.to(accelerator.device, dtype=weight_dtype)

@aihao2000
Copy link
Author

aihao2000 commented Jul 27, 2023

@pacman100
I used the weight_dtype variable in the training. If I try to remove ".to", I get "RuntimeError: Input type (c10::Half) and bias type (float) should be the same " ,when running to" latents = vae.encode( batch["pixel_values"].to(dtype=weight_dtype)
).latent_dist.sample()”

@aihao2000
Copy link
Author

hello @pacman100
If I remove all the weight_type stuff, I still get an error.
This is the new code:

from train_config import TrainConfig
import logging
import os
import math
from pathlib import Path
from attr import dataclass
from tqdm.auto import tqdm
import random
import accelerate
import diffusers
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import transformers
from torchvision import transforms
from packaging import version
import shutil
from accelerate.logging import get_logger
import PIL
from PIL import Image
import numpy as np
import datasets
import itertools

logger = get_logger(__name__)


def train(config: TrainConfig):
    accelerator = (
        accelerate.Accelerator(
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            mixed_precision=config.mixed_precision,
            log_with=config.log_with,
            project_config=accelerate.utils.ProjectConfiguration(
                project_dir=config.output_dir,
                logging_dir=os.path.join(config.output_dir, "logs"),
            ),
        )
        if config.accelerator is None
        else config.accelerator
    )
    if config.seed is not None:
        accelerate.utils.set_seed(config.seed)

    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
    """
        日志设置
    """
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    """加载预训练组件"""
    tokenizer: transformers.CLIPTokenizer = (
        transformers.CLIPTokenizer.from_pretrained(
            config.pretrained_model_name_or_path,
            subfolder="tokenizer",
        )
        if config.tokenizer is None
        else config.tokenizer
    )

    text_encoder: transformers.CLIPTextModel = (
        (
            transformers.CLIPTextModel.from_pretrained(
                config.pretrained_model_name_or_path,
                subfolder="text_encoder",
                revision=config.revision,
            )
        )
        if config.text_encoder is None
        else config.tokenizer
    )

    noise_scheduler: diffusers.DDPMScheduler = (
        (
            diffusers.DDPMScheduler.from_pretrained(
                config.pretrained_model_name_or_path, subfolder="scheduler"
            )
        )
        if config.noise_scheduler is None
        else config.noise_scheduler
    )

    vae: diffusers.AutoencoderKL = (
        diffusers.AutoencoderKL.from_pretrained(
            config.pretrained_model_name_or_path,
            subfolder="vae",
            revision=config.revision,
        )
        if config.vae is None
        else config.vae
    )

    unet: diffusers.UNet2DConditionModel = (
        (
            diffusers.UNet2DConditionModel.from_pretrained(
                config.pretrained_model_name_or_path,
                subfolder="unet",
                revision=config.revision,
            )
        )
        if config.unet is None
        else config.unet
    )
    unet.requires_grad_(True)
    text_encoder.requires_grad_(True)
    vae.requires_grad_(False)

    unet.train()
    text_encoder.train()

    # weight_dtype = torch.float32
    # if accelerator.mixed_precision == "fp16":
    #     weight_dtype = torch.float16
    # elif accelerator.mixed_precision == "bf16":
    #     weight_dtype = torch.bfloat16

    # unet.to(accelerator.device, dtype=weight_dtype)
    # vae.to(accelerator.device, dtype=weight_dtype)
    # text_encoder.to(accelerator.device, dtype=weight_dtype)

    if config.enable_xformers_memory_efficient_attention:
        unet.enable_xformers_memory_efficient_attention()

    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    if config.scale_lr:
        config.learning_rate = (
            config.learning_rate
            * config.gradient_accumulation_steps
            * config.train_batch_size
            * accelerator.num_processes
        )
    """优化器设置"""
    if config.optimizer is None:
        if config.use_8bit_adam:
            import bitsandbytes as bnb

            optimizer = bnb.optim.AdamW8bit(
                itertools.chain(unet.parameters(), text_encoder.parameters()),
                lr=config.learning_rate,
                betas=(config.adam_beta1, config.adam_beta2),
                weight_decay=config.adam_weight_decay,
                eps=config.adam_epsilon,
            )
        else:
            optimizer: torch.optim.AdamW = torch.optim.AdamW(
                itertools.chain(unet.parameters(), text_encoder.parameters()),
                lr=config.learning_rate,
                betas=(config.adam_beta1, config.adam_beta2),
                weight_decay=config.adam_weight_decay,
                eps=config.adam_epsilon,
            )
    else:
        optimizer = config.optimizer

    """数据集设置"""
    train_dataset = datasets.load_dataset(
        "comic_generator/datasets/comics",
        data_dir="comic_generator/datasets/comics",
        name="page spilt",
        split="train",
    )
    image_preprocess = transforms.Compose(
        [
            transforms.Resize(config.train_image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def transform(batch):
        pixel_values = [
            image_preprocess(image.convert("RGB")) for image in batch["image"]
        ]
        res = tokenizer(
            [
                f"{comic_name},{chapter_name},{idx},"
                for comic_name, chapter_name, idx in zip(
                    batch["comic_name"], batch["chapter_name"], batch["idx"]
                )
            ],
            padding="max_length",
            truncation=True,
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        )
        input_ids = res.input_ids
        attention_mask = res.attention_mask
        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }

    train_dataset.set_transform(transform)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.train_batch_size, shuffle=True
    )

    """lr scheduler"""
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / config.gradient_accumulation_steps
    )
    if config.max_train_steps is None:
        config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = diffusers.optimization.get_scheduler(
        config.lr_scheduler_name,
        optimizer,
        num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps,
        num_training_steps=config.max_train_steps * config.gradient_accumulation_steps,
        num_cycles=config.lr_num_cycles * config.gradient_accumulation_steps,
    )
    """训练准备"""
    text_encoder, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        text_encoder, unet, optimizer, train_dataloader, lr_scheduler
    )

    """
        accelerator分布式后,重新计算训练参数
    """
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / config.gradient_accumulation_steps
    )
    if overrode_max_train_steps:
        config.max_train_steps = config.num_train_epochs * num_update_steps_per_epoch

    config.num_train_epochs = math.ceil(
        config.max_train_steps / num_update_steps_per_epoch
    )
    if accelerator.is_main_process:
        accelerator.init_trackers(
            "train_gen_2d",
            config={
                "learning_rate": config.learning_rate,
                "batch_size": config.train_batch_size,
            },
        )

    total_batch_size = (
        config.train_batch_size
        * accelerator.num_processes
        * config.gradient_accumulation_steps
    )
    """训练开始"""
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {config.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {config.train_batch_size}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(f"  Gradient Accumulation steps = {config.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {config.max_train_steps}")
    global_step = 0
    first_epoch = 0
    """加载保存节点"""
    if config.resume_train_from_accelerator_state:
        if config.resume_train_from_accelerator_state == "latest":
            # Get the most recent checkpoint
            dirs = os.listdir(config.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            accelerator_ckpt_path = dirs[-1] if len(dirs) > 0 else None

        if config.accelerator_ckpt_path is None:
            accelerator.print(
                f"Checkpoint '{config.resume_train_from_accelerator_state}' does not exist. Starting a new training run."
            )
            config.resume_train_from_accelerator_state = None
        else:
            accelerator.print(f"Resuming from checkpoint {accelerator_ckpt_path}")
            accelerator.load_state(accelerator_ckpt_path)
            global_step = int(accelerator_ckpt_path.split("-")[1])

            resume_global_step = global_step * config.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (
                num_update_steps_per_epoch * config.gradient_accumulation_steps
            )
    progress_bar = tqdm(
        range(global_step, config.max_train_steps),
        disable=not accelerator.is_local_main_process,
    )
    progress_bar.set_description("Steps")

    """训练"""
    for epoch in range(first_epoch, config.num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            if (
                config.resume_train_from_accelerator_state
                and epoch == first_epoch
                and step < resume_step
            ):
                if step % config.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue

            with accelerator.accumulate(text_encoder), accelerator.accumulate(unet):
                # Convert images to latent space
                latents = vae.encode(
                    batch["pixel_values"]
                ).latent_dist.sample()
                latents = latents * vae.config.scaling_factor

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                batch_size, channels, h, w = latents.shape
                # Sample a random timestep for each image
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (batch_size,),
                    device=latents.device,
                )
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
                    noisy_model_input = torch.cat(
                        [noisy_model_input, noisy_model_input], dim=1
                    )
                # Get the text embedding for conditioning
                encoder_hidden_states = text_encoder(
                    batch["input_ids"].to(text_encoder.device),
                    attention_mask=batch["attention_mask"].to(text_encoder.device),
                )[0]

                # Predict the noise residual
                model_pred = unet(
                    noisy_latents, timesteps, encoder_hidden_states
                ).sample

                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(
                        f"Unknown prediction type {noise_scheduler.config.prediction_type}"
                    )

                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = itertools.chain(
                        unet.parameters(), text_encoder.parameters()
                    )
                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                images = []
                progress_bar.update(1)
                global_step += 1
                if global_step % config.save_steps == 0:
                    """
                    TODO:保存模型
                    """

                    """
                        验证
                    """
                    if (
                        config.output_dir is not None
                        and global_step % config.validation_steps == 0
                    ):
                        """
                        TODO:验证当前模型效果
                        """

            """
                TODO:打印日志
            """

            if global_step >= config.max_train_steps:
                break

@pacman100
Copy link
Contributor

The reason for the new error is that you don't prepare vae and as such it isn't using amp. Does having below code work:

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
-    unet.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)
-    text_encoder.to(accelerator.device, dtype=weight_dtype)

@aihao2000
Copy link
Author

@pacman100 It now runs successfully. Thank you very much. Do unet and text_encoder not need ".to()" because they have training parameters? And do we need to "accelerate.prepare" regardless of whether the model contains training parameters or not? I'm a little confused about accelerate

@muellerzr
Copy link
Collaborator

@AisingioroHao0 unet and text_encoder were sent through accelerator.preapre, so they got converted to mixed precision and placed on the right device.

Re; training parameters, if it's fully just an eval model and never used, then you should just load the model on the device itself (so model.to(accelerator.device). However in the case of using mixed precision (like we have here), you can do accelerator.prepare_model(model, evaluation_mode=True) which will setup mixed precision and the device properly, without wrapping it in DDP etc.

@aihao2000
Copy link
Author

@muellerzr Thank you for your explanation.So in other words, compared to my implementation above. The most elegant implementation of using mixed_precision in the accelerate framework is: Models with training parameters are passed to "accelerator.prepare", and models without training parameters are passed to "accelerator.prepare_model(model,evaluation_mode=True)". The output of the model in training does not require any type conversion like ".to(dtype=weight_dtype)". Right? I wrote it from the training script in diffusers repository, samples. So what are the benefits of this implementation?

@muellerzr
Copy link
Collaborator

Correct!

The benefit here is doing it "right" when using Accelerate. So e.g. you don't have to worry about mixed precision. Which example was this from? I can pass it along to the diffusers team and we can look at using the right code. This feature is quite new

@aihao2000
Copy link
Author

@muellerzr Training code under diffusers/examples. Such as https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py

@wddwzwhhxx
Copy link

wddwzwhhxx commented Apr 17, 2024

In the above code, how is the model trained under fp16 specified? I didn't see mixed_precision taking effect. Is it through accelerate config? @muellerzr Looking forward to your reply, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants