diff --git a/docker/diffusers-flax-tpu/Dockerfile b/docker/diffusers-flax-tpu/Dockerfile index 4b38dd13e461..c0722139ca55 100644 --- a/docker/diffusers-flax-tpu/Dockerfile +++ b/docker/diffusers-flax-tpu/Dockerfile @@ -24,7 +24,7 @@ ENV PATH="/opt/venv/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) # follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container RUN python3 -m pip install --no-cache-dir --upgrade pip uv && \ - python3 -m uv pip install --no-cache-dir \ + python3 -m pip install --no-cache-dir \ "jax[tpu]>=0.2.16,!=0.3.2" \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \ python3 -m uv pip install --upgrade --no-cache-dir \