Skip to content

Commit

Permalink
[Utils] Add deprecate function and move testing_utils under utils (#659)
Browse files Browse the repository at this point in the history
* [Utils] Add deprecate function

* up

* up

* uP

* up

* up

* up

* up

* uP

* up

* fix

* up

* move to deprecation utils file

* fix

* fix

* fix more
  • Loading branch information
patrickvonplaten committed Oct 3, 2022
1 parent 1070e1a commit f1484b8
Show file tree
Hide file tree
Showing 26 changed files with 305 additions and 121 deletions.
4 changes: 2 additions & 2 deletions examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@


def pytest_addoption(parser):
from diffusers.testing_utils import pytest_addoption_shared
from diffusers.utils.testing_utils import pytest_addoption_shared

pytest_addoption_shared(parser)


def pytest_terminal_summary(terminalreporter):
from diffusers.testing_utils import pytest_terminal_summary_main
from diffusers.utils.testing_utils import pytest_terminal_summary_main

make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
Expand Down
2 changes: 1 addition & 1 deletion examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import List

from accelerate.utils import write_basic_config
from diffusers.testing_utils import slow
from diffusers.utils import slow


logging.basicConfig(level=logging.DEBUG)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import warnings
from typing import Callable, List, Optional, Union

import torch
Expand All @@ -10,7 +9,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

Expand Down Expand Up @@ -59,15 +58,15 @@ def __init__(
super().__init__()

if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import warnings
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -12,7 +11,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

Expand Down Expand Up @@ -71,15 +70,15 @@ def __init__(
super().__init__()

if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import warnings
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -13,7 +12,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

Expand Down Expand Up @@ -86,15 +85,15 @@ def __init__(
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")

if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
Expand Down
30 changes: 11 additions & 19 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
# and https://github.com/hojonathanho/diffusion

import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin


Expand Down Expand Up @@ -122,12 +121,12 @@ def __init__(
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
Expand Down Expand Up @@ -175,17 +174,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""

offset = self.config.steps_offset

if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead.",
DeprecationWarning,
)

offset = kwargs["offset"]
deprecated_offset = deprecate(
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
)
offset = deprecated_offset or self.config.steps_offset

self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin


Expand Down Expand Up @@ -115,12 +114,12 @@ def __init__(
clip_sample: bool = True,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# limitations under the License.


import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin


Expand Down Expand Up @@ -89,12 +88,12 @@ def __init__(
s_max: float = 50,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

# setable values
self.num_inference_steps: int = None
Expand Down
15 changes: 7 additions & 8 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand All @@ -22,7 +21,7 @@
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin


Expand Down Expand Up @@ -77,12 +76,12 @@ def __init__(
trained_betas: Optional[np.ndarray] = None,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
Expand Down
28 changes: 11 additions & 17 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

import math
import warnings
from typing import Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput


Expand Down Expand Up @@ -102,12 +102,12 @@ def __init__(
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
Expand Down Expand Up @@ -155,16 +155,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""

offset = self.config.steps_offset

if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead."
)

offset = kwargs["offset"]
deprecated_offset = deprecate(
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
)
offset = deprecated_offset or self.config.steps_offset

self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
Expand Down
21 changes: 8 additions & 13 deletions src/diffusers/schedulers/scheduling_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput


Expand Down Expand Up @@ -78,12 +77,12 @@ def __init__(
correct_steps: int = 1,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
deprecate(
"tensor_format",
"0.5.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

# setable values
self.timesteps = None
Expand Down Expand Up @@ -139,11 +138,7 @@ def get_adjacent_sigma(self, timesteps, t):
)

def set_seed(self, seed):
warnings.warn(
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
" generator instead.",
DeprecationWarning,
)
deprecate("set_seed", "0.5.0", "Please consider passing a generator instead.")
torch.manual_seed(seed)

def step_pred(
Expand Down
Loading

0 comments on commit f1484b8

Please sign in to comment.