Skip to content
Permalink
Browse files

added FacesetRelighter:

Synthesize new faces from existing ones by relighting them using DeepPortraitRelighter network.
With the relighted faces neural network will better reproduce face shadows.

Therefore you can synthsize shadowed faces from fully lit faceset.
https://i.imgur.com/wxcmQoi.jpg

as a result, better fakes on dark faces:
https://i.imgur.com/5xXIbz5.jpg

in OpenCL build Relighter runs on CPU,

install pytorch directly via pip install, look at requirements
  • Loading branch information
iperov committed Nov 11, 2019
1 parent b9c0815 commit fe58459f369719cce0fff8cf564fce16f725fae9
34 main.py
@@ -63,7 +63,7 @@ def process_dev_extract_vggface2_dataset(arguments):
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
p.set_defaults (func=process_dev_extract_vggface2_dataset)

def process_dev_extract_umd_csv(arguments):
os_utils.set_process_lowest_prio()
from mainscripts import Extractor
@@ -78,8 +78,8 @@ def process_dev_extract_umd_csv(arguments):
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
p.set_defaults (func=process_dev_extract_umd_csv)


def process_dev_apply_celebamaskhq(arguments):
os_utils.set_process_lowest_prio()
from mainscripts import dev_misc
@@ -130,10 +130,10 @@ def process_util(arguments):

#if arguments.remove_fanseg:
# Util.remove_fanseg_folder (input_path=arguments.input_dir)

if arguments.remove_ie_polys:
Util.remove_ie_polys_folder (input_path=arguments.input_dir)

p = subparsers.add_parser( "util", help="Utilities.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
@@ -190,7 +190,7 @@ def process_convert(arguments):
Converter.main (args, device_args)

p = subparsers.add_parser( "convert", help="Converter")
p.add_argument('--training-data-src-dir', action=fixPathAction, dest="training_data_src_dir", help="(optional, may be required by some models) Dir of extracted SRC faceset.")
p.add_argument('--training-data-src-dir', action=fixPathAction, dest="training_data_src_dir", help="(optional, may be required by some models) Dir of extracted SRC faceset.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the converted files will be stored.")
p.add_argument('--aligned-dir', action=fixPathAction, dest="aligned_dir", help="Aligned directory. This is where the extracted of dst faces stored.")
@@ -270,9 +270,29 @@ def process_labelingtool_edit_mask(arguments):
p.add_argument('--confirmed-dir', required=True, action=fixPathAction, dest="confirmed_dir", help="This is where the labeled faces will be stored.")
p.add_argument('--skipped-dir', required=True, action=fixPathAction, dest="skipped_dir", help="This is where the labeled faces will be stored.")
p.add_argument('--no-default-mask', action="store_true", dest="no_default_mask", default=False, help="Don't use default mask.")

p.set_defaults(func=process_labelingtool_edit_mask)

def process_relight_faceset(arguments):
from mainscripts import FacesetRelighter
FacesetRelighter.relight (arguments.input_dir, arguments.lighten, arguments.random_one)

def process_delete_relighted(arguments):
from mainscripts import FacesetRelighter
FacesetRelighter.delete_relighted (arguments.input_dir)

facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers()

p = facesettool_parser.add_parser ("relight", help="Synthesize new faces from existing ones by relighting them. With the relighted faces neural network will better reproduce face shadows.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
p.add_argument('--lighten', action="store_true", dest="lighten", default=None, help="Lighten the faces.")
p.add_argument('--random-one', action="store_true", dest="random_one", default=None, help="Relight the faces only with one random direction, otherwise relight with all directions.")
p.set_defaults(func=process_relight_faceset)

p = facesettool_parser.add_parser ("delete_relighted", help="Delete relighted faces.")
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
p.set_defaults(func=process_delete_relighted)

def bad_args(arguments):
parser.print_help()
exit(0)
@@ -0,0 +1,81 @@
import traceback
from pathlib import Path

from interact import interact as io
from nnlib import DeepPortraitRelighting
from utils import Path_utils
from utils.cv2_utils import *
from utils.DFLJPG import DFLJPG
from utils.DFLPNG import DFLPNG


def relight(input_dir, lighten=None, random_one=None):
if lighten is None:
lighten = io.input_bool ("Lighten the faces? ( y/n default:n ) : ", False)

if random_one is None:
random_one = io.input_bool ("Relight the faces only with one random direction? ( y/n default:y ) : ", True)

input_path = Path(input_dir)

image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)]

dpr = DeepPortraitRelighting()

for filepath in io.progress_bar_generator(image_paths, "Relighting"):
try:
if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) )
else:
dflimg = None

if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) )
continue
else:
if dflimg.get_relighted():
io.log_info (f"Skipping already relighted face [{filepath.name}]")
continue
img = cv2_imread (str(filepath))

if random_one:
relighted_imgs = dpr.relight_random(img,lighten=lighten)
else:
relighted_imgs = dpr.relight_all(img,lighten=lighten)

for i,relighted_img in enumerate(relighted_imgs):
im_flags = []
if filepath.suffix == '.jpg':
im_flags += [int(cv2.IMWRITE_JPEG_QUALITY), 100]

relighted_filename = filepath.parent / (filepath.stem+f'_relighted_{i}'+filepath.suffix)

cv2_imwrite (relighted_filename, relighted_img )
dflimg.embed_and_set (relighted_filename, source_filename="_", relighted=True )
except:
io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}")

def delete_relighted(input_dir):
input_path = Path(input_dir)
image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)]

files_to_delete = []
for filepath in io.progress_bar_generator(image_paths, "Loading"):
if filepath.suffix == '.png':
dflimg = DFLPNG.load( str(filepath) )
elif filepath.suffix == '.jpg':
dflimg = DFLJPG.load ( str(filepath) )
else:
dflimg = None

if dflimg is None:
io.log_err ("%s is not a dfl image file" % (filepath.name) )
continue
else:
if dflimg.get_relighted():
files_to_delete += [filepath]

for file in io.progress_bar_generator(files_to_delete, "Deleting"):
file.unlink()
@@ -0,0 +1,223 @@
from pathlib import Path
import numpy as np
import cv2

class DeepPortraitRelighting(object):

def __init__(self):
from nnlib import nnlib
nnlib.import_torch()

self.torch = nnlib.torch
self.torch_device = nnlib.torch_device

self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device)

self.shs = [
[1.084125496282453138e+00,-4.642676300617166185e-01,2.837846795150648915e-02,6.765292733937575687e-01,-3.594067725393816914e-01,4.790996460111427574e-02,-2.280054643781863066e-01,-8.125983081159608712e-02,2.881082012687687932e-01],
[1.084125496282453138e+00,-4.642676300617170626e-01,5.466255701105990905e-01,3.996219229512094628e-01,-2.615439760463462715e-01,-2.511241554473071513e-01,6.495694866016435420e-02,3.510322039081858470e-01,1.189662732386344152e-01],
[1.084125496282453138e+00,-4.642676300617179508e-01,6.532524688468428486e-01,-1.782088862752457814e-01,3.326676893441832261e-02,-3.610566644446819295e-01,3.647561777790956361e-01,-7.496419691318900735e-02,-5.412289239602386531e-02],
[1.084125496282453138e+00,-4.642676300617186724e-01,2.679669346194941126e-01,-6.218447693376460972e-01,3.030269583891490037e-01,-1.991061409014726058e-01,-6.162944418511027977e-02,-3.176699976873690878e-01,1.920509612235956343e-01],
[1.084125496282453138e+00,-4.642676300617186724e-01,-3.191031669056417219e-01,-5.972188577671910803e-01,3.446016675533919993e-01,1.127753677656503223e-01,-1.716692196540034188e-01,2.163406460637767315e-01,2.555824552121269688e-01],
[1.084125496282453138e+00,-4.642676300617178398e-01,-6.658820752324799974e-01,-1.228749652534838893e-01,1.266842924569576145e-01,3.397347243069742673e-01,3.036887095295650041e-01,2.213893524577207617e-01,-1.886557316342868038e-02],
[1.084125496282453138e+00,-4.642676300617169516e-01,-5.112381993903207800e-01,4.439962822886048266e-01,-1.866289387481862572e-01,3.108669041197227867e-01,2.021743042675238355e-01,-3.148681770175290051e-01,3.974379604123656762e-02]
]

#n = [0..8]
def relight(self, img, n, lighten=False):
torch = self.torch

sh = (np.array (self.shs[np.clip(n, 0,8)]).reshape( (1,9,1,1) )*0.7).astype(np.float32)
sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device))

row, col, _ = img.shape
img = cv2.resize(img, (512, 512))
Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)

inputL = Lab[:,:,0]
outputImg, outputSH = self.model(torch.autograd.Variable(torch.from_numpy(inputL[None,None,...].astype(np.float32)/255.0).to(self.torch_device)),
sh, 0)

outputImg = outputImg[0].cpu().data.numpy()
outputImg = outputImg.transpose((1,2,0))
outputImg = np.squeeze(outputImg)
outputImg = np.clip (outputImg, 0.0, 1.0)
outputImg = cv2.blur(outputImg, (3,3) )

if not lighten:
outputImg = inputL* outputImg
else:
outputImg = outputImg*255.0
outputImg = np.clip(outputImg, 0,255).astype(np.uint8)

Lab[:,:,0] = outputImg
result = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR)
result = cv2.resize(result, (col, row))
return result

def relight_all(self, img, lighten=False):
return [ self.relight(img, n, lighten=lighten) for n in range( len(self.shs) ) ]

def relight_random(self, img, lighten=False):
return [ self.relight(img, np.random.randint(len(self.shs)), lighten=lighten ) ]

@staticmethod
def build_model(torch, torch_device):
nn = torch.nn
F = torch.nn.functional

def conv3X3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

# define the network
class BasicBlock(nn.Module):
def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None):
super(BasicBlock, self).__init__()
# batchNorm_type 0 means batchnormalization
# 1 means instance normalization
self.inplanes = inplanes
self.outplanes = outplanes
self.conv1 = conv3X3(inplanes, outplanes, 1)
self.conv2 = conv3X3(outplanes, outplanes, 1)
if batchNorm_type == 0:
self.bn1 = nn.BatchNorm2d(outplanes)
self.bn2 = nn.BatchNorm2d(outplanes)
else:
self.bn1 = nn.InstanceNorm2d(outplanes)
self.bn2 = nn.InstanceNorm2d(outplanes)

self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False)

def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)

if self.inplanes != self.outplanes:
out += self.shortcuts(x)
else:
out += x

out = F.relu(out)
return out

class HourglassBlock(nn.Module):
def __init__(self, inplane, mid_plane, middleNet, skipLayer=True):
super(HourglassBlock, self).__init__()
# upper branch
self.skipLayer = True
self.upper = BasicBlock(inplane, inplane, batchNorm_type=1)

# lower branch
self.downSample = nn.MaxPool2d(kernel_size=2, stride=2)
self.upSample = nn.Upsample(scale_factor=2, mode='nearest')
self.low1 = BasicBlock(inplane, mid_plane)
self.middle = middleNet
self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1)

def forward(self, x, light, count, skip_count):
# we use count to indicate wich layer we are in
# max_count indicates the from which layer, we would use skip connections
out_upper = self.upper(x)
out_lower = self.downSample(x)
out_lower = self.low1(out_lower)
out_lower, out_middle = self.middle(out_lower, light, count+1, skip_count)
out_lower = self.low2(out_lower)
out_lower = self.upSample(out_lower)
if count >= skip_count and self.skipLayer:
out = out_lower + out_upper
else:
out = out_lower
return out, out_middle

class lightingNet(nn.Module):
def __init__(self, ncInput, ncOutput, ncMiddle):
super(lightingNet, self).__init__()
self.ncInput = ncInput
self.ncOutput = ncOutput
self.ncMiddle = ncMiddle
self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False)
self.predict_relu1 = nn.PReLU()
self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False)

self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False)
self.post_relu1 = nn.PReLU()
self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False)
self.post_relu2 = nn.ReLU() # to be consistance with the original feature

def forward(self, innerFeat, target_light, count, skip_count):
x = innerFeat[:,0:self.ncInput,:,:] # lighting feature
_, _, row, col = x.shape
# predict lighting
feat = x.mean(dim=(2,3), keepdim=True)
light = self.predict_relu1(self.predict_FC1(feat))
light = self.predict_FC2(light)
upFeat = self.post_relu1(self.post_FC1(target_light))
upFeat = self.post_relu2(self.post_FC2(upFeat))
upFeat = upFeat.repeat((1,1,row, col))
innerFeat[:,0:self.ncInput,:,:] = upFeat
return innerFeat, light#light


class HourglassNet(nn.Module):
def __init__(self, baseFilter = 16, gray=True):
super(HourglassNet, self).__init__()

self.ncLight = 27 # number of channels for input to lighting network
self.baseFilter = baseFilter

# number of channles for output of lighting network
if gray:
self.ncOutLight = 9 # gray: channel is 1
else:
self.ncOutLight = 27 # color: channel is 3

self.ncPre = self.baseFilter # number of channels for pre-convolution

# number of channels
self.ncHG3 = self.baseFilter
self.ncHG2 = 2*self.baseFilter
self.ncHG1 = 4*self.baseFilter
self.ncHG0 = 8*self.baseFilter + self.ncLight

self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2)
self.pre_bn = nn.BatchNorm2d(self.ncPre)

self.light = lightingNet(self.ncLight, self.ncOutLight, 128)
self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light)
self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0)
self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1)
self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2)

self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1)
self.bn_1 = nn.BatchNorm2d(self.ncPre)
self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
self.bn_2 = nn.BatchNorm2d(self.ncPre)
self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
self.bn_3 = nn.BatchNorm2d(self.ncPre)

self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0)

def forward(self, x, target_light, skip_count):
feat = self.pre_conv(x)

feat = F.relu(self.pre_bn(feat))
# get the inner most features
feat, out_light = self.HG3(feat, target_light, 0, skip_count)
#return feat, out_light

feat = F.relu(self.bn_1(self.conv_1(feat)))
feat = F.relu(self.bn_2(self.conv_2(feat)))
feat = F.relu(self.bn_3(self.conv_3(feat)))
out_img = self.output(feat)
out_img = torch.sigmoid(out_img)
return out_img, out_light

model = HourglassNet()
t_dict = torch.load( Path(__file__).parent / 'DeepPortraitRelighting.t7' )
model.load_state_dict(t_dict)
model.to( torch_device )
model.train(False)
return model
Binary file not shown.
@@ -1,4 +1,5 @@
from .nnlib import nnlib
from .FUNIT import FUNIT
from .TernausNet import TernausNet
from .VGGFace import VGGFace
from .VGGFace import VGGFace
from .DeepPortraitRelighting import DeepPortraitRelighting

0 comments on commit fe58459

Please sign in to comment.
You can’t perform that action at this time.