Permalink
Please
sign in to comment.
Browse files
added FacesetRelighter:
Synthesize new faces from existing ones by relighting them using DeepPortraitRelighter network. With the relighted faces neural network will better reproduce face shadows. Therefore you can synthsize shadowed faces from fully lit faceset. https://i.imgur.com/wxcmQoi.jpg as a result, better fakes on dark faces: https://i.imgur.com/5xXIbz5.jpg in OpenCL build Relighter runs on CPU, install pytorch directly via pip install, look at requirements
- Loading branch information
Showing
with
402 additions
and 17 deletions.
- +27 −7 main.py
- +81 −0 mainscripts/FacesetRelighter.py
- +223 −0 nnlib/DeepPortraitRelighting.py
- BIN nnlib/DeepPortraitRelighting.t7
- +2 −1 nnlib/__init__.py
- +26 −3 nnlib/nnlib.py
- +7 −1 requirements-colab.txt
- +6 −0 requirements-cpu.txt
- +6 −0 requirements-cuda.txt
- +6 −0 requirements-opencl.txt
- +8 −3 utils/DFLJPG.py
- +10 −2 utils/DFLPNG.py
@@ -0,0 +1,81 @@ | ||
import traceback | ||
from pathlib import Path | ||
|
||
from interact import interact as io | ||
from nnlib import DeepPortraitRelighting | ||
from utils import Path_utils | ||
from utils.cv2_utils import * | ||
from utils.DFLJPG import DFLJPG | ||
from utils.DFLPNG import DFLPNG | ||
|
||
|
||
def relight(input_dir, lighten=None, random_one=None): | ||
if lighten is None: | ||
lighten = io.input_bool ("Lighten the faces? ( y/n default:n ) : ", False) | ||
|
||
if random_one is None: | ||
random_one = io.input_bool ("Relight the faces only with one random direction? ( y/n default:y ) : ", True) | ||
|
||
input_path = Path(input_dir) | ||
|
||
image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)] | ||
|
||
dpr = DeepPortraitRelighting() | ||
|
||
for filepath in io.progress_bar_generator(image_paths, "Relighting"): | ||
try: | ||
if filepath.suffix == '.png': | ||
dflimg = DFLPNG.load( str(filepath) ) | ||
elif filepath.suffix == '.jpg': | ||
dflimg = DFLJPG.load ( str(filepath) ) | ||
else: | ||
dflimg = None | ||
|
||
if dflimg is None: | ||
io.log_err ("%s is not a dfl image file" % (filepath.name) ) | ||
continue | ||
else: | ||
if dflimg.get_relighted(): | ||
io.log_info (f"Skipping already relighted face [{filepath.name}]") | ||
continue | ||
img = cv2_imread (str(filepath)) | ||
|
||
if random_one: | ||
relighted_imgs = dpr.relight_random(img,lighten=lighten) | ||
else: | ||
relighted_imgs = dpr.relight_all(img,lighten=lighten) | ||
|
||
for i,relighted_img in enumerate(relighted_imgs): | ||
im_flags = [] | ||
if filepath.suffix == '.jpg': | ||
im_flags += [int(cv2.IMWRITE_JPEG_QUALITY), 100] | ||
|
||
relighted_filename = filepath.parent / (filepath.stem+f'_relighted_{i}'+filepath.suffix) | ||
|
||
cv2_imwrite (relighted_filename, relighted_img ) | ||
dflimg.embed_and_set (relighted_filename, source_filename="_", relighted=True ) | ||
except: | ||
io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}") | ||
|
||
def delete_relighted(input_dir): | ||
input_path = Path(input_dir) | ||
image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)] | ||
|
||
files_to_delete = [] | ||
for filepath in io.progress_bar_generator(image_paths, "Loading"): | ||
if filepath.suffix == '.png': | ||
dflimg = DFLPNG.load( str(filepath) ) | ||
elif filepath.suffix == '.jpg': | ||
dflimg = DFLJPG.load ( str(filepath) ) | ||
else: | ||
dflimg = None | ||
|
||
if dflimg is None: | ||
io.log_err ("%s is not a dfl image file" % (filepath.name) ) | ||
continue | ||
else: | ||
if dflimg.get_relighted(): | ||
files_to_delete += [filepath] | ||
|
||
for file in io.progress_bar_generator(files_to_delete, "Deleting"): | ||
file.unlink() |
@@ -0,0 +1,223 @@ | ||
from pathlib import Path | ||
import numpy as np | ||
import cv2 | ||
|
||
class DeepPortraitRelighting(object): | ||
|
||
def __init__(self): | ||
from nnlib import nnlib | ||
nnlib.import_torch() | ||
|
||
self.torch = nnlib.torch | ||
self.torch_device = nnlib.torch_device | ||
|
||
self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device) | ||
|
||
self.shs = [ | ||
[1.084125496282453138e+00,-4.642676300617166185e-01,2.837846795150648915e-02,6.765292733937575687e-01,-3.594067725393816914e-01,4.790996460111427574e-02,-2.280054643781863066e-01,-8.125983081159608712e-02,2.881082012687687932e-01], | ||
[1.084125496282453138e+00,-4.642676300617170626e-01,5.466255701105990905e-01,3.996219229512094628e-01,-2.615439760463462715e-01,-2.511241554473071513e-01,6.495694866016435420e-02,3.510322039081858470e-01,1.189662732386344152e-01], | ||
[1.084125496282453138e+00,-4.642676300617179508e-01,6.532524688468428486e-01,-1.782088862752457814e-01,3.326676893441832261e-02,-3.610566644446819295e-01,3.647561777790956361e-01,-7.496419691318900735e-02,-5.412289239602386531e-02], | ||
[1.084125496282453138e+00,-4.642676300617186724e-01,2.679669346194941126e-01,-6.218447693376460972e-01,3.030269583891490037e-01,-1.991061409014726058e-01,-6.162944418511027977e-02,-3.176699976873690878e-01,1.920509612235956343e-01], | ||
[1.084125496282453138e+00,-4.642676300617186724e-01,-3.191031669056417219e-01,-5.972188577671910803e-01,3.446016675533919993e-01,1.127753677656503223e-01,-1.716692196540034188e-01,2.163406460637767315e-01,2.555824552121269688e-01], | ||
[1.084125496282453138e+00,-4.642676300617178398e-01,-6.658820752324799974e-01,-1.228749652534838893e-01,1.266842924569576145e-01,3.397347243069742673e-01,3.036887095295650041e-01,2.213893524577207617e-01,-1.886557316342868038e-02], | ||
[1.084125496282453138e+00,-4.642676300617169516e-01,-5.112381993903207800e-01,4.439962822886048266e-01,-1.866289387481862572e-01,3.108669041197227867e-01,2.021743042675238355e-01,-3.148681770175290051e-01,3.974379604123656762e-02] | ||
] | ||
|
||
#n = [0..8] | ||
def relight(self, img, n, lighten=False): | ||
torch = self.torch | ||
|
||
sh = (np.array (self.shs[np.clip(n, 0,8)]).reshape( (1,9,1,1) )*0.7).astype(np.float32) | ||
sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device)) | ||
|
||
row, col, _ = img.shape | ||
img = cv2.resize(img, (512, 512)) | ||
Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) | ||
|
||
inputL = Lab[:,:,0] | ||
outputImg, outputSH = self.model(torch.autograd.Variable(torch.from_numpy(inputL[None,None,...].astype(np.float32)/255.0).to(self.torch_device)), | ||
sh, 0) | ||
|
||
outputImg = outputImg[0].cpu().data.numpy() | ||
outputImg = outputImg.transpose((1,2,0)) | ||
outputImg = np.squeeze(outputImg) | ||
outputImg = np.clip (outputImg, 0.0, 1.0) | ||
outputImg = cv2.blur(outputImg, (3,3) ) | ||
|
||
if not lighten: | ||
outputImg = inputL* outputImg | ||
else: | ||
outputImg = outputImg*255.0 | ||
outputImg = np.clip(outputImg, 0,255).astype(np.uint8) | ||
|
||
Lab[:,:,0] = outputImg | ||
result = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) | ||
result = cv2.resize(result, (col, row)) | ||
return result | ||
|
||
def relight_all(self, img, lighten=False): | ||
return [ self.relight(img, n, lighten=lighten) for n in range( len(self.shs) ) ] | ||
|
||
def relight_random(self, img, lighten=False): | ||
return [ self.relight(img, np.random.randint(len(self.shs)), lighten=lighten ) ] | ||
|
||
@staticmethod | ||
def build_model(torch, torch_device): | ||
nn = torch.nn | ||
F = torch.nn.functional | ||
|
||
def conv3X3(in_planes, out_planes, stride=1): | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
|
||
# define the network | ||
class BasicBlock(nn.Module): | ||
def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None): | ||
super(BasicBlock, self).__init__() | ||
# batchNorm_type 0 means batchnormalization | ||
# 1 means instance normalization | ||
self.inplanes = inplanes | ||
self.outplanes = outplanes | ||
self.conv1 = conv3X3(inplanes, outplanes, 1) | ||
self.conv2 = conv3X3(outplanes, outplanes, 1) | ||
if batchNorm_type == 0: | ||
self.bn1 = nn.BatchNorm2d(outplanes) | ||
self.bn2 = nn.BatchNorm2d(outplanes) | ||
else: | ||
self.bn1 = nn.InstanceNorm2d(outplanes) | ||
self.bn2 = nn.InstanceNorm2d(outplanes) | ||
|
||
self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False) | ||
|
||
def forward(self, x): | ||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = F.relu(out) | ||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
|
||
if self.inplanes != self.outplanes: | ||
out += self.shortcuts(x) | ||
else: | ||
out += x | ||
|
||
out = F.relu(out) | ||
return out | ||
|
||
class HourglassBlock(nn.Module): | ||
def __init__(self, inplane, mid_plane, middleNet, skipLayer=True): | ||
super(HourglassBlock, self).__init__() | ||
# upper branch | ||
self.skipLayer = True | ||
self.upper = BasicBlock(inplane, inplane, batchNorm_type=1) | ||
|
||
# lower branch | ||
self.downSample = nn.MaxPool2d(kernel_size=2, stride=2) | ||
self.upSample = nn.Upsample(scale_factor=2, mode='nearest') | ||
self.low1 = BasicBlock(inplane, mid_plane) | ||
self.middle = middleNet | ||
self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1) | ||
|
||
def forward(self, x, light, count, skip_count): | ||
# we use count to indicate wich layer we are in | ||
# max_count indicates the from which layer, we would use skip connections | ||
out_upper = self.upper(x) | ||
out_lower = self.downSample(x) | ||
out_lower = self.low1(out_lower) | ||
out_lower, out_middle = self.middle(out_lower, light, count+1, skip_count) | ||
out_lower = self.low2(out_lower) | ||
out_lower = self.upSample(out_lower) | ||
if count >= skip_count and self.skipLayer: | ||
out = out_lower + out_upper | ||
else: | ||
out = out_lower | ||
return out, out_middle | ||
|
||
class lightingNet(nn.Module): | ||
def __init__(self, ncInput, ncOutput, ncMiddle): | ||
super(lightingNet, self).__init__() | ||
self.ncInput = ncInput | ||
self.ncOutput = ncOutput | ||
self.ncMiddle = ncMiddle | ||
self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False) | ||
self.predict_relu1 = nn.PReLU() | ||
self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False) | ||
|
||
self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False) | ||
self.post_relu1 = nn.PReLU() | ||
self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False) | ||
self.post_relu2 = nn.ReLU() # to be consistance with the original feature | ||
|
||
def forward(self, innerFeat, target_light, count, skip_count): | ||
x = innerFeat[:,0:self.ncInput,:,:] # lighting feature | ||
_, _, row, col = x.shape | ||
# predict lighting | ||
feat = x.mean(dim=(2,3), keepdim=True) | ||
light = self.predict_relu1(self.predict_FC1(feat)) | ||
light = self.predict_FC2(light) | ||
upFeat = self.post_relu1(self.post_FC1(target_light)) | ||
upFeat = self.post_relu2(self.post_FC2(upFeat)) | ||
upFeat = upFeat.repeat((1,1,row, col)) | ||
innerFeat[:,0:self.ncInput,:,:] = upFeat | ||
return innerFeat, light#light | ||
|
||
|
||
class HourglassNet(nn.Module): | ||
def __init__(self, baseFilter = 16, gray=True): | ||
super(HourglassNet, self).__init__() | ||
|
||
self.ncLight = 27 # number of channels for input to lighting network | ||
self.baseFilter = baseFilter | ||
|
||
# number of channles for output of lighting network | ||
if gray: | ||
self.ncOutLight = 9 # gray: channel is 1 | ||
else: | ||
self.ncOutLight = 27 # color: channel is 3 | ||
|
||
self.ncPre = self.baseFilter # number of channels for pre-convolution | ||
|
||
# number of channels | ||
self.ncHG3 = self.baseFilter | ||
self.ncHG2 = 2*self.baseFilter | ||
self.ncHG1 = 4*self.baseFilter | ||
self.ncHG0 = 8*self.baseFilter + self.ncLight | ||
|
||
self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2) | ||
self.pre_bn = nn.BatchNorm2d(self.ncPre) | ||
|
||
self.light = lightingNet(self.ncLight, self.ncOutLight, 128) | ||
self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light) | ||
self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0) | ||
self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1) | ||
self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2) | ||
|
||
self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1) | ||
self.bn_1 = nn.BatchNorm2d(self.ncPre) | ||
self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) | ||
self.bn_2 = nn.BatchNorm2d(self.ncPre) | ||
self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) | ||
self.bn_3 = nn.BatchNorm2d(self.ncPre) | ||
|
||
self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0) | ||
|
||
def forward(self, x, target_light, skip_count): | ||
feat = self.pre_conv(x) | ||
|
||
feat = F.relu(self.pre_bn(feat)) | ||
# get the inner most features | ||
feat, out_light = self.HG3(feat, target_light, 0, skip_count) | ||
#return feat, out_light | ||
|
||
feat = F.relu(self.bn_1(self.conv_1(feat))) | ||
feat = F.relu(self.bn_2(self.conv_2(feat))) | ||
feat = F.relu(self.bn_3(self.conv_3(feat))) | ||
out_img = self.output(feat) | ||
out_img = torch.sigmoid(out_img) | ||
return out_img, out_light | ||
|
||
model = HourglassNet() | ||
t_dict = torch.load( Path(__file__).parent / 'DeepPortraitRelighting.t7' ) | ||
model.load_state_dict(t_dict) | ||
model.to( torch_device ) | ||
model.train(False) | ||
return model |
Binary file not shown.
@@ -1,4 +1,5 @@ | ||
from .nnlib import nnlib | ||
from .FUNIT import FUNIT | ||
from .TernausNet import TernausNet | ||
from .VGGFace import VGGFace | ||
from .VGGFace import VGGFace | ||
from .DeepPortraitRelighting import DeepPortraitRelighting |

Oops, something went wrong.
0 comments on commit
fe58459