From d5078fe4a8c9f823a680c660fa1fac2292899ce4 Mon Sep 17 00:00:00 2001 From: standardAI Date: Thu, 14 Mar 2024 22:16:34 +0300 Subject: [PATCH] Use PyTorch's conventional inplace functions --- tests/pipelines/controlnet/test_controlnet.py | 4 ++-- tests/pipelines/controlnet/test_controlnet_img2img.py | 2 +- tests/pipelines/controlnet/test_controlnet_inpaint.py | 2 +- tests/pipelines/controlnet/test_controlnet_sdxl.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 114a36b37f74..cc0696bae99b 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -302,7 +302,7 @@ def get_dummy_components(self): def init_weights(m): if isinstance(m, torch.nn.Conv2d): - torch.nn.init.normal(m.weight) + torch.nn.init.normal_(m.weight) m.bias.data.fill_(1.0) controlnet1 = ControlNetModel( @@ -519,7 +519,7 @@ def get_dummy_components(self): def init_weights(m): if isinstance(m, torch.nn.Conv2d): - torch.nn.init.normal(m.weight) + torch.nn.init.normal_(m.weight) m.bias.data.fill_(1.0) controlnet = ControlNetModel( diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 89e2b3803dee..46821a51b70c 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -210,7 +210,7 @@ def get_dummy_components(self): def init_weights(m): if isinstance(m, torch.nn.Conv2d): - torch.nn.init.normal(m.weight) + torch.nn.init.normal_(m.weight) m.bias.data.fill_(1.0) controlnet1 = ControlNetModel( diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 67e0da4de9cd..32ae8d125ab1 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -273,7 +273,7 @@ def get_dummy_components(self): def init_weights(m): if isinstance(m, torch.nn.Conv2d): - torch.nn.init.normal(m.weight) + torch.nn.init.normal_(m.weight) m.bias.data.fill_(1.0) controlnet1 = ControlNetModel( diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index c82ce6c39cca..e06f228a7d30 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -503,7 +503,7 @@ def get_dummy_components(self): def init_weights(m): if isinstance(m, torch.nn.Conv2d): - torch.nn.init.normal(m.weight) + torch.nn.init.normal_(m.weight) m.bias.data.fill_(1.0) controlnet1 = ControlNetModel( @@ -708,7 +708,7 @@ def get_dummy_components(self): def init_weights(m): if isinstance(m, torch.nn.Conv2d): - torch.nn.init.normal(m.weight) + torch.nn.init.normal_(m.weight) m.bias.data.fill_(1.0) controlnet = ControlNetModel(