Skip to content

Commit

Permalink
Merge pull request #121 from nateraw/diffusers-0.9.0
Browse files Browse the repository at this point in the history
Diffusers 0.9.0
  • Loading branch information
nateraw committed Dec 5, 2022
2 parents ffbea6d + c11d154 commit c1871c1
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 24 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ dmypy.json
dreams
images
run.py
test_outputs
test_outputs
examples/music
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,9 @@ Enjoy 🤗

You can also 4x upsample your images with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)!

First, you'll need to install it...
It's included when you pip install the latest version of `stable-diffusion-videos`!

```bash
pip install realesrgan
```

Then, you'll be able to use `upsample=True` in the `walk` function, like this:
You'll be able to use `upsample=True` in the `walk` function, like this:

```python
pipeline.walk(['a cat', 'a dog'], [234, 345], upsample=True)
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
transformers
diffusers==0.6.0
transformers>=4.21.0
diffusers==0.9.0
scipy
fire
gradio
librosa
av<10.0.0
realesrgan==0.2.5.0
4 changes: 0 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ def get_version() -> str:
with open("requirements.txt", "r") as f:
requirements = f.read().splitlines()

extras = {}
extras['realesrgan'] = ['realesrgan==0.2.5.0']

setup(
name="stable_diffusion_videos",
version=get_version(),
Expand All @@ -29,6 +26,5 @@ def get_version() -> str:
long_description_content_type="text/markdown",
license="Apache",
install_requires=requirements,
extras_require=extras,
packages=find_packages(),
)
2 changes: 1 addition & 1 deletion stable_diffusion_videos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,4 @@ def __dir__():
},
)

__version__ = "0.6.2"
__version__ = "0.7.0"
92 changes: 82 additions & 10 deletions stable_diffusion_videos/stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@
import json

import torch
from packaging import version
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.utils import deprecate, logging
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -166,9 +174,17 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()

Expand All @@ -186,8 +202,21 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)

if safety_checker is None:
logger.warn(
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)

if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
Expand All @@ -196,6 +225,33 @@ def __init__(
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)

if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)

is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -205,6 +261,9 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)


def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Expand All @@ -218,9 +277,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
Expand Down Expand Up @@ -361,7 +425,7 @@ def __call__(
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
elif text_embeddings is None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
Expand Down Expand Up @@ -524,6 +588,7 @@ def make_clip_frames(
image_file_ext: str = ".png",
T: np.ndarray = None,
skip: int = 0,
negative_prompt: str = None,
):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -559,6 +624,7 @@ def make_clip_frames(
eta=eta,
num_inference_steps=num_inference_steps,
output_type="pil" if not upsample else "numpy",
negative_prompt=negative_prompt,
)["images"]

for image in outputs:
Expand Down Expand Up @@ -588,6 +654,7 @@ def walk(
audio_start_sec: Optional[Union[int, float]] = None,
margin: Optional[float] = 1.0,
smooth: Optional[float] = 0.0,
negative_prompt: Optional[str] = None,
):
"""Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
video to interpolate to the intensity of the audio.
Expand Down Expand Up @@ -638,6 +705,8 @@ def walk(
Margin from librosa hpss to use for audio interpolation.
smooth (Optional[float], *optional*, defaults to 0.0):
Smoothness of the audio interpolation. 1.0 means linear interpolation.
negative_prompt (Optional[str], *optional*, defaults to None):
Optional negative prompt to use. Same across all prompts.
This function will create sub directories for each prompt and seed pair.
Expand Down Expand Up @@ -710,6 +779,7 @@ def walk(
width=width,
audio_filepath=audio_filepath,
audio_start_sec=audio_start_sec,
negative_prompt=negative_prompt,
),
indent=2,
sort_keys=False,
Expand All @@ -729,6 +799,7 @@ def walk(
width = data["width"]
audio_filepath = data["audio_filepath"]
audio_start_sec = data["audio_start_sec"]
negative_prompt = data.get("negative_prompt", None)

for i, (prompt_a, prompt_b, seed_a, seed_b, num_step) in enumerate(
zip(prompts, prompts[1:], seeds, seeds[1:], num_interpolation_steps)
Expand Down Expand Up @@ -771,7 +842,6 @@ def walk(
width=width,
upsample=upsample,
batch_size=batch_size,
skip=skip,
T=get_timesteps_arr(
audio_filepath,
offset=audio_offset,
Expand All @@ -782,6 +852,8 @@ def walk(
)
if audio_filepath
else None,
skip=skip,
negative_prompt=negative_prompt,
)
make_video_pyav(
save_path,
Expand All @@ -805,7 +877,7 @@ def walk(
sr=44100,
)

def embed_text(self, text):
def embed_text(self, text, negative_prompt=None):
"""Helper to embed some text"""
with torch.autocast("cuda"):
text_input = self.tokenizer(
Expand Down

0 comments on commit c1871c1

Please sign in to comment.