Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
adodge committed Mar 7, 2023
1 parent 85597f8 commit 88b8744
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 45 deletions.
26 changes: 22 additions & 4 deletions backend/src/nodes/impl/stable_diffusion/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
from comfy import CLIPModel, Conditioning, LatentImage, StableDiffusionModel, VAEModel
from comfy.clip import CLIPModel
from comfy.conditioning import Conditioning
from comfy.latent_image import CropMethod, LatentImage, UpscaleMethod
from comfy.stable_diffusion import (
BuiltInCheckpointConfigName,
CheckpointConfig,
Sampler,
Scheduler,
StableDiffusionModel,
load_checkpoint,
)
from comfy.vae import VAEModel

__all__ = [
"StableDiffusionModel",
"VAEModel",
"CLIPModel",
"LatentImage",
"Conditioning",
"CropMethod",
"LatentImage",
"UpscaleMethod",
"BuiltInCheckpointConfigName",
"CheckpointConfig",
"Sampler",
"Scheduler",
"StableDiffusionModel",
"load_checkpoint",
"VAEModel",
]
16 changes: 8 additions & 8 deletions backend/src/nodes/nodes/stable_diffusion/k_sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import comfy

from ...impl.stable_diffusion.types import (
Conditioning,
LatentImage,
Sampler,
Scheduler,
StableDiffusionModel,
)
from ...node_base import NodeBase, group
Expand Down Expand Up @@ -43,12 +43,12 @@ def __init__(self):
),
SliderInput("Steps", minimum=1, default=20, maximum=150),
EnumInput(
comfy.Sampler,
default_value=comfy.Sampler.SAMPLE_EULER,
Sampler,
default_value=Sampler.SAMPLE_EULER,
),
EnumInput(
comfy.Scheduler,
default_value=comfy.Scheduler.NORMAL,
Scheduler,
default_value=Scheduler.NORMAL,
),
SliderInput(
"CFG Scale",
Expand Down Expand Up @@ -77,8 +77,8 @@ def run(
denoising_strength: float,
seed: int,
steps: int,
sampler: comfy.Sampler,
scheduler: comfy.Scheduler,
sampler: Sampler,
scheduler: Scheduler,
cfg_scale: float,
) -> LatentImage:
img = model.sample(
Expand Down
16 changes: 8 additions & 8 deletions backend/src/nodes/nodes/stable_diffusion/k_sampler_advanced.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import comfy

from ...impl.stable_diffusion.types import (
Conditioning,
LatentImage,
Sampler,
Scheduler,
StableDiffusionModel,
)
from ...node_base import NodeBase, group
Expand Down Expand Up @@ -43,12 +43,12 @@ def __init__(self):
),
SliderInput("Steps", minimum=1, default=20, maximum=150),
EnumInput(
comfy.Sampler,
default_value=comfy.Sampler.SAMPLE_EULER,
Sampler,
default_value=Sampler.SAMPLE_EULER,
),
EnumInput(
comfy.Scheduler,
default_value=comfy.Scheduler.NORMAL,
Scheduler,
default_value=Scheduler.NORMAL,
),
SliderInput(
"CFG Scale",
Expand Down Expand Up @@ -91,8 +91,8 @@ def run(
denoising_strength: float,
seed: int,
steps: int,
sampler: comfy.Sampler,
scheduler: comfy.Scheduler,
sampler: Sampler,
scheduler: Scheduler,
cfg_scale: float,
start_at: int,
end_at: int,
Expand Down
16 changes: 7 additions & 9 deletions backend/src/nodes/nodes/stable_diffusion/latent_upscale.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import comfy.latent_image

from ...impl.stable_diffusion.types import LatentImage
from ...impl.stable_diffusion.types import CropMethod, LatentImage, UpscaleMethod
from ...node_base import NodeBase
from ...node_factory import NodeFactory
from ...properties.inputs import EnumInput, SliderInput
Expand All @@ -19,12 +17,12 @@ def __init__(self):
self.inputs = [
LatentImageInput(),
EnumInput(
comfy.UpscaleMethod,
default_value=comfy.UpscaleMethod.BILINEAR,
UpscaleMethod,
default_value=UpscaleMethod.BILINEAR,
),
EnumInput(
comfy.CropMethod,
default_value=comfy.CropMethod.DISABLED,
CropMethod,
default_value=CropMethod.DISABLED,
),
SliderInput(
"width",
Expand Down Expand Up @@ -57,8 +55,8 @@ def __init__(self):
def run(
self,
latent_image: LatentImage,
upscale_method: comfy.latent_image.UpscaleMethod,
crop_method: comfy.latent_image.CropMethod,
upscale_method: UpscaleMethod,
crop_method: CropMethod,
width: int,
height: int,
) -> LatentImage:
Expand Down
17 changes: 10 additions & 7 deletions backend/src/nodes/nodes/stable_diffusion/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import os
from typing import Tuple

import comfy

from ...impl.stable_diffusion.types import CLIPModel, StableDiffusionModel, VAEModel
from ...impl.stable_diffusion.types import (
BuiltInCheckpointConfigName,
CheckpointConfig,
CLIPModel,
StableDiffusionModel,
VAEModel,
load_checkpoint,
)
from ...node_base import NodeBase
from ...node_factory import NodeFactory
from ...properties.inputs import CkptFileInput
Expand Down Expand Up @@ -45,11 +50,9 @@ def run(
assert os.path.isfile(path), f"Path {path} is not a file"

# TODO load V2 models, maybe auto-detect
config = comfy.CheckpointConfig.from_built_in(
comfy.BuiltInCheckpointConfigName.V1
)
config = CheckpointConfig.from_built_in(BuiltInCheckpointConfigName.V1)

sd, clip, vae = comfy.load_checkpoint(
sd, clip, vae = load_checkpoint(
config=config, checkpoint_filepath=path, embedding_directory=None
)

Expand Down
5 changes: 2 additions & 3 deletions backend/src/nodes/nodes/stable_diffusion/vae_decode.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import comfy
import cv2
import numpy as np

from ...impl.stable_diffusion.types import LatentImage
from ...impl.stable_diffusion.types import LatentImage, VAEModel
from ...node_base import NodeBase
from ...node_factory import NodeFactory
from ...properties.inputs.stable_diffusion_inputs import LatentImageInput, VAEModelInput
Expand All @@ -30,7 +29,7 @@ def __init__(self):
self.icon = "PyTorch"
self.sub = "Input & Output"

def run(self, vae: comfy.VAEModel, latent_image: LatentImage) -> np.ndarray:
def run(self, vae: VAEModel, latent_image: LatentImage) -> np.ndarray:
img = vae.decode(latent_image)
arr = np.array(img)
arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
Expand Down
4 changes: 2 additions & 2 deletions backend/src/nodes/nodes/stable_diffusion/vae_encode.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import comfy
import numpy as np
from PIL import Image

from ...impl.stable_diffusion.types import LatentImage, VAEModel
from ...node_base import NodeBase
from ...node_factory import NodeFactory
from ...properties.inputs import ImageInput
Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(self):
self.icon = "PyTorch"
self.sub = "Input & Output"

def run(self, vae: comfy.VAEModel, image: np.ndarray) -> np.ndarray:
def run(self, vae: VAEModel, image: np.ndarray) -> LatentImage:
img = _array_to_image(image)
latent = vae.encode(img)
return latent
6 changes: 2 additions & 4 deletions backend/src/nodes/nodes/stable_diffusion/vae_masked_encode.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import comfy
import numpy as np
from PIL import Image

from ...impl.stable_diffusion.types import LatentImage, VAEModel
from ...node_base import NodeBase
from ...node_factory import NodeFactory
from ...properties.inputs import ImageInput
Expand Down Expand Up @@ -31,9 +31,7 @@ def __init__(self):
self.icon = "PyTorch"
self.sub = "Input & Output"

def run(
self, vae: comfy.VAEModel, image: np.ndarray, mask: np.ndarray
) -> np.ndarray:
def run(self, vae: VAEModel, image: np.ndarray, mask: np.ndarray) -> LatentImage:
img = Image.fromarray(image)
mask_img = Image.fromarray(mask)
latent = vae.masked_encode(img, mask_img)
Expand Down

0 comments on commit 88b8744

Please sign in to comment.