Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyCharm/VSCode static type checking for dummy objects #1596

Merged
merged 11 commits into from
Dec 8, 2022
74 changes: 52 additions & 22 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
from .utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_inflect_available,
is_k_diffusion_available,
is_librosa_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
Expand All @@ -15,7 +17,12 @@
)


if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
Expand All @@ -29,14 +36,12 @@
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import (
AudioDiffusionPipeline,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah it's further down

DanceDiffusionPipeline,
DDIMPipeline,
DDPMPipeline,
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
Mel,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
Expand All @@ -60,15 +65,22 @@
VQDiffusionScheduler,
)
from .training_utils import EMAModel
else:
from .utils.dummy_pt_objects import * # noqa F403

if is_torch_available() and is_scipy_available():
from .schedulers import LMSDiscreteScheduler
else:
try:
if not (is_torch_available() and is_scipy_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
else:
from .schedulers import LMSDiscreteScheduler

if is_torch_available() and is_transformers_available():

try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
Expand All @@ -88,26 +100,43 @@
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403

if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .pipelines import StableDiffusionKDiffusionPipeline
else:
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .pipelines import StableDiffusionKDiffusionPipeline

if is_torch_available() and is_transformers_available() and is_onnx_available():
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
else:
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline,
)

try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
from .pipelines import AudioDiffusionPipeline, Mel

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
Expand All @@ -122,10 +151,11 @@
FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler,
)
else:
from .utils.dummy_flax_objects import * # noqa F403

if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipelines import FlaxStableDiffusionPipeline
48 changes: 38 additions & 10 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
Expand All @@ -8,7 +9,12 @@
)


if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
Expand All @@ -18,15 +24,21 @@
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
else:
from ..utils.dummy_pt_objects import * # noqa F403

if is_torch_available() and is_librosa_available():
from .audio_diffusion import AudioDiffusionPipeline, Mel
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_librosa_objects import * # noqa F403
else:
from ..utils.dummy_torch_and_librosa_objects import AudioDiffusionPipeline, Mel # noqa F403
from .audio_diffusion import AudioDiffusionPipeline, Mel

if is_torch_available() and is_transformers_available():
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline
Expand All @@ -48,7 +60,12 @@
)
from .vq_diffusion import VQDiffusionPipeline

if is_transformers_available() and is_onnx_available():
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
else:
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
Expand All @@ -57,8 +74,19 @@
StableDiffusionOnnxPipeline,
)

if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .stable_diffusion import StableDiffusionKDiffusionPipeline

if is_transformers_available() and is_flax_available():

try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .stable_diffusion import FlaxStableDiffusionPipeline
17 changes: 13 additions & 4 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ...utils import (
BaseOutput,
OptionalDependencyNotAvailable,
is_flax_available,
is_k_diffusion_available,
is_onnx_available,
Expand Down Expand Up @@ -44,12 +45,20 @@ class StableDiffusionPipelineOutput(BaseOutput):
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .safety_checker import StableDiffusionSafetyChecker

if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
else:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
else:
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline

if is_transformers_available() and is_torch_available() and is_k_diffusion_available():
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline

if is_transformers_available() and is_onnx_available():
Expand Down
24 changes: 16 additions & 8 deletions src/diffusers/pipelines/versatile_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
from ...utils import is_torch_available, is_transformers_available, is_transformers_version
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)


if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
else:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
else:
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
29 changes: 19 additions & 10 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
# limitations under the License.


from ..utils import is_flax_available, is_scipy_available, is_torch_available
from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available


if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
Expand All @@ -34,10 +39,13 @@
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
else:
from ..utils.dummy_pt_objects import * # noqa F403

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_objects import * # noqa F403
else:
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
Expand All @@ -46,11 +54,12 @@
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
else:
from ..utils.dummy_flax_objects import * # noqa F403


if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
try:
if not (is_torch_available() and is_scipy_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
else:
from .scheduling_lms_discrete import LMSDiscreteScheduler
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
USE_TF,
USE_TORCH,
DummyObject,
OptionalDependencyNotAvailable,
is_accelerate_available,
is_flax_available,
is_inflect_available,
Expand Down
30 changes: 0 additions & 30 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class AudioDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DanceDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]

Expand Down Expand Up @@ -257,21 +242,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class Mel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]

Expand Down