論文<br>
https://arxiv.org/abs/2303.09875<br>
<br>
GitHub<br>
https://github.com/megvii-research/CVPR2023-DMVFN<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/DMVFN_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# setup environment

## git clone

In [None]:
%cd /content

!git clone https://github.com/megvii-research/CVPR2023-DMVFN.git ./DMVFN

%cd /content/DMVFN
# Commits on Mar 24, 2023
!git checkout 2ffe0399ecb82e77ef5f386d0be75c8ce5bcef2f

## install libraries

In [None]:
%cd /content/DMVFN

!pip install -r requirements.txt
!pip install --upgrade gdown

## import libraries

In [None]:
%cd /content/DMVFN

import os
import gdown
import random
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import glob

import torch
device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)

from utils.util import *
from model.model import Model

# download pretrain models

In [None]:
%cd /content/DMVFN

os.makedirs('./pretrained_models', exist_ok=True)

if not os.path.exists("./pretrained_models/dmvfn_city.pkl"):
  gdown.download('https://drive.google.com/uc?id=1jILbS8Gm4E5Xx4tDCPZh_7rId0eo8r9W', "./pretrained_models/dmvfn_city.pkl", quiet=False)
if not os.path.exists("./pretrained_models/dmvfn_kitti.pkl"):
  gdown.download('https://drive.google.com/uc?id=1WrV30prRiS4hWOQBnVPUxdaTlp9XxmVK', "./pretrained_models/dmvfn_kitti.pkl", quiet=False)
if not os.path.exists("./pretrained_models/dmvfn_vimeo.pkl"):
  gdown.download('https://drive.google.com/uc?id=14_xQ3Yl3mO89hr28hbcQW3h63lLrcYY0', "./pretrained_models/dmvfn_vimeo.pkl", quiet=False)


# download test dataset

In [None]:
%cd /content/DMVFN

os.makedirs('./data/cityscapes/test', exist_ok=True)

if not os.path.exists("./data/cityscapes/test/test.zip"):
  gdown.download('https://drive.google.com/uc?id=10zCt-uZFOqgF3tpdhluRqbs-4aScvGR4&confirm=t', "./data/cityscapes/test/test.zip", quiet=False)
  %cd /content/DMVFN/data/cityscapes/test
  !unzip -q ./test.zip

%cd /content/DMVFN

# Inference

## select model

In [None]:
%cd /content/DMVFN

pretrained_weights = 'city' #@param ['city', 'kitti', 'vimeo']

model_path = './pretrained_models/dmvfn_' + pretrained_weights + '.pkl'

## set seed

In [None]:
seed = 12
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True

## Inference

In [None]:
# input image dir
image_dir = "./data/cityscapes/test/000000"
image_list = glob.glob(image_dir + "/*.png")
image_list.sort()

# output dir
out_dir = "./output"
os.makedirs('./output', exist_ok=True)

# load model
model = Model(load_path=model_path, training=False)

In [None]:
pred_num = 3 # max->len(image_list) - 2

with torch.no_grad():
  for i in range(pred_num):
    if i == 0:
      # load image
      cvimg_0 = cv2.imread(image_list[0])
      cvimg_1 = cv2.imread(image_list[1])
    else:
      cvimg_0 = cvimg_1
      cvimg_1 = pred
    # preprocess image
    img_0 = cvimg_0.transpose(2, 0, 1).astype('float32')
    img_1 = cvimg_1.transpose(2, 0, 1).astype('float32')
    img = torch.cat([torch.tensor(img_0),torch.tensor(img_1)], dim=0)
    img = img.unsqueeze(0).unsqueeze(0).to(device, non_blocking=True) # NCHW
    img = img.to(device, non_blocking=True) / 255.
    # inference
    pred = model.eval(img, 'single_test') # 1CHW
    # post process
    pred = np.array(pred.cpu().squeeze() * 255).transpose(1, 2, 0) # CHW -> HWC
    # save result image
    cv2.imwrite(
        os.path.join(out_dir, f'pred_{i:06}.png'), pred)

## show result

In [None]:
show_num = pred_num + 2

# show original images
fig = plt.figure(figsize=(15, 10))
for i in range(show_num):
  ax = fig.add_subplot(1, show_num, i+1)
  plt.title(os.path.basename(image_list[i]), fontsize=16)
  ax.axis('off')
  ax.imshow( Image.open(image_list[i]) )
plt.show()

# show predict images
pred_list = glob.glob(os.path.join(out_dir, '*.png'))
pred_list.sort()
fig = plt.figure(figsize=(15, 10))
for i in range(show_num):
  if i < 2:
    image_path = image_list[i]
  else:
    image_path = pred_list[i-2]

  ax = fig.add_subplot(1, show_num, i+1)
  plt.title(os.path.basename(image_path), fontsize=16)
  ax.axis('off')
  ax.imshow( Image.open(image_path) )
plt.show()