# Towards An <strong>E</strong>nd-to-<strong>E</strong>nd Framework for <strong>F</strong>low-<strong>G</strong>uided <strong>V</strong>ideo <strong>I</strong>npainting (CVPR 2022)

In this demo, you can try to inpaint an example video through our framework.

# Setup

In [None]:
#@title Setup environment and code (may take some time)

# Install Pytorch
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
# Install MMCV
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.5/index.html

# prepare code
import os
CODE_DIR = 'E2FGVI'
os.makedirs(f'./{CODE_DIR}')
!git clone https://github.com/MCG-NKU/E2FGVI.git $CODE_DIR
os.chdir(f'./{CODE_DIR}')


Looking in links: https://download.pytorch.org/whl/torch_stable.html
[31mERROR: Could not find a version that satisfies the requirement torch==1.5.1+cu101 (from versions: 1.11.0, 1.11.0+cpu, 1.11.0+cu102, 1.11.0+cu113, 1.11.0+cu115, 1.11.0+rocm4.3.1, 1.11.0+rocm4.5.2, 1.12.0, 1.12.0+cpu, 1.12.0+cu102, 1.12.0+cu113, 1.12.0+cu116, 1.12.0+rocm5.0, 1.12.0+rocm5.1.1, 1.12.1, 1.12.1+cpu, 1.12.1+cu102, 1.12.1+cu113, 1.12.1+cu116, 1.12.1+rocm5.0, 1.12.1+rocm5.1.1, 1.13.0, 1.13.0+cpu, 1.13.0+cu116, 1.13.0+cu117, 1.13.0+cu117.with.pypi.cudnn, 1.13.0+rocm5.1.1, 1.13.0+rocm5.2, 1.13.1, 1.13.1+cpu, 1.13.1+cu116, 1.13.1+cu117, 1.13.1+cu117.with.pypi.cudnn, 1.13.1+rocm5.1.1, 1.13.1+rocm5.2, 2.0.0, 2.0.0+cpu, 2.0.0+cpu.cxx11.abi, 2.0.0+cu117, 2.0.0+cu117.with.pypi.cudnn, 2.0.0+cu118, 2.0.0+rocm5.3, 2.0.0+rocm5.4.2, 2.0.1, 2.0.1+cpu, 2.0.1+cpu.cxx11.abi, 2.0.1+cu117, 2.0.1+cu117.with.pypi.cudnn, 2.0.1+cu118, 2.0.1+rocm5.3, 2.0.1+rocm5.4.2)[0m[31m
[0m[31mERROR: No matching distribution found for to

In [None]:
#@title Download model with PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import os

download_with_pydrive = True

class Downloader(object):
    def __init__(self, use_pydrive):
        self.use_pydrive = use_pydrive
        current_directory = os.getcwd()
        self.save_dir = os.path.join(os.path.dirname(current_directory), CODE_DIR, "release_model")
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        if self.use_pydrive:
            self.authenticate()

    def authenticate(self):
        auth.authenticate_user()
        gauth = GoogleAuth()
        gauth.credentials = GoogleCredentials.get_application_default()
        self.drive = GoogleDrive(gauth)

    def download_file(self, file_id, file_name):
        file_dst = f'{self.save_dir}/{file_name}'
        if os.path.exists(file_dst):
            print(f'{file_name} already exists!')
            return
        downloaded = self.drive.CreateFile({'id':file_id})
        downloaded.FetchMetadata(fetch_all=True)
        downloaded.GetContentFile(file_dst)

downloader = Downloader(download_with_pydrive)
path = {"id": "1tNJMTJ2gmWdIXJoHVi5-H504uImUiJW9", "name": "E2FGVI_CVPR22_models.zip"}
downloader.download_file(file_id=path["id"], file_name=path["name"])

os.chdir(downloader.save_dir)
!unzip E2FGVI_CVPR22_models.zip
os.chdir('..')

Archive:  E2FGVI_CVPR22_models.zip
  inflating: E2FGVI-CVPR22.pth       
  inflating: i3d_rgb_imagenet.pt     


# Define Utility Functions


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import animation

import cv2
from PIL import Image
import numpy as np
import importlib
import os
import argparse
from tqdm import tqdm
import torch

# rc('animation', html='jshtml')

from core.utils import to_tensors


# global variables
w, h = 432, 240
ref_length = 10  # ref_step
num_ref = -1
neighbor_stride = 5


# sample reference frames from the whole video
def get_ref_index(f, neighbor_ids, length):
    ref_index = []
    if num_ref == -1:
        for i in range(0, length, ref_length):
            if i not in neighbor_ids:
                ref_index.append(i)
    else:
        start_idx = max(0, f - ref_length * (num_ref//2))
        end_idx = min(length, f + ref_length * (num_ref//2))
        for i in range(start_idx, end_idx+1, ref_length):
            if i not in neighbor_ids:
                if len(ref_index) > num_ref:
                    break
                ref_index.append(i)
    return ref_index


# read frame-wise masks
def read_mask(mpath):
    masks = []
    mnames = os.listdir(mpath)
    mnames.sort()
    for mp in mnames:
        m = Image.open(os.path.join(mpath, mp))
        m = m.resize((w, h), Image.NEAREST)
        m = np.array(m.convert('L'))
        m = np.array(m > 0).astype(np.uint8)
        m = cv2.dilate(m, cv2.getStructuringElement(
            cv2.MORPH_CROSS, (3, 3)), iterations=4)
        masks.append(Image.fromarray(m*255))
    return masks


#  read frames from video
def read_frame_from_videos(video_path):
    vname = video_path
    frames = []
    lst = os.listdir(vname)
    lst.sort()
    fr_lst = [vname+'/'+name for name in lst]
    for fr in fr_lst:
        image = cv2.imread(fr)
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        frames.append(image.resize((w, h)))
    return frames

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')


Mounted at /content/drive


# Inpainting


In [None]:
# set up models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = importlib.import_module('model.e2fgvi')
model = net.InpaintGenerator().to(device)
ckpt_path = '/content/E2FGVI/release_model/E2FGVI-CVPR22.pth'
data = torch.load(ckpt_path, map_location=device)
model.load_state_dict(data)
print(f'Loading model from: {ckpt_path}')
model.eval()

# prepare dataset
video_path = '/content/drive/MyDrive/frames'
mask_path = '/content/drive/MyDrive/saved_masks'
print(f'Loading videos and masks from: {video_path}')
frames = read_frame_from_videos(video_path)


video_length = len(frames)
imgs = to_tensors()(frames).unsqueeze(0) * 2 - 1
frames = [np.array(f).astype(np.uint8) for f in frames]


masks = read_mask(mask_path)
binary_masks = [np.expand_dims((np.array(m) != 0).astype(np.uint8), 2)
                for m in masks]
masks = to_tensors()(masks).unsqueeze(0)
imgs, masks = imgs.to(device), masks.to(device)
comp_frames = [None] * video_length

# completing holes by e2fgvi
print(f'Start test...')
for f in tqdm(range(0, video_length, neighbor_stride)):
    neighbor_ids = [i for i in range(max(0, f-neighbor_stride), min(video_length, f+neighbor_stride+1))]
    ref_ids = get_ref_index(f, neighbor_ids, video_length)
    selected_imgs = imgs[:1, neighbor_ids+ref_ids, :, :, :]
    selected_masks = masks[:1, neighbor_ids+ref_ids, :, :, :]
    with torch.no_grad():
        masked_imgs = selected_imgs*(1-selected_masks)
        pred_img, _ = model(masked_imgs, len(neighbor_ids))

        pred_img = (pred_img + 1) / 2
        pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
        for i in range(len(neighbor_ids)):
            idx = neighbor_ids[i]
            img = np.array(pred_img[i]).astype(
                np.uint8)*binary_masks[idx] + frames[idx] * (1-binary_masks[idx])
            if comp_frames[idx] is None:
                comp_frames[idx] = img
            else:
                comp_frames[idx] = comp_frames[idx].astype(
                    np.float32)*0.5 + img.astype(np.float32)*0.5

NameError: ignored

## Show the inpainting video

In [None]:
height, width = 1920, 1080
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])

ax.axis('off')
imdata = ax.imshow(cv2.resize(comp_frames[0].astype(np.uint8), (width, height)), extent=[0, width, height, 0])

def update(idx):
    imdata.set_data(cv2.resize(comp_frames[idx].astype(np.uint8), (width, height)))

anim = animation.FuncAnimation(fig, update, frames=len(frames), interval=50)
from IPython.display import HTML
HTML(anim.to_html5_video())