@@ -59,6 +59,135 @@ def check_same_shape(tensor_list):
5959 return all (shape == shapes [0 ] for shape in shapes [1 :])
6060
6161
62+ class PipelineEfficiencyFunctionTesterMixin :
63+ """
64+ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
65+ It provides a set of common tests for PyTorch pipeline that inherit from EfficiencyMixin, e.g. vae_slicing, vae_tiling, freeu, etc.
66+ """
67+
68+ def test_vae_slicing (self ):
69+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
70+ components = self .get_dummy_components ()
71+ # components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)
72+ pipe = self .pipeline_class (** components )
73+ pipe = pipe .to (device )
74+ pipe .set_progress_bar_config (disable = None )
75+
76+ image_count = 4
77+
78+ inputs = self .get_dummy_inputs (device )
79+ inputs ["prompt" ] = [inputs ["prompt" ]] * image_count
80+ output_1 = pipe (** inputs )
81+
82+ # make sure sliced vae decode yields the same result
83+ pipe .enable_vae_slicing ()
84+ inputs = self .get_dummy_inputs (device )
85+ inputs ["prompt" ] = [inputs ["prompt" ]] * image_count
86+ output_2 = pipe (** inputs )
87+
88+ # there is a small discrepancy at image borders vs. full batch decode
89+ assert np .abs (output_2 .images .flatten () - output_1 .images .flatten ()).max () < 3e-3
90+
91+ def test_vae_tiling (self ):
92+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
93+ components = self .get_dummy_components ()
94+
95+ # make sure here that pndm scheduler skips prk
96+ components ["safety_checker" ] = None
97+ pipe = self .pipeline_class (** components )
98+ pipe = pipe .to (device )
99+ pipe .set_progress_bar_config (disable = None )
100+
101+ prompt = "A painting of a squirrel eating a burger"
102+
103+ # Test that tiled decode at 512x512 yields the same result as the non-tiled decode
104+ generator = torch .Generator (device = device ).manual_seed (0 )
105+ output_1 = pipe ([prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 2 , output_type = "np" )
106+
107+ # make sure tiled vae decode yields the same result
108+ pipe .enable_vae_tiling ()
109+ generator = torch .Generator (device = device ).manual_seed (0 )
110+ output_2 = pipe ([prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 2 , output_type = "np" )
111+
112+ assert np .abs (output_2 .images .flatten () - output_1 .images .flatten ()).max () < 5e-1
113+
114+ # test that tiled decode works with various shapes
115+ shapes = [(1 , 4 , 73 , 97 ), (1 , 4 , 97 , 73 ), (1 , 4 , 49 , 65 ), (1 , 4 , 65 , 49 )]
116+ for shape in shapes :
117+ zeros = torch .zeros (shape ).to (device )
118+ pipe .vae .decode (zeros )
119+
120+ def test_freeu_enabled (self ):
121+ components = self .get_dummy_components ()
122+ pipe = self .pipeline_class (** components )
123+ pipe = pipe .to (torch_device )
124+ pipe .set_progress_bar_config (disable = None )
125+
126+ prompt = "hey"
127+ output = pipe (prompt , num_inference_steps = 1 , output_type = "np" , generator = torch .manual_seed (0 )).images
128+
129+ pipe .enable_freeu (s1 = 0.9 , s2 = 0.2 , b1 = 1.2 , b2 = 1.4 )
130+ output_freeu = pipe (prompt , num_inference_steps = 1 , output_type = "np" , generator = torch .manual_seed (0 )).images
131+
132+ assert not np .allclose (
133+ output [0 , - 3 :, - 3 :, - 1 ], output_freeu [0 , - 3 :, - 3 :, - 1 ]
134+ ), "Enabling of FreeU should lead to different results."
135+
136+ def test_freeu_disabled (self ):
137+ components = self .get_dummy_components ()
138+ pipe = self .pipeline_class (** components )
139+ pipe = pipe .to (torch_device )
140+ pipe .set_progress_bar_config (disable = None )
141+
142+ prompt = "hey"
143+ output = pipe (prompt , num_inference_steps = 1 , output_type = "np" , generator = torch .manual_seed (0 )).images
144+
145+ pipe .enable_freeu (s1 = 0.9 , s2 = 0.2 , b1 = 1.2 , b2 = 1.4 )
146+ pipe .disable_freeu ()
147+
148+ freeu_keys = {"s1" , "s2" , "b1" , "b2" }
149+ for upsample_block in pipe .unet .up_blocks :
150+ for key in freeu_keys :
151+ assert getattr (upsample_block , key ) is None , f"Disabling of FreeU should have set { key } to None."
152+
153+ output_no_freeu = pipe (prompt , num_inference_steps = 1 , output_type = "np" , generator = torch .manual_seed (0 )).images
154+
155+ assert np .allclose (
156+ output [0 , - 3 :, - 3 :, - 1 ], output_no_freeu [0 , - 3 :, - 3 :, - 1 ]
157+ ), "Disabling of FreeU should lead to results similar to the default pipeline results."
158+
159+ def test_fused_qkv_projections (self ):
160+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
161+ components = self .get_dummy_components ()
162+ pipe = self .pipeline_class (** components )
163+ pipe = pipe .to (device )
164+ pipe .set_progress_bar_config (disable = None )
165+
166+ inputs = self .get_dummy_inputs (device )
167+ image = pipe (** inputs ).images
168+ original_image_slice = image [0 , - 3 :, - 3 :, - 1 ]
169+
170+ pipe .fuse_qkv_projections ()
171+ inputs = self .get_dummy_inputs (device )
172+ image = pipe (** inputs ).images
173+ image_slice_fused = image [0 , - 3 :, - 3 :, - 1 ]
174+
175+ pipe .unfuse_qkv_projections ()
176+ inputs = self .get_dummy_inputs (device )
177+ image = pipe (** inputs ).images
178+ image_slice_disabled = image [0 , - 3 :, - 3 :, - 1 ]
179+
180+ assert np .allclose (
181+ original_image_slice , image_slice_fused , atol = 1e-2 , rtol = 1e-2
182+ ), "Fusion of QKV projections shouldn't affect the outputs."
183+ assert np .allclose (
184+ image_slice_fused , image_slice_disabled , atol = 1e-2 , rtol = 1e-2
185+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
186+ assert np .allclose (
187+ original_image_slice , image_slice_disabled , atol = 1e-2 , rtol = 1e-2
188+ ), "Original outputs should match when fused QKV projections are disabled."
189+
190+
62191class PipelineLatentTesterMixin :
63192 """
64193 This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
0 commit comments