Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 222 additions & 1 deletion comfy_extras/nodes_hunyuan.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import math
import nodes
import node_helpers
import torch
import re
import comfy.model_management
import comfy.patcher_extension


class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(s):
def INPUT_TYPES(cls):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
Expand All @@ -23,6 +26,216 @@ def encode(self, clip, bert, mt5xl):

return (clip.encode_from_tokens_scheduled(tokens), )

class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0

def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average

def normalized_guidance_apg(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]

if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average

if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor

v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)

normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update

return pred

class AdaptiveProjectedGuidance:
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum=None,
adaptive_projected_guidance_rescale: float = 15.0,
# eta: float = 1.0,
eta: float = 0.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__()

self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None

def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, is_first_step=False) -> torch.Tensor:

if is_first_step and self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)

pred = normalized_guidance_apg(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)

return pred

class HunyuanMixModeAPG:

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"has_quoted_text": ("BOOLEAN", ),

"guidance_scale": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),

"general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
"general_start_percent": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of general APG."}),

"ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
"ocr_start_percent": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of OCR APG."}),

}
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_mix_mode_apg"
CATEGORY = "sampling/custom_sampling/hunyuan"


def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_percent,
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_percent):

general_apg = AdaptiveProjectedGuidance(
guidance_scale=guidance_scale,
eta=general_eta,
adaptive_projected_guidance_rescale=general_norm_threshold,
adaptive_projected_guidance_momentum=general_momentum
)

ocr_apg = AdaptiveProjectedGuidance(
eta=ocr_eta,
adaptive_projected_guidance_rescale=ocr_norm_threshold,
adaptive_projected_guidance_momentum=ocr_momentum
)

m = model.clone()


model_sampling = m.model.model_sampling
general_start_t = model_sampling.percent_to_sigma(general_start_percent)
ocr_start_t = model_sampling.percent_to_sigma(ocr_start_percent)


def cfg_function(args):
sigma = args["sigma"].to(torch.float32)
is_first_step = math.isclose(sigma.item(), args['model_options']['transformer_options']['sample_sigmas'][0].item())
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]

sigma = sigma[:, None, None, None]


if not has_quoted_text:
if sigma[0] <= general_start_t:
modified_cond = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step)
return modified_cond * sigma
else:
if cond_scale > 1:
_ = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum
return uncond + (cond - uncond) * cond_scale
else:
if sigma[0] <= ocr_start_t:
modified_cond = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step)
return modified_cond * sigma
else:
if cond_scale > 1:
_ = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum
return uncond + (cond - uncond) * cond_scale

return cond

m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
return (m,)

class CLIPTextEncodeHunyuanDiTWithTextDetection:

@classmethod
def INPUT_TYPES(cls):
return {"required": {
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}

RETURN_TYPES = ("CONDITIONING", "BOOLEAN")
RETURN_NAMES = ("conditioning", "has_quoted_text")
FUNCTION = "encode"

CATEGORY = "advanced/conditioning/hunyuan"

def detect_quoted_text(self, text):
"""Detect quoted text in the prompt"""
text_prompt_texts = []

# Patterns to match different quote styles
pattern_quote_double = r'\"(.*?)\"'
pattern_quote_chinese_single = r'‘(.*?)’'
pattern_quote_chinese_double = r'“(.*?)”'

matches_quote_double = re.findall(pattern_quote_double, text)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)

text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)

return len(text_prompt_texts) > 0

def encode(self, clip, text):
tokens = clip.tokenize(text)
has_quoted_text = self.detect_quoted_text(text)

conditioning = clip.encode_from_tokens_scheduled(tokens)

return (conditioning, has_quoted_text)

class EmptyHunyuanLatentVideo:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -151,8 +364,16 @@ def execute(self, positive, negative, latent, noise_augmentation):
return (positive, negative, out_latent)



NODE_DISPLAY_NAME_MAPPINGS = {
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
}

NODE_CLASS_MAPPINGS = {
"HunyuanMixModeAPG": HunyuanMixModeAPG,
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
Expand Down
Loading