diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index b28145ae561d..5b16d58961dc 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -119,11 +119,11 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5) @torch.no_grad() -def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False): +def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False, device=None): if image3 is None: image3 = image1 padder = InputPadder(image1.shape, padding_factor=8) - image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + image1, image2 = padder.pad(image1[None].to(device), image2[None].to(device)) results_dict = flow_model( image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True ) @@ -307,6 +307,7 @@ def __init__( feature_extractor: CLIPImageProcessor, image_encoder=None, requires_safety_checker: bool = True, + device=None, ): super().__init__( vae, @@ -320,6 +321,7 @@ def __init__( image_encoder, requires_safety_checker, ) + self.to(device) if safety_checker is None and requires_safety_checker: logger.warning( @@ -374,7 +376,7 @@ def __init__( attention_type="swin", ffn_dim_expansion=4, num_transformer_layers=6, - ).to("cuda") + ).to(self.device) checkpoint = torch.utils.model_zoo.load_url( "https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth", @@ -928,13 +930,13 @@ def __call__( prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32) warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( - self.flow_model, first_image, image[0], first_result, False + self.flow_model, first_image, image[0], first_result, False, self.device ) blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( - self.flow_model, prev_image[0], image[0], prev_result, False + self.flow_model, prev_image[0], image[0], prev_result, False, self.device ) blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)