Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/research_projects/intel_opts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ We accelereate the fine-tuning for textual inversion with Intel Extension for Py
## Accelerating the inference for Stable Diffusion using Bfloat16

We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.
```bash
pip install diffusers transformers accelerate scipy safetensors

export KMP_BLOCKTIME=1
export KMP_SETTINGS=1
export KMP_AFFINITY=granularity=fine,compact,1,0

# Intel OpenMP
export OMP_NUM_THREADS=< Cores to use >
export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libiomp5.so
# Jemalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.
export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libjemalloc.so
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:9000000000"

# Launch with default DDIM
numactl --membind <node N> -C <cpu list> python python inference_bf16.py
# Launch with DPMSolverMultistepScheduler
numactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm

```

## Accelerating the inference for Stable Diffusion using INT8

Expand Down
67 changes: 37 additions & 30 deletions examples/research_projects/intel_opts/inference_bf16.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,56 @@
import argparse

import intel_extension_for_pytorch as ipex
import torch
from PIL import Image

from diffusers import StableDiffusionPipeline


def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline

w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size

for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
parser = argparse.ArgumentParser("Stable Diffusion script with intel optimization", add_help=False)
parser.add_argument("--dpm", action="store_true", help="Enable DPMSolver or not")
parser.add_argument("--steps", default=None, type=int, help="Num inference steps")
args = parser.parse_args()


prompt = ["a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings"]
batch_size = 8
prompt = prompt * batch_size

device = "cpu"
prompt = "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings"

model_id = "path-to-your-trained-model"
model = StableDiffusionPipeline.from_pretrained(model_id)
model = model.to(device)
pipe = StableDiffusionPipeline.from_pretrained(model_id)
if args.dpm:
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)

# to channels last
model.unet = model.unet.to(memory_format=torch.channels_last)
model.vae = model.vae.to(memory_format=torch.channels_last)
model.text_encoder = model.text_encoder.to(memory_format=torch.channels_last)
model.safety_checker = model.safety_checker.to(memory_format=torch.channels_last)
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)
if pipe.requires_safety_checker:
pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last)

# optimize with ipex
model.unet = ipex.optimize(model.unet.eval(), dtype=torch.bfloat16, inplace=True)
model.vae = ipex.optimize(model.vae.eval(), dtype=torch.bfloat16, inplace=True)
model.text_encoder = ipex.optimize(model.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
model.safety_checker = ipex.optimize(model.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
sample = torch.randn(2, 4, 64, 64)
timestep = torch.rand(1) * 999
encoder_hidden_status = torch.randn(2, 77, 768)
input_example = (sample, timestep, encoder_hidden_status)
try:
pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
except Exception:
pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True)
pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True)
pipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
if pipe.requires_safety_checker:
pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)

# compute
seed = 666
generator = torch.Generator(device).manual_seed(seed)
generate_kwargs = {"generator": generator}
if args.steps is not None:
generate_kwargs["num_inference_steps"] = args.steps

with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
images = model(prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator).images
image = pipe(prompt, **generate_kwargs).images[0]

# save image
grid = image_grid(images, rows=2, cols=4)
grid.save(model_id + ".png")
# save image
image.save("generated.png")