Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)


def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
offload_stream = None
xfer_dest = None

Expand Down Expand Up @@ -170,10 +170,10 @@ def to_dequant(tensor, dtype):
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
(want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if orig.dtype == dtype and len(fns) == 0:
if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
Expand All @@ -194,7 +194,7 @@ def to_dequant(tensor, dtype):
return weight, bias, (offload_stream, device if signature is not None else None, None)


def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
Expand All @@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device)

if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)

if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
Expand Down Expand Up @@ -850,8 +850,8 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs):
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)

def forward_comfy_cast_weights(self, input, compute_dtype=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
Expand Down Expand Up @@ -881,8 +881,7 @@ def forward(self, input, *args, **kwargs):
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)


output = self.forward_comfy_cast_weights(input, compute_dtype)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))

# Reshape output back to 3D if input was 3D
if reshaped_3d:
Expand Down
45 changes: 37 additions & 8 deletions comfy_api_nodes/apis/recraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,6 @@ def get_v3_substyles(style_v3: str, include_none=True) -> list[str]:
}


class RecraftModel(str, Enum):
recraftv3 = 'recraftv3'
recraftv2 = 'recraftv2'


class RecraftImageSize(str, Enum):
res_1024x1024 = '1024x1024'
res_1365x1024 = '1365x1024'
Expand All @@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
res_1707x1024 = '1707x1024'


RECRAFT_V4_SIZES = [
"1024x1024",
"1536x768",
"768x1536",
"1280x832",
"832x1280",
"1216x896",
"896x1216",
"1152x896",
"896x1152",
"832x1344",
"1280x896",
"896x1280",
"1344x768",
"768x1344",
]

RECRAFT_V4_PRO_SIZES = [
"2048x2048",
"3072x1536",
"1536x3072",
"2560x1664",
"1664x2560",
"2432x1792",
"1792x2432",
"2304x1792",
"1792x2304",
"1664x2688",
"1434x1024",
"1024x1434",
"2560x1792",
"1792x2560",
]


class RecraftColorObject(BaseModel):
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')

Expand All @@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):

class RecraftImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The text prompt describing the image to generate')
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: int = Field(..., description='The number of images to generate')
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
model: str = Field(...)
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: int | None = Field(None, description="Seed for video generation")
# text_layout


class RecraftReturnedObject(BaseModel):
Expand Down
Loading
Loading