Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy as np
import PIL.Image
import pytest
import requests_mock
import safetensors.torch
import torch
Expand Down Expand Up @@ -62,10 +63,7 @@
)
from diffusers.pipelines.pipeline_utils import _get_pipeline_class
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import (
CONFIG_NAME,
WEIGHTS_NAME,
)
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, is_transformers_version
from diffusers.utils.torch_utils import is_compiled_module

from ..testing_utils import (
Expand Down Expand Up @@ -584,6 +582,7 @@ def test_download_variants_with_sharded_checkpoints(self):
assert not any(f.endswith(unexpected_ext) for f in files)
assert all(variant in f for f in model_files if f.endswith(model_ext) and variant is not None)

@pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=True)
def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
Expand Down Expand Up @@ -690,6 +689,7 @@ def test_download_bin_variant_does_not_exist_for_model(self):
)
assert "Error no file name" in str(error_context.exception)

@pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=True)
def test_local_save_load_index(self):
prompt = "hello"
for variant in [None, "fp16"]:
Expand Down Expand Up @@ -1584,6 +1584,7 @@ def test_save_safe_serialization(self):
assert pipeline.scheduler is not None
assert pipeline.feature_extractor is not None

@pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=True)
def test_no_pytorch_download_when_doing_safetensors(self):
# by default we don't download
with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -1603,6 +1604,7 @@ def test_no_pytorch_download_when_doing_safetensors(self):
# pytorch does not
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))

@pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=True)
def test_no_safetensors_download_when_doing_pytorch(self):
use_safetensors = False

Expand Down
Loading