Skip to content

Commit

Permalink
autoquant experiments
Browse files Browse the repository at this point in the history
Summary: 2.928 -> 2.887

Test Plan: sh run_sd.sh

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Mar 5, 2024
1 parent 8a163ef commit 176e85f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 50 deletions.
6 changes: 6 additions & 0 deletions collated_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pipeline_cls,ckpt_id,bf16,sdpa,fused_qkv_projections,upcast_vae,batch_size,num_inference_steps,compile_unet,compile_vae,compile_mode,change_comp_config,do_quant,time (secs),memory (gbs),actual_gpu_memory (gbs),tag
StableDiffusionXLPipeline,stabilityai/stable-diffusion-xl-base-1.0,False,False,False,False,1,30,False,False,,False,,8.605,16.827,79.154,
StableDiffusionXLPipeline,stabilityai/stable-diffusion-xl-base-1.0,True,True,False,False,1,30,True,True,max-autotune,True,,3.038,8.328,79.154,
StableDiffusionXLPipeline,stabilityai/stable-diffusion-xl-base-1.0,True,True,True,False,1,30,True,True,max-autotune,True,,2.990,10.529,79.154,
StableDiffusionXLPipeline,stabilityai/stable-diffusion-xl-base-1.0,True,True,True,False,1,30,True,True,max-autotune,True,autoquant,2.887,9.665,79.154,
StableDiffusionXLPipeline,stabilityai/stable-diffusion-xl-base-1.0,True,True,True,False,1,30,True,True,max-autotune,True,int8dynamic,2.928,7.866,79.154,
9 changes: 5 additions & 4 deletions experiment-scripts/run_sd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

# From diffusion-fast source directory.

python run_benchmark.py --no_sdpa --no_bf16 && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --change_comp_config && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --change_comp_config --enable_fused_projections && \
python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --enable_fused_projections --do_quant "int8dynamic" --change_comp_config && \
# python run_benchmark.py --no_sdpa --no_bf16 && \
# python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --change_comp_config && \
# python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --change_comp_config --enable_fused_projections && \
# python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --enable_fused_projections --do_quant "int8dynamic" --change_comp_config && \
# python run_benchmark.py --compile_unet --compile_mode=max-autotune --compile_vae --do_quant "autoquant" --change_comp_config --enable_fused_projections && \
python prepare_results.py --plot_title "SDXL, Batch Size: 1, Steps: 30" --final_csv_filename "collated_results.csv"
5 changes: 4 additions & 1 deletion run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def main(args) -> dict:
do_quant=args.do_quant,
compile_mode=args.compile_mode,
change_comp_config=args.change_comp_config,
prompt=args.prompt,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size
)

# Warmup.
Expand All @@ -46,7 +49,7 @@ def main(args) -> dict:

time = benchmark_fn(run_inference, pipeline, args) # in seconds.
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.

print(f"time {time}, memory {memory}")
data_dict = generate_csv_dict(
pipeline_cls=str(pipeline.__class__.__name__),
args=args,
Expand Down
97 changes: 52 additions & 45 deletions utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_woqtensors,
swap_conv2d_1x1_to_linear,
change_linears_to_autoquantizable,
change_autoquantizable_to_quantized,
)

from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler


PROMPT = "ghibli style, a fantasy landscape with castles"

torch._dynamo.config.cache_size_limit = 100000

def dynamic_quant_filter_fn(mod, *args):
return (
Expand Down Expand Up @@ -58,6 +59,9 @@ def load_pipeline(
do_quant: bool,
compile_mode: str,
change_comp_config: bool,
prompt: str="",
num_inference_steps: int=1,
num_images_per_prompt: int=1,
):
"""Loads the SDXL pipeline."""

Expand Down Expand Up @@ -95,57 +99,60 @@ def load_pipeline(
pipe.vae.set_default_attn_processor()

pipe = pipe.to("cuda")
pipe.set_progress_bar_config(disable=True)

if change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

if compile_unet:
if do_quant:
pipe.unet.to(memory_format=torch.channels_last)
print("Compile UNet.")
pipe.vae.to(memory_format=torch.channels_last)
print("Applying Quantization")
swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
if compile_mode == "max-autotune" and change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

if do_quant:
print("Apply quantization to UNet.")
if do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.unet)
elif do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.unet)
elif do_quant == "int8dynamic":
apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {do_quant}.")
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

if do_quant == "autoquant":
with torch.no_grad():
hold = torch._dynamo.config.automatic_dynamic_shapes
torch._dynamo.config.automatic_dynamic_shapes = False
change_linears_to_autoquantizable(pipe.unet)
change_linears_to_autoquantizable(pipe.vae)
# run model to record shapes
pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt,
)
change_autoquantizable_to_quantized(pipe.unet, error_on_unseen=False)
change_autoquantizable_to_quantized(pipe.vae, error_on_unseen=False)
torch._dynamo.config.automatic_dynamic_shapes = hold
torch._dynamo.reset()
elif do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.unet)
change_linear_weights_to_int4_woqtensors(pipe.vae)
elif do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.unet)
change_linear_weights_to_int8_woqtensors(pipe.vae)
elif do_quant == "int8dynamic":
apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {do_quant}.")

if compile_unet or do_quant:
pipe.unet.to(memory_format=torch.channels_last)
print("Compile UNet.")
pipe.unet = torch.compile(pipe.unet, mode=compile_mode, fullgraph=True)

if compile_vae:
if compile_vae or do_quant:
pipe.vae.to(memory_format=torch.channels_last)
print("Compile VAE.")
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)

if compile_mode == "max-autotune" and change_comp_config:
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

if do_quant:
print("Apply quantization to VAE.")
if do_quant == "int4weightonly":
change_linear_weights_to_int4_woqtensors(pipe.vae)
elif do_quant == "int8weightonly":
change_linear_weights_to_int8_woqtensors(pipe.vae)
elif do_quant == "int8dynamic":
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
else:
raise ValueError(f"Unknown do_quant value: {do_quant}.")
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

pipe.vae.decode = torch.compile(pipe.vae.decode, mode=compile_mode, fullgraph=True)

pipe.set_progress_bar_config(disable=True)
return pipe

0 comments on commit 176e85f

Please sign in to comment.