Skip to content

Commit

Permalink
Merge pull request #1056 from kohya-ss/dev
Browse files Browse the repository at this point in the history
fix vram usage in LoRA training
  • Loading branch information
kohya-ss committed Jan 17, 2024
2 parents e6b15c7 + 0395a35 commit d2a99a1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum

## Change History

### Jan 17, 2024 / 2024/1/17: v0.8.1

- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`).
- Text Encoders were not moved to CPU.
- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053)

- LoRA 等の学習スクリプト(`train_network.py``sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。
- Text Encoder が GPU に保持されたままになっていました。
- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。

### Jan 15, 2024 / 2024/1/15: v0.8.0

- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade).
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def cache_text_encoder_outputs_if_needed(
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)

def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
Expand Down
12 changes: 6 additions & 6 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
for t_enc in text_encoders:
t_enc.to(accelerator.device)
t_enc.to(accelerator.device, dtype=weight_dtype)

def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids = batch["input_ids"].to(accelerator.device)
Expand Down Expand Up @@ -278,6 +278,7 @@ def train(self, args):
accelerator.wait_for_everyone()

# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
Expand Down Expand Up @@ -394,8 +395,7 @@ def train(self, args):
for t_enc in text_encoders:
t_enc.requires_grad_(False)

# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
else:
Expand All @@ -407,8 +407,8 @@ def train(self, args):
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set

network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)

if args.gradient_checkpointing:
Expand Down Expand Up @@ -685,7 +685,7 @@ def train(self, args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
Expand Down

1 comment on commit d2a99a1

@FurkanGozukara
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this consider training only single Text Encoder?

On OneTrainer I used same settings. Training only Text Encoder 1 and not 2. It has substantial lower VRAM usage.

Please sign in to comment.