diff --git a/examples/pytorch/diffusion_model/diffusers/flux/run_quant.sh b/examples/pytorch/diffusion_model/diffusers/flux/run_quant.sh index d13c3bcf470..8b72db4d77c 100644 --- a/examples/pytorch/diffusion_model/diffusers/flux/run_quant.sh +++ b/examples/pytorch/diffusion_model/diffusers/flux/run_quant.sh @@ -36,12 +36,13 @@ function init_params { # run_tuning function run_tuning { + dataset_location=${dataset_location:="captions_source.tsv"} tuned_checkpoint=${tuned_checkpoint:="saved_results"} if [ "${topology}" = "flux_fp8" ]; then - extra_cmd="--scheme FP8 --iters 0 --dataset captions_source.tsv --quantize" + extra_cmd="--scheme FP8 --iters 0 --dataset ${dataset_location} --quantize" elif [ "${topology}" = "flux_mxfp8" ]; then - extra_cmd="--scheme MXFP8 --iters 1000 --dataset captions_source.tsv --quantize" + extra_cmd="--scheme MXFP8 --iters 1000 --dataset ${dataset_location} --quantize" fi python3 main.py \