Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMA: refactor to support CPU offload, step-skipping, and DiT models | pixart: reduce max grad norm by default, forcibly #521

Merged
merged 19 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,27 @@ A lot of settings are instead set through the [dataloader config](/documentation

## 🛠 Advanced Optimizations

### `--use_ema`

- **What**: Keeping an exponential moving average of your weights over the models' training lifetime is like periodically back-merging the model into itself.
- **Why**: It can improve training stability at the cost of more system resources, and a slight increase in training runtime.

## `--ema_device`

- **Choices**: `cpu`, `accelerator`, default: `cpu`
- **What**: Place the EMA weights on the accelerator instead of CPU.
- **Why**: The default location of CPU for EMA weights might result in a substantial slowdown on some systems. However, `--ema_cpu_only` will override this value if provided.

### `--ema_cpu_only`

- **What**: Keeps EMA weights on the CPU. The default behaviour is to move the EMA weights to the GPU before updating them.
- **Why**: Moving the EMA weights to the GPU is unnecessary, as the update on CPU can be nearly just as quick. However, some systems may experience a substantial slowdown, so EMA weights will remain on GPU by default.

### `--ema_update_interval`

- **What**: Reduce the update interval of your EMA shadow parameters.
- **Why**: Updating the EMA weights on every step could be an unnecessary waste of resources. Providing `--ema_update_interval=100` will update the EMA weights only once every 100 optimizer steps.

### `--gradient_accumulation_steps`

- **What**: Number of update steps to accumulate before performing a backward/update pass, essentially splitting the work over multiple batches to save memory at the cost of a higher training runtime.
Expand Down Expand Up @@ -179,7 +200,9 @@ A lot of settings are instead set through the [dataloader config](/documentation
### `--resume_from_checkpoint`

- **What**: Specifies if and from where to resume training.
- **Why**: Allows you to continue training from a saved state, either manually specified or the latest available. A checkpoint is composed of a `unet` and optionally, an `ema_unet`. The `unet` may be dropped into any Diffusers layout SDXL model, allowing it to be used as a normal model would.
- **Why**: Allows you to continue training from a saved state, either manually specified or the latest available. A checkpoint is composed of a `unet` and optionally, a `unet_ema` subfolder. The `unet` may be dropped into any Diffusers layout SDXL model, allowing it to be used as a normal model would.

> ℹ️ Transformer models such as PixArt, SD3, or Hunyuan, use the `transformer` and `transformer_ema` subfolder names.

---

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ Simply point your base model to a Stable Diffusion 3 checkpoint and set `STABLE_

## Hardware Requirements

EMA (exponential moving average) weights are a memory-heavy affair, but provide fantastic results at the end of training. Without it, training can still be done, but more care must be taken not to drastically change the model leading to "catastrophic forgetting".
EMA (exponential moving average) weights are a memory-heavy affair, but provide fantastic results at the end of training. Options like `--ema_cpu_only` can improve this situation by loading EMA weights onto the CPU and then keeping them there.

Without EMA, more care must be taken not to drastically change the model leading to "catastrophic forgetting" through the use of regularisation data.

### GPU vendors

Expand Down
47 changes: 47 additions & 0 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,43 @@ def parse_args(input_args=None):
action="store_true",
help="Whether to use EMA (exponential moving average) model.",
)
parser.add_argument(
"--ema_device",
choices=["cpu", "accelerator"],
default="cpu",
help=(
"The device to use for the EMA model. If set to 'accelerator', the EMA model will be placed on the accelerator."
" This provides the fastest EMA update times, but is not ultimately necessary for EMA to function."
),
)
parser.add_argument(
"--ema_cpu_only",
action="store_true",
default=False,
help=(
"When using EMA, the shadow model is moved to the accelerator before we update its parameters."
" When provided, this option will disable the moving of the EMA model to the accelerator."
" This will save a lot of VRAM at the cost of a lot of time for updates. It is recommended to also supply"
" --ema_update_interval to reduce the number of updates to eg. every 100 steps."
),
)
parser.add_argument(
"--ema_foreach_disable",
action="store_true",
default=True,
help=(
"By default, we use torch._foreach functions for updating the shadow parameters, which should be fast."
" When provided, this option will disable the foreach methods and use vanilla EMA updates."
),
)
parser.add_argument(
"--ema_update_interval",
type=int,
default=None,
help=(
"The number of optimization steps between EMA updates. If not provided, EMA network will update on every step."
),
)
parser.add_argument(
"--ema_decay",
type=float,
Expand Down Expand Up @@ -1646,4 +1683,14 @@ def parse_args(input_args=None):
"Disabling Compel long-prompt weighting for SD3 inference, as it does not support Stable Diffusion 3."
)
args.disable_compel = True

if args.use_ema and args.ema_cpu_only:
args.ema_device = "cpu"

if args.pixart_sigma and not args.i_know_what_i_am_doing:
if args.max_grad_norm != 0.01:
logger.warning(
f"PixArt Sigma requires --max_grad_norm=0.01 to prevent model collapse. Overriding value. Set this value manually to disable this warning."
)
args.max_grad_norm = 0.01
return args
10 changes: 5 additions & 5 deletions helpers/legacy/sd_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def register_file_hooks(
text_encoder,
text_encoder_cls,
use_deepspeed_optimizer,
ema_unet=None,
ema_model=None,
controlnet=None,
):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
Expand Down Expand Up @@ -127,7 +127,7 @@ def save_model_hook(models, weights, output_dir):
weights.pop()

if args.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "ema_unet"))
ema_model.save_pretrained(os.path.join(output_dir, "ema_model"))

def load_model_hook(models, input_dir):
training_state_path = os.path.join(input_dir, "training_state.json")
Expand Down Expand Up @@ -184,10 +184,10 @@ def load_model_hook(models, input_dir):

if args.use_ema:
load_model = EMAModel.from_pretrained(
os.path.join(input_dir, "ema_unet"), UNet2DConditionModel
os.path.join(input_dir, "ema_model"), UNet2DConditionModel
)
ema_unet.load_state_dict(load_model.state_dict())
ema_unet.to(accelerator.device)
ema_model.load_state_dict(load_model.state_dict())
ema_model.to(accelerator.device)
del load_model
if args.model_type == "full":
return_exception = False
Expand Down
6 changes: 0 additions & 6 deletions helpers/sdxl/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from diffusers.models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder
Expand Down Expand Up @@ -905,8 +903,6 @@ def upcast_vae(self):
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
Expand Down Expand Up @@ -2367,8 +2363,6 @@ def upcast_vae(self):
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
Expand Down
48 changes: 33 additions & 15 deletions helpers/sdxl/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os, logging, shutil, torch, json
from safetensors import safe_open
from safetensors.torch import save_file

from tqdm import tqdm

logger = logging.getLogger("SDXLSaveHook")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL") or "INFO")
Expand All @@ -24,6 +24,14 @@
except ImportError:
logger.error("This release requires the latest version of Diffusers.")

try:
from diffusers.models import PixArtTransformer2DModel
except Exception as e:
logger.error(
f"Can not load Pixart Sigma model class. This release requires the latest version of Diffusers: {e}"
)
raise e


def merge_safetensors_files(directory):
json_file_name = "diffusion_pytorch_model.safetensors.index.json"
Expand Down Expand Up @@ -71,7 +79,7 @@ def __init__(
args,
unet,
transformer,
ema_unet,
ema_model,
text_encoder_1,
text_encoder_2,
text_encoder_3,
Expand All @@ -84,9 +92,20 @@ def __init__(
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
self.text_encoder_3 = text_encoder_3
self.ema_unet = ema_unet
self.ema_model = ema_model
self.accelerator = accelerator
self.use_deepspeed_optimizer = use_deepspeed_optimizer
self.ema_model_cls = None
self.ema_model_subdir = None
if unet is not None:
self.ema_model_subdir = "unet_ema"
self.ema_model_cls = UNet2DConditionModel
if transformer is not None:
self.ema_model_subdir = "transformer_ema"
if self.args.sd3:
self.ema_model_cls = SD3Transformer2DModel
elif self.args.pixart_sigma:
self.ema_model_cls = PixArtTransformer2DModel

def save_model_hook(self, models, weights, output_dir):
# Write "training_state.json" to the output directory containing the training state
Expand Down Expand Up @@ -167,7 +186,11 @@ def save_model_hook(self, models, weights, output_dir):
os.makedirs(temporary_dir, exist_ok=True)

if self.args.use_ema:
self.ema_unet.save_pretrained(os.path.join(temporary_dir, "unet_ema"))
tqdm.write("Saving EMA model")
self.ema_model.save_pretrained(
os.path.join(temporary_dir, self.ema_model_subdir),
max_shard_size="10GB",
)

if self.unet is not None:
sub_dir = "unet"
Expand All @@ -176,7 +199,9 @@ def save_model_hook(self, models, weights, output_dir):
if self.args.controlnet:
sub_dir = "controlnet"
for model in models:
model.save_pretrained(os.path.join(temporary_dir, sub_dir))
model.save_pretrained(
os.path.join(temporary_dir, sub_dir), max_shard_size="10GB"
)
merge_safetensors_files(os.path.join(temporary_dir, sub_dir))
if weights:
weights.pop() # Pop the last weight
Expand Down Expand Up @@ -295,10 +320,10 @@ def load_model_hook(self, models, input_dir):

if self.args.use_ema:
load_model = EMAModel.from_pretrained(
os.path.join(input_dir, "unet_ema"), UNet2DConditionModel
os.path.join(input_dir, self.ema_model_subdir), self.ema_model_cls
)
self.ema_unet.load_state_dict(load_model.state_dict())
self.ema_unet.to(self.accelerator.device)
self.ema_model.load_state_dict(load_model.state_dict())
self.ema_model.to(self.accelerator.device)
del load_model
if self.args.model_type == "full":
return_exception = False
Expand Down Expand Up @@ -338,13 +363,6 @@ def load_model_hook(self, models, input_dir):
)
elif self.args.pixart_sigma:
# load pixart sigma checkpoint
try:
from diffusers.models import PixArtTransformer2DModel
except Exception as e:
logger.error(
f"Can not load Pixart Sigma model class. This release requires the latest version of Diffusers: {e}"
)
raise e
load_model = PixArtTransformer2DModel.from_pretrained(
input_dir, subfolder="transformer"
)
Expand Down
Loading
Loading