From 4c13037f9d8a44878158b131334fffc91edc1dc2 Mon Sep 17 00:00:00 2001 From: pravdomil <2387356+pravdomil@users.noreply.github.com> Date: Tue, 16 Jan 2024 13:41:46 +0100 Subject: [PATCH 1/4] use self.device --- examples/community/rerender_a_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index b28145ae561d..5372adcb8c3b 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -374,7 +374,7 @@ def __init__( attention_type="swin", ffn_dim_expansion=4, num_transformer_layers=6, - ).to("cuda") + ).to(self.device) checkpoint = torch.utils.model_zoo.load_url( "https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth", From 9d0a6cd7959a3cb852b077e93d91606dcd252cd8 Mon Sep 17 00:00:00 2001 From: pravdomil <2387356+pravdomil@users.noreply.github.com> Date: Tue, 16 Jan 2024 21:34:47 +0100 Subject: [PATCH 2/4] use device --- examples/community/rerender_a_video.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index 5372adcb8c3b..5b16d58961dc 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -119,11 +119,11 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5) @torch.no_grad() -def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False): +def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False, device=None): if image3 is None: image3 = image1 padder = InputPadder(image1.shape, padding_factor=8) - image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + image1, image2 = padder.pad(image1[None].to(device), image2[None].to(device)) results_dict = flow_model( image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True ) @@ -307,6 +307,7 @@ def __init__( feature_extractor: CLIPImageProcessor, image_encoder=None, requires_safety_checker: bool = True, + device=None, ): super().__init__( vae, @@ -320,6 +321,7 @@ def __init__( image_encoder, requires_safety_checker, ) + self.to(device) if safety_checker is None and requires_safety_checker: logger.warning( @@ -928,13 +930,13 @@ def __call__( prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32) warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( - self.flow_model, first_image, image[0], first_result, False + self.flow_model, first_image, image[0], first_result, False, self.device ) blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( - self.flow_model, prev_image[0], image[0], prev_result, False + self.flow_model, prev_image[0], image[0], prev_result, False, self.device ) blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) From 2c68043a86d321087fe6027e4180f2989d8eb4ba Mon Sep 17 00:00:00 2001 From: pravdomil <2387356+pravdomil@users.noreply.github.com> Date: Fri, 16 Feb 2024 12:13:24 +0100 Subject: [PATCH 3/4] fix --- examples/community/rerender_a_video.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index 5b16d58961dc..bc71a76fd23d 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -307,7 +307,6 @@ def __init__( feature_extractor: CLIPImageProcessor, image_encoder=None, requires_safety_checker: bool = True, - device=None, ): super().__init__( vae, From 23c1face8e0f91472e6eab17bdb70ea553de6b7a Mon Sep 17 00:00:00 2001 From: pravdomil <2387356+pravdomil@users.noreply.github.com> Date: Thu, 7 Mar 2024 21:57:01 +0100 Subject: [PATCH 4/4] fix --- examples/community/rerender_a_video.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index bc71a76fd23d..5b16d58961dc 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -307,6 +307,7 @@ def __init__( feature_extractor: CLIPImageProcessor, image_encoder=None, requires_safety_checker: bool = True, + device=None, ): super().__init__( vae,