Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions diffsynth_engine/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def convert(self, state_dict):
return state_dict


class WanLowNoiseLoRAConverter(WanLoRAConverter):
def convert(self, state_dict):
return {"dit2": super().convert(state_dict)["dit"]}


class WanVideoPipeline(BasePipeline):
lora_converter = WanLoRAConverter()

Expand Down Expand Up @@ -144,8 +149,24 @@ def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, sav
)
super().load_loras(lora_list, fused, save_original_weight)

def load_loras_low_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
assert self.dit2 is not None, "low noise LoRA can only be loaded when low noise model (dit2) is initialized"
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
"load LoRA is not allowed when tensor parallel is enabled; "
"set tp_degree=None or tp_degree=1 during pipeline initialization"
)
assert not (self.config.use_fsdp and fused), (
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
"either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
)
self.lora_converter = WanLowNoiseLoRAConverter()
super().load_loras(lora_list, fused, save_original_weight)
self.lora_converter = WanLoRAConverter()

def unload_loras(self):
self.dit.unload_loras()
if self.dit2 is not None:
self.dit2.unload_loras()
self.text_encoder.unload_loras()

def get_default_fps(self) -> int:
Expand Down