Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device support improvements (MPS) #1054

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# XXX dropped option: hypernetwork training

import argparse
import gc
import math
import os
from multiprocessing import Value
Expand All @@ -11,6 +10,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -158,9 +158,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down
3 changes: 2 additions & 1 deletion finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.device_utils import get_preferred_device

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()


IMAGE_SIZE = 384
Expand Down
4 changes: 2 additions & 2 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from transformers.generation.utils import GenerationMixin

import library.train_util as train_util
from library.device_utils import get_preferred_device


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()

PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
Expand Down
4 changes: 3 additions & 1 deletion finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import library.model_util as model_util
import library.train_util as train_util

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from library.device_utils import get_preferred_device

DEVICE = get_preferred_device()

IMAGE_TRANSFORMS = transforms.Compose(
[
Expand Down
9 changes: 4 additions & 5 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import numpy as np
import torch

from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -888,8 +889,7 @@ def __call__(
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
Expand Down Expand Up @@ -1047,8 +1047,7 @@ def __call__(
if vae_batch_size >= batch_size:
image = self.vae.decode(latents).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
Expand Down Expand Up @@ -2325,7 +2324,7 @@ def __getattr__(self, item):
scheduler.config.clip_sample = True

# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
device = get_preferred_device()

# custom pipelineをコピったやつを生成する
if args.vae_slices:
Expand Down
34 changes: 34 additions & 0 deletions library/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import functools
import gc

import torch

try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
HAS_CUDA = False

try:
HAS_MPS = torch.backends.mps.is_available()
except Exception:
HAS_MPS = False


def clean_memory():
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()


@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
return device
5 changes: 2 additions & 3 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import math
import os
from typing import Optional
Expand All @@ -8,6 +7,7 @@
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.device_utils import clean_memory
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline

TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
Expand Down Expand Up @@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
Expand Down
11 changes: 5 additions & 6 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
import os
Expand Down Expand Up @@ -67,6 +66,7 @@

# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel

# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
Expand Down Expand Up @@ -2278,8 +2278,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent

# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()


def cache_batch_text_encoder_outputs(
Expand Down Expand Up @@ -3919,6 +3918,7 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
)
print("accelerator device:", accelerator.device)
return accelerator


Expand Down Expand Up @@ -4005,8 +4005,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()

return text_encoder, vae, unet, load_stable_diffusion_format
Expand Down Expand Up @@ -4815,7 +4814,7 @@ def sample_images_common(

# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
clean_memory()

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
Expand Down
4 changes: 3 additions & 1 deletion networks/lora_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from transformers import CLIPTextModel
import torch

from library.device_utils import get_preferred_device


def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
Expand Down Expand Up @@ -476,7 +478,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_preferred_device()

parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
Expand Down
3 changes: 2 additions & 1 deletion networks/lora_interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import library.model_util as model_util
import lora
from library.device_utils import get_preferred_device

TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = get_preferred_device()


def interrogate(args):
Expand Down
12 changes: 5 additions & 7 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import torch

from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -640,8 +641,7 @@ def __call__(
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
Expand Down Expand Up @@ -780,8 +780,7 @@ def __call__(
if vae_batch_size >= batch_size:
image = self.vae.decode(latents.to(self.vae.dtype)).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
Expand All @@ -796,8 +795,7 @@ def __call__(
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()

if output_type == "pil":
# image = self.numpy_to_pil(image)
Expand Down Expand Up @@ -1497,7 +1495,7 @@ def __getattr__(self, item):
# scheduler.config.clip_sample = True

# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
device = get_preferred_device()

# custom pipelineをコピったやつを生成する
if args.vae_slices:
Expand Down
3 changes: 2 additions & 1 deletion sdxl_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch

from library.device_utils import get_preferred_device
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -85,7 +86,7 @@ def get_timestep_embedding(x, outdim):
guidance_scale = 7
seed = None # 1

DEVICE = "cuda"
DEVICE = get_preferred_device()
DTYPE = torch.float16 # bfloat16 may work

parser = argparse.ArgumentParser()
Expand Down
9 changes: 3 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# training with captions

import argparse
import gc
import math
import os
from multiprocessing import Value
Expand All @@ -11,6 +10,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -252,9 +252,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -407,8 +405,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
Expand Down
9 changes: 3 additions & 6 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward

import argparse
import gc
import json
import math
import os
Expand All @@ -15,6 +14,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -164,9 +164,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -291,8 +289,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
Expand Down
Loading