Skip to content

Commit

Permalink
Add art model of deoldify (PaddlePaddle#71)
Browse files Browse the repository at this point in the history
* add art model of deoldify
  • Loading branch information
LielinJiang committed Nov 4, 2020
1 parent f02d52a commit 71f3755
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 45 deletions.
5 changes: 5 additions & 0 deletions applications/tools/video-enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
default=360,
help='Length of minimum image edges')
# DeOldify args
parser.add_argument('--artistic',
action='store_true',
default=False,
help='whether to use artistic DeOldify Model')
parser.add_argument('--render_factor',
type=int,
default=32,
Expand Down Expand Up @@ -107,6 +111,7 @@
elif order == 'DeOldify':
predictor = DeOldifyPredictor(args.output,
weight_path=args.DeOldify_weight,
artistic=args.artistic,
render_factor=args.render_factor)
frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'RealSR':
Expand Down
1 change: 1 addition & 0 deletions docs/en_US/tutorials/video_restore.md
1 change: 1 addition & 0 deletions docs/zh_CN/apis/apps.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32
>
> > - output (str): 设置输出图片的保存路径,默认是output。注意,保存路径为设置output/DeOldify。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
> > - artistic (bool): 是否使用偏"艺术性"的模型。"艺术性"的模型有可能产生一些有趣的颜色,但是毛刺比较多。
> > - render_factor (int): 图片渲染上色时的缩放因子,图片会缩放到边长为16xrender_factor的正方形, 再上色,例如render_factor默认值为32,输入图片先缩放到(16x32=512) 512x512大小的图片。通常来说,render_factor越小,计算速度越快,颜色看起来也更鲜活。较旧和较低质量的图像通常会因降低渲染因子而受益。渲染因子越高,图像质量越好,但颜色可能会稍微褪色。
### run
Expand Down
1 change: 1 addition & 0 deletions docs/zh_CN/tutorials/video_restore.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32

- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
- `artistic (bool)`: 是否使用偏"艺术性"的模型。"艺术性"的模型有可能产生一些有趣的颜色,但是毛刺比较多。
- `render_factor (int)`: 会将该参数乘以16后作为输入帧的resize的值,如果该值设置为32,
则输入帧会resize到(32 * 16, 32 * 16)的尺寸再输入到网络中。

Expand Down
14 changes: 11 additions & 3 deletions ppgan/apps/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import cv2
import numpy as np
from PIL import Image
import paddle

Expand Down Expand Up @@ -64,9 +65,16 @@ def base_forward(self, inputs):

def is_image(self, input):
try:
img = Image.open(input)
_ = img.size
return True
if isinstance(input, (np.ndarray, Image.Image)):
return True
elif isinstance(input, str):
if not os.path.isfile(input):
raise ValueError('input must be a file')
img = Image.open(input)
_ = img.size
return True
else:
return False
except:
return False

Expand Down
25 changes: 19 additions & 6 deletions ppgan/apps/deoldify_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,27 @@

from .base_predictor import BasePredictor

DEOLDIFY_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
DEOLDIFY_STABLE_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
DEOLDIFY_ART_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_art.pdparams'


class DeOldifyPredictor(BasePredictor):
def __init__(self, output='output', weight_path=None, render_factor=32):
# self.input = input
def __init__(self,
output='output',
weight_path=None,
artistic=False,
render_factor=32):
self.output = os.path.join(output, 'DeOldify')
if not os.path.exists(self.output):
os.makedirs(self.output)
self.render_factor = render_factor
self.model = build_model()
self.model = build_model(
model_type='artistic' if artistic else 'stable')
if weight_path is None:
weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL)
if artistic:
weight_path = get_path_from_url(DEOLDIFY_ART_WEIGHT_URL)
else:
weight_path = get_path_from_url(DEOLDIFY_STABLE_WEIGHT_URL)

state_dict = paddle.load(weight_path)
self.model.load_dict(state_dict)
Expand Down Expand Up @@ -134,7 +144,10 @@ def run(self, input):

out_path = None
if self.output:
base_name = os.path.splitext(os.path.basename(input))[0]
try:
base_name = os.path.splitext(os.path.basename(input))[0]
except:
base_name = 'result'
out_path = os.path.join(self.output, base_name + '.png')
pred_img.save(out_path)

Expand Down
5 changes: 4 additions & 1 deletion ppgan/apps/realsr_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def run(self, input):

out_path = None
if self.output:
base_name = os.path.splitext(os.path.basename(input))[0]
try:
base_name = os.path.splitext(os.path.basename(input))[0]
except:
base_name = 'result'
out_path = os.path.join(self.output, base_name + '.png')
pred_img.save(out_path)

Expand Down
87 changes: 57 additions & 30 deletions ppgan/models/generators/deoldify.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import resnet101
from paddle.vision.models import resnet34, resnet101

from .hook import hook_outputs, model_sizes, dummy_eval
from ...modules.nn import Spectralnorm
Expand Down Expand Up @@ -57,6 +57,7 @@ class Deoldify(SequentialEx):
def __init__(self,
encoder,
n_classes,
model_type='stable',
blur=False,
blur_final=True,
self_attention=False,
Expand Down Expand Up @@ -95,18 +96,34 @@ def __init__(self,
do_blur = blur and (not_final or blur_final)
sa = self_attention and (i == len(sfs_idxs) - 3)

n_out = nf if not_final else nf // 2

unet_block = UnetBlockWide(up_in_c,
x_in_c,
n_out,
self.sfs[i],
final_div=not_final,
blur=blur,
self_attention=sa,
norm_type=norm_type,
extra_bn=extra_bn,
**kwargs)
if model_type == 'stable':
n_out = nf if not_final else nf // 2
unet_block = UnetBlockWide(up_in_c,
x_in_c,
n_out,
self.sfs[i],
final_div=not_final,
blur=blur,
self_attention=sa,
norm_type=norm_type,
extra_bn=extra_bn,
**kwargs)
elif model_type == 'artistic':
unet_block = UnetBlockDeep(up_in_c,
x_in_c,
self.sfs[i],
final_div=not_final,
blur=blur,
self_attention=sa,
norm_type=norm_type,
extra_bn=extra_bn,
nf_factor=nf_factor,
**kwargs)
else:
raise ValueError(
'Expected model_type in [stable, artistic], but got {}'.
format(model_type))

unet_block.eval()
layers.append(unet_block)
x = unet_block(x)
Expand Down Expand Up @@ -151,7 +168,7 @@ def custom_conv_layer(ni: int,
bn = norm_type in ('Batch', 'Batchzero') or extra_bn == True
if bias is None:
bias = not bn
conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D
conv_func = nn.Conv2DTranspose if transpose else nn.Conv1D if is_1d else nn.Conv2D

conv = conv_func(ni,
nf,
Expand Down Expand Up @@ -222,19 +239,18 @@ def forward(self, up_in):
class UnetBlockDeep(nn.Layer):
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."

def __init__(
self,
up_in_c: int,
x_in_c: int,
# hook: Hook,
final_div: bool = True,
blur: bool = False,
leaky: float = None,
self_attention: bool = False,
nf_factor: float = 1.0,
**kwargs):
def __init__(self,
up_in_c: int,
x_in_c: int,
hook,
final_div: bool = True,
blur: bool = False,
leaky: float = None,
self_attention: bool = False,
nf_factor: float = 1.0,
**kwargs):
super().__init__()

self.hook = hook
self.shuf = CustomPixelShuffle_ICNR(up_in_c,
up_in_c // 2,
blur=blur,
Expand Down Expand Up @@ -312,7 +328,7 @@ def conv_layer(ni: int,
if padding is None: padding = (ks - 1) // 2 if not transpose else 0
bn = norm_type in ('Batch', 'BatchZero')
if bias is None: bias = not bn
conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D
conv_func = nn.Conv2DTranspose if transpose else nn.Conv1D if is_1d else nn.Conv2D

conv = conv_func(ni,
nf,
Expand Down Expand Up @@ -472,16 +488,27 @@ def _get_sfs_idxs(sizes):
return sfs_idxs


def build_model():
backbone = resnet101()
def build_model(model_type='stable'):
if model_type == 'stable':
backbone = resnet101()
nf_factor = 2
elif model_type == 'artistic':
backbone = resnet34()
nf_factor = 1.5
else:
raise ValueError(
'Expected model_type in [stable, artistic], but got {}'.format(
model_type))

cut = -2
encoder = nn.Sequential(*list(backbone.children())[:cut])

model = Deoldify(encoder,
3,
model_type=model_type,
blur=True,
y_range=(-3, 3),
norm_type='Spectral',
self_attention=True,
nf_factor=2)
nf_factor=nf_factor)
return model
7 changes: 2 additions & 5 deletions ppgan/models/pix2pix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,8 @@ def set_input(self, input):

AtoB = self.cfg.dataset.train.direction == 'AtoB'

# TODO: replace to_varialbe with to_tensor
self.real_A = paddle.fluid.dygraph.to_variable(
input['A' if AtoB else 'B'])
self.real_B = paddle.fluid.dygraph.to_variable(
input['B' if AtoB else 'A'])
self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])

self.image_paths = input['A_paths' if AtoB else 'B_paths']

Expand Down

0 comments on commit 71f3755

Please sign in to comment.