Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDXL] [JAX] Txt2img Inference slower than Blog Article #2058

Open
hsebik opened this issue May 13, 2024 · 0 comments
Open

[SDXL] [JAX] Txt2img Inference slower than Blog Article #2058

hsebik opened this issue May 13, 2024 · 0 comments

Comments

@hsebik
Copy link

hsebik commented May 13, 2024

In the sdxl jax blog, example inference in the blog code is given and blog says with that inference 4 images is generated about 2 s with Cloud TPU v5e-4. However, Even though I am using same code and google cloud same tpu with ubuntu 22.04 base software version and python 3.10, I am generating images 3.23 seconds. Could you help me about what am I missing to achieve fast inference?

absl-py==2.1.0
attrs==21.2.0
Automat==20.2.0
Babel==2.8.0
bcrypt==3.2.0
blinker==1.4
certifi==2020.6.20
chardet==4.0.0
charset-normalizer==3.1.0
chex==0.1.86
click==8.0.3
cloud-init==23.1.2
colorama==0.4.4
command-not-found==0.3
configobj==5.0.6
constantly==15.1.0
cryptography==3.4.8
Cython==0.29.28
dbus-python==1.2.18
diffusers==0.27.2
distlib==0.3.6
distro==1.7.0
distro-info===1.1build1
etils==1.7.0
filelock==3.12.0
flax==0.8.3
fsspec==2024.3.1
httplib2==0.20.2
huggingface-hub==0.22.2
hyperlink==21.0.0
idna==3.3
importlib-metadata==4.6.4
importlib_resources==6.4.0
incremental==21.3.0
jax==0.4.26
jaxlib==0.4.26
jeepney==0.7.1
Jinja2==3.0.3
jsonpatch==1.32
jsonpointer==2.0
jsonschema==3.2.0
keyring==23.5.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
libtpu-nightly==0.1.dev20240403
markdown-it-py==3.0.0
MarkupSafe==2.0.1
mdurl==0.1.2
ml-dtypes==0.4.0
more-itertools==8.10.0
msgpack==1.0.8
nest-asyncio==1.6.0
netifaces==0.11.0
numpy==1.26.4
oauthlib==3.2.0
opt-einsum==3.3.0
optax==0.2.2
orbax-checkpoint==0.5.10
packaging==21.3
pexpect==4.8.0
pillow==10.3.0
platformdirs==3.5.0
protobuf==5.26.1
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.1
Pygments==2.17.2
PyGObject==3.42.1
PyHamcrest==2.0.2
PyJWT==2.3.0
pyOpenSSL==21.0.0
pyparsing==2.4.7
pyrsistent==0.18.1
pyserial==3.5
python-apt==2.4.0+ubuntu1
python-debian===0.1.43ubuntu1
python-magic==0.4.24
pytz==2022.1
PyYAML==5.4.1
regex==2024.4.28
requests==2.29.0
rich==13.7.1
safetensors==0.4.3
scipy==1.13.0
SecretStorage==3.3.1
service-identity==18.1.0
six==1.16.0
sos==4.4
ssh-import-id==5.11
systemd-python==234
tensorstore==0.1.58
tokenizers==0.19.1
toolz==0.12.1
tqdm==4.66.2
transformers==4.40.1
Twisted==22.1.0
typing_extensions==4.11.0
ubuntu-advantage-tools==8001
ufw==0.36.1
unattended-upgrades==0.1
urllib3==1.26.5
virtualenv==20.23.0
wadllib==1.3.6
zipp==1.0.0
zope.interface==5.4.0

inference_time

Show best practices for SDXL JAX

import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
import time


pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", split_head_dim=True
)


scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state


default_prompt = "high-quality photo of a baby dolphin playing in a pool and wearing a party hat"
default_neg_prompt = "illustration, low-quality"
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = 25


def tokenize_prompt(prompt, neg_prompt):
		prompt_ids = pipeline.prepare_inputs(prompt)
		neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
		return prompt_ids, neg_prompt_ids


NUM_DEVICES = jax.device_count()

# Model parameters don't change during inference,
# so we only need to replicate them once.
p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
		p_prompt_ids = replicate(prompt_ids)
		p_neg_prompt_ids = replicate(neg_prompt_ids)
		rng = jax.random.PRNGKey(seed)
		rng = jax.random.split(rng, NUM_DEVICES)
		return p_prompt_ids, p_neg_prompt_ids, rng


def generate(
		prompt,
		negative_prompt,
		seed=default_seed,
		guidance_scale=default_guidance_scale,
		num_inference_steps=default_num_steps,
):
		prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
		prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
		images = pipeline(
    		prompt_ids,
    		p_params,
    		rng,
    		num_inference_steps=num_inference_steps,
    		neg_prompt_ids=neg_prompt_ids,
    		guidance_scale=guidance_scale,
    		jit=True,
		).images

		# convert the images to PIL
		images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
		return pipeline.numpy_to_pil(np.array(images))


start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")


for i in range(100):
		start = time.time()
		prompt = "llama in ancient Greece, oil on canvas"
		neg_prompt = "cartoon, illustration, animation"
		images = generate(prompt, neg_prompt)
		print(f"Inference in {time.time() - start}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant