From 1e700facded75c0db61729c237500506cfd8981e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Oct 2023 12:06:03 +0200 Subject: [PATCH] fix: torch.compile() for lora conv --- src/diffusers/models/lora.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index cc8e3e231e2b..a777bb93e1c8 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -164,7 +164,10 @@ def forward(self, hidden_states, scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + original_outputs = F.conv2d( + hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + return original_outputs + (scale * self.lora_layer(hidden_states)) class LoRACompatibleLinear(nn.Linear):