Skip to content

Commit

Permalink
Add PatchModelAddDownscale (Kohya Deep Shrink) node.
Browse files Browse the repository at this point in the history
By adding a downscale to the unet in the first timesteps this node lets
you generate images at higher resolutions with less consistency issues.
  • Loading branch information
comfyanonymous committed Nov 16, 2023
1 parent 7ea6bb0 commit bd07ad1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
45 changes: 45 additions & 0 deletions comfy_extras/nodes_model_downscale.py
@@ -0,0 +1,45 @@
import torch

class PatchModelAddDownscale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "_for_testing"

def patch(self, model, block_number, downscale_factor, start_percent, end_percent):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item()
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item()

def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False)
return h

def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False)
return h, hsp

m = model.clone()
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )

NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}

NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
}
1 change: 1 addition & 0 deletions nodes.py
Expand Up @@ -1799,6 +1799,7 @@ def init_custom_nodes():
"nodes_custom_sampler.py",
"nodes_hypertile.py",
"nodes_model_advanced.py",
"nodes_model_downscale.py",
]

for node_file in extras_files:
Expand Down

4 comments on commit bd07ad1

@ATERUBER
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\comfy_extras\nodes_model_downscale.py", line 34, in patch
m.set_model_input_block_patch(input_block_patch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ModelPatcher' object has no attribute 'set_model_input_block_patch'

@ltdrdata
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_model_input_block_patch

It seems your ComfyUI is outdated verison.
Did you manually download and place this .py file?

@SLAPaper
Copy link

@SLAPaper SLAPaper commented on bd07ad1 Nov 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the block_number be different for sd1.5 and sdxl model?

@traugdor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the start and end percents are supposed to do but I am struggling getting this node to perform like a standard hires fix.

Please sign in to comment.