論文<br>
https://arxiv.org/abs/2204.02663<br>
<br>
GitHub<br>
https://github.com/MCG-NKU/E2FGVI<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/E2FGVI_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## GitHubからコード取得

In [None]:
%cd /content

!git clone https://github.com/MCG-NKU/E2FGVI.git

## ライブラリのインストール

In [None]:
%cd /content

# Install Pytorch
!pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html 
# Install MMCV
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.5/index.html
# Install gdown
!pip install --upgrade gdown

## 学習済みモデルのダウンロード

In [None]:
%cd /content/E2FGVI

!gdown 'https://drive.google.com/uc?id=1tNJMTJ2gmWdIXJoHVi5-H504uImUiJW9'
!unzip E2FGVI_CVPR22_models.zip

# ライブラリのインポート

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation

import cv2
from PIL import Image
import numpy as np
import importlib
import os
import argparse
from tqdm import tqdm
import torch

from core.utils import to_tensors

# 関数定義

In [None]:
# global variables
w, h = 432, 240
ref_length = 10  # ref_step
num_ref = -1
neighbor_stride = 5


# sample reference frames from the whole video 
def get_ref_index(f, neighbor_ids, length):
    ref_index = []
    if num_ref == -1:
        for i in range(0, length, ref_length):
            if i not in neighbor_ids:
                ref_index.append(i)
    else:
        start_idx = max(0, f - ref_length * (num_ref//2))
        end_idx = min(length, f + ref_length * (num_ref//2))
        for i in range(start_idx, end_idx+1, ref_length):
            if i not in neighbor_ids:
                if len(ref_index) > num_ref:
                    break
                ref_index.append(i)
    return ref_index


# read frame-wise masks
def read_mask(mpath):
    masks = []
    mnames = os.listdir(mpath)
    mnames.sort()
    for mp in mnames:
        m = Image.open(os.path.join(mpath, mp))
        m = m.resize((w, h), Image.NEAREST)
        m = np.array(m.convert('L'))
        m = np.array(m > 0).astype(np.uint8)
        m = cv2.dilate(m, cv2.getStructuringElement(
            cv2.MORPH_CROSS, (3, 3)), iterations=4)
        masks.append(Image.fromarray(m*255))
    return masks


#  read frames from video
def read_frame_from_videos(video_path):
    vname = video_path
    frames = []
    lst = os.listdir(vname)
    lst.sort()
    fr_lst = [vname+'/'+name for name in lst]
    for fr in fr_lst:
        image = cv2.imread(fr)
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        frames.append(image.resize((w, h)))
    return frames

# Load Model

In [None]:
# set up models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = importlib.import_module('model.e2fgvi')
model = net.InpaintGenerator().to(device)
ckpt_path = 'E2FGVI-CVPR22.pth'
data = torch.load(ckpt_path, map_location=device)
model.load_state_dict(data)
print(f'Loading model from: {ckpt_path}')
model.eval()

# Load test data

In [None]:
%cd /content/E2FGVI/
!mkdir -p /content/E2FGVI/examples/schoolgirls

!ffmpeg -i examples/schoolgirls.mp4 examples/schoolgirls/%05d.png

In [None]:
# prepare dataset
video_path = 'examples/schoolgirls'
mask_path = 'examples/schoolgirls_mask'
print(f'Loading videos and masks from: {video_path}')
frames = read_frame_from_videos(video_path)
video_length = len(frames)
imgs = to_tensors()(frames).unsqueeze(0) * 2 - 1
frames = [np.array(f).astype(np.uint8) for f in frames]

masks = read_mask(mask_path)
binary_masks = [np.expand_dims((np.array(m) != 0).astype(np.uint8), 2)
                for m in masks]
masks = to_tensors()(masks).unsqueeze(0)
imgs, masks = imgs.to(device), masks.to(device)
comp_frames = [None] * video_length

# Flow-Guided Video Inpainting

In [None]:
# completing holes by e2fgvi
print(f'Start test...')
for f in tqdm(range(0, video_length, neighbor_stride)):
    neighbor_ids = [i for i in range(max(0, f-neighbor_stride), min(video_length, f+neighbor_stride+1))]
    ref_ids = get_ref_index(f, neighbor_ids, video_length)
    selected_imgs = imgs[:1, neighbor_ids+ref_ids, :, :, :]
    selected_masks = masks[:1, neighbor_ids+ref_ids, :, :, :]
    with torch.no_grad():
        masked_imgs = selected_imgs*(1-selected_masks)
        pred_img, _ = model(masked_imgs, len(neighbor_ids))

        pred_img = (pred_img + 1) / 2
        pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
        for i in range(len(neighbor_ids)):
            idx = neighbor_ids[i]
            img = np.array(pred_img[i]).astype(
                np.uint8)*binary_masks[idx] + frames[idx] * (1-binary_masks[idx])
            if comp_frames[idx] is None:
                comp_frames[idx] = img
            else:
                comp_frames[idx] = comp_frames[idx].astype(
                    np.float32)*0.5 + img.astype(np.float32)*0.5

# 推論結果を画像に出力

In [None]:
%cd /content/E2FGVI
!mkdir results

import matplotlib.pyplot as plt

# 推論結果出力
for i, frame in enumerate(comp_frames):
  plt.imsave('results/frames_%06d.png'%(i), frame.astype(np.uint8))

# 画像を動画に変換して表示

In [None]:
from moviepy.editor import *
from moviepy.video.fx.resize import resize

frames_path = "results/frames_%06d.png"
result_video = "results/result.mp4"

!ffmpeg -i {frames_path} -c:v libx264 -vf "fps=25,format=yuv420p" {result_video}

clip = VideoFileClip(result_video)
resize_clip = resize(clip, height=400)
resize_clip.ipython_display()

# 入力動画

In [None]:
frames_path = "examples/schoolgirls/%05d.png"
result_video = "results/original.mp4"

!ffmpeg -i {frames_path} -c:v libx264 -vf "fps=25,format=yuv420p" {result_video}

clip = VideoFileClip(result_video)
resize_clip = resize(clip, height=400)
resize_clip.ipython_display()