diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 78982d7959..2e746d8a9e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -9,16 +9,26 @@ from comfy.types import UnetWrapperFunction -def weight_decompose_scale(dora_scale, weight): +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): + dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32) + lora_diff *= alpha + weight_calc = weight + lora_diff.type(weight.dtype) weight_norm = ( - weight.transpose(0, 1) - .reshape(weight.shape[1], -1) + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) .norm(dim=1, keepdim=True) - .reshape(weight.shape[1], *[1] * (weight.dim() - 1)) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) .transpose(0, 1) ) - return (dora_scale / weight_norm).type(weight.dtype) + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * (weight_calc) + else: + weight[:] = weight_calc + return weight + def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -328,7 +338,7 @@ def __call__(self, weight): def calculate_weight(self, patches, weight, key): for p in patches: - alpha = p[0] + strength = p[0] v = p[1] strength_model = p[2] @@ -346,26 +356,31 @@ def calculate_weight(self, patches, weight, key): if patch_type == "diff": w1 = v[0] - if alpha != 0.0: + if strength != 0.0: if w1.shape != weight.shape: logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: - weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) + weight += strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) dora_scale = v[4] if v[2] is not None: - alpha *= v[2] / mat2.shape[0] + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: - weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) if dora_scale is not None: - weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + else: + weight += ((strength * alpha) * lora_diff).type(weight.dtype) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": @@ -402,19 +417,26 @@ def calculate_weight(self, patches, weight, key): if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) if v[2] is not None and dim is not None: - alpha *= v[2] / dim + alpha = v[2] / dim + else: + alpha = 1.0 try: - weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + else: + weight += ((strength * alpha) * lora_diff).type(weight.dtype) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": w1a = v[0] w1b = v[1] if v[2] is not None: - alpha *= v[2] / w1b.shape[0] + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + w2a = v[3] w2b = v[4] dora_scale = v[7] @@ -437,14 +459,18 @@ def calculate_weight(self, patches, weight, key): comfy.model_management.cast_to_device(w2b, weight.device, torch.float32)) try: - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + else: + weight += ((strength * alpha) * lora_diff).type(weight.dtype) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": if v[4] is not None: - alpha *= v[4] / v[0].shape[0] + alpha = v[4] / v[0].shape[0] + else: + alpha = 1.0 dora_scale = v[5] @@ -454,9 +480,11 @@ def calculate_weight(self, patches, weight, key): b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) try: - weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) if dora_scale is not None: - weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength) + else: + weight += ((strength * alpha) * lora_diff).type(weight.dtype) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: