In [4]:

import os, sys
import cv2
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from glob import glob
import time
import datetime
import imageio

torch.backends.cudnn.benchmark = True

### At this stage, use the smart_class environment

In [1]:
import Gen_Landmark

img_folder_dir = "/data/share/FLAME/facescape_fit/images"
out_dir = '/data/share/FLAME/facescape_fit/landmarks'

tt = Gen_Landmark.Gen2DLandmarks()
tt.main_process(img_dir = img_folder_dir,out_dir = out_dir)

Generate facial landmarks:  84%|████████▍ | 88/105 [00:17<00:00, 21.26it/s]



Generate facial landmarks: 100%|██████████| 105/105 [00:18<00:00,  5.71it/s]


In [2]:
from Gen_facial_mask import *

img_folder_dir = "/data/share/FLAME/facescape_fit/images"
mask_out_dir = "/data/share/FLAME/facescape_fit/masks"

gen_face_mask(img_folder_dir=img_folder_dir,out_dir=mask_out_dir)

CUDA is available. Device:  NVIDIA RTX A6000


100%|██████████| 105/105 [00:14<00:00,  7.26it/s]


### Starting this line below, run the flame environment

In [9]:
import sys
sys.path.append('/data/share/FLAME/Photometric_fitting/photometric_optimization')
import cv2
import os
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from glob import glob
from tqdm import tqdm
import datetime
import imageio
os.chdir("/data/share/FLAME/Photometric_fitting/photometric_optimization")
sys.path.append('/data/share/FLAME/Photometric_fitting/photometric_optimization/models')
from FLAME import FLAME, FLAMETex
from renderer import Renderer
import util
torch.backends.cudnn.benchmark = True

### Remember to change the directories in the following class if there is a change.

In [2]:

class PhotometricFitting(object):
    def __init__(self, config, device='cuda'):
        self.batch_size = config.batch_size
        self.image_size = config.image_size
        self.config = config
        self.device = device
        #
        self.flame = FLAME(self.config).to(self.device)
        self.flametex = FLAMETex(self.config).to(self.device)

        self._setup_renderer()

    def _setup_renderer(self):
        mesh_file = '/data/share/FLAME/Photometric_fitting/photometric_optimization/data/head_template_mesh.obj'
        self.render = Renderer(self.image_size, obj_filename=mesh_file).to(self.device)

    def optimize(self, images, landmarks, image_masks, savefolder=None):
        bz = images.shape[0]
        shape = nn.Parameter(torch.zeros(bz, self.config.shape_params).float().to(self.device))
        tex = nn.Parameter(torch.zeros(bz, self.config.tex_params).float().to(self.device))
        exp = nn.Parameter(torch.zeros(bz, self.config.expression_params).float().to(self.device))
        pose = nn.Parameter(torch.zeros(bz, self.config.pose_params).float().to(self.device))
        cam = torch.zeros(bz, self.config.camera_params); cam[:, 0] = 5.
        cam = nn.Parameter(cam.float().to(self.device))
        lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device))
        e_opt = torch.optim.Adam(
            [shape, exp, pose, cam, tex, lights],
            lr=self.config.e_lr,
            weight_decay=self.config.e_wd
        )
        e_opt_rigid = torch.optim.Adam(
            [pose, cam],
            lr=self.config.e_lr,
            weight_decay=self.config.e_wd
        )

        gt_landmark = landmarks

        # rigid fitting of pose and camera with 51 static face landmarks,
        # this is due to the non-differentiable attribute of contour landmarks trajectory
        for k in range(200):
            losses = {}
            vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose)
            trans_vertices = util.batch_orth_proj(vertices, cam);
            trans_vertices[..., 1:] = - trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam);
            landmarks2d[..., 1:] = - landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam);
            landmarks3d[..., 1:] = - landmarks3d[..., 1:]

            losses['landmark'] = util.l2_distance(landmarks2d[:, 17:, :2], gt_landmark[:, 17:, :2]) * config.w_lmks

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss
            e_opt_rigid.zero_grad()
            all_loss.backward()
            e_opt_rigid.step()

            loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
            for key in losses.keys():
                loss_info = loss_info + '{}: {}, '.format(key, float(losses[key]))
            if k % 10 == 0:
                print(loss_info)

            if k % 10 == 0:
                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks3d[visind]))

                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8)
                cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image)

        # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms.
        for k in range(200, 1000):
            losses = {}
            vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose)
            trans_vertices = util.batch_orth_proj(vertices, cam);
            trans_vertices[..., 1:] = - trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam);
            landmarks2d[..., 1:] = - landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam);
            landmarks3d[..., 1:] = - landmarks3d[..., 1:]

            losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2]) * config.w_lmks
            losses['shape_reg'] = (torch.sum(shape ** 2) / 2) * config.w_shape_reg  # *1e-4
            losses['expression_reg'] = (torch.sum(exp ** 2) / 2) * config.w_expr_reg  # *1e-4
            losses['pose_reg'] = (torch.sum(pose ** 2) / 2) * config.w_pose_reg

            ## render
            albedos = self.flametex(tex) / 255.
            ops = self.render(vertices, trans_vertices, albedos, lights)
            predicted_images = ops['images']
            losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * config.w_pho

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss
            e_opt.zero_grad()
            all_loss.backward()
            e_opt.step()

            loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
            for key in losses.keys():
                loss_info = loss_info + '{}: {}, '.format(key, float(losses[key]))

            if k % 10 == 0:
                print(loss_info)

            # visualize
            if k % 10 == 0:
                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks3d[visind]))
                grids['albedoimage'] = torchvision.utils.make_grid(
                    (ops['albedo_images'])[visind].detach().cpu())
                grids['render'] = torchvision.utils.make_grid(predicted_images[visind].detach().float().cpu())
                shape_images = self.render.render_shape(vertices, trans_vertices, images)
                grids['shape'] = torchvision.utils.make_grid(
                    F.interpolate(shape_images[visind], [224, 224])).detach().float().cpu()


                # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu()
                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8)

                cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image)

        single_params = {
            'shape': shape.detach().cpu().numpy(),
            'exp': exp.detach().cpu().numpy(),
            'pose': pose.detach().cpu().numpy(),
            'cam': cam.detach().cpu().numpy(),
            'verts': trans_vertices.detach().cpu().numpy(),
            'albedos':albedos.detach().cpu().numpy(),
            'tex': tex.detach().cpu().numpy(),
            'lit': lights.detach().cpu().numpy()
        }
        return single_params

    def run(self, imagepath, landmarkpath,image_mask_folder):
        # The implementation is potentially able to optimize with images(batch_size>1),
        # here we show the example with a single image fitting
        images = []
        landmarks = []
        image_masks = []

        image_name = os.path.basename(imagepath)[:-4]
        savefile = os.path.sep.join([self.config.savefolder,image_name ,image_name + '.npy'])

        # photometric optimization is sensitive to the hair or glass occlusions,
        # therefore we use a face segmentation network to mask the skin region out.
        # image_mask_folder = './FFHQ_seg/'
        image_mask_path = os.path.sep.join([image_mask_folder, image_name + '.npy'])

        image = cv2.resize(cv2.imread(imagepath), (config.cropped_size, config.cropped_size)).astype(np.float32) / 255.
        image = image[:, :, [2, 1, 0]].transpose(2, 0, 1)
        images.append(torch.from_numpy(image[None, :, :, :]).to(self.device))

        image_mask = np.load(image_mask_path, allow_pickle=True)
        image_mask = image_mask[..., None].astype('float32')
        image_mask = image_mask.transpose(2, 0, 1)
        image_mask_bn = np.zeros_like(image_mask)
        image_mask_bn[np.where(image_mask != 0)] = 1.
        image_masks.append(torch.from_numpy(image_mask_bn[None, :, :, :]).to(self.device))

        landmark = np.load(landmarkpath).astype(np.float32)
        landmark[:, 0] = landmark[:, 0] / float(image.shape[2]) * 2 - 1
        landmark[:, 1] = landmark[:, 1] / float(image.shape[1]) * 2 - 1
        landmarks.append(torch.from_numpy(landmark)[None, :, :].float().to(self.device))

        images = torch.cat(images, dim=0)
        images = F.interpolate(images, [self.image_size, self.image_size])
        image_masks = torch.cat(image_masks, dim=0)
        image_masks = F.interpolate(image_masks, [self.image_size, self.image_size])

        landmarks = torch.cat(landmarks, dim=0)
        savefolder = os.path.sep.join([self.config.savefolder, image_name])

        util.check_mkdir(savefolder)
        # optimize
        single_params = self.optimize(images, landmarks, image_masks, savefolder)
        self.render.save_obj(filename=savefile[:-4]+'.obj',
                             vertices=torch.from_numpy(single_params['verts'][0]).to(self.device),
                             textures=torch.from_numpy(single_params['albedos'][0]).to(self.device)
                             )
        np.save(savefile, single_params)
config = {
# FLAME
'flame_model_path': '/data/share/FLAME/Photometric_fitting/photometric_optimization/data/generic_model.pkl',  # acquire it from FLAME project page
'flame_lmk_embedding_path': '/data/share/FLAME/Photometric_fitting/photometric_optimization/data/landmark_embedding.npy',
'tex_space_path': '/data/share/FLAME/Photometric_fitting/photometric_optimization/data/FLAME_texture.npz',  # acquire it from FLAME project page
'camera_params': 3,
'shape_params': 100,
'expression_params': 50,
'pose_params': 6,
'tex_params': 50,
'use_face_contour': True,

'cropped_size': 256,
'batch_size': 1,
'image_size': 224,
'e_lr': 0.005,
'e_wd': 0.0001,
'savefolder': '/data/share/FLAME/facescape_fit/fit_results',
# weights of losses and reg terms
'w_pho': 8,
'w_lmks': 1,
'w_shape_reg': 1e-4,
'w_expr_reg': 1e-4,
'w_pose_reg': 0,
}

config = util.dict2obj(config)
util.check_mkdir(config.savefolder)
device_name = "cuda"
config.batch_size = 1
fitting = PhotometricFitting(config, device=device_name)   

creating the FLAME Decoder


  self.register_buffer('dynamic_lmk_faces_idx', torch.tensor(lmk_embeddings['dynamic_lmk_faces_idx'], dtype=torch.long))
  self.register_buffer('dynamic_lmk_bary_coords', torch.tensor(lmk_embeddings['dynamic_lmk_bary_coords'], dtype=self.dtype))


In [11]:
input_folder = '/data/share/FLAME/facescape_fit/images'
landmark_folder = '/data/share/FLAME/facescape_fit/landmarks'
image_mask_folder = '/data/share/FLAME/facescape_fit/masks'

for image_name in tqdm(os.listdir(input_folder)):
    image_name = image_name.split("/")[-1][:-4]
    imagepath = os.path.sep.join([input_folder, image_name + '.jpg'])
    landmarkpath = os.path.sep.join([landmark_folder, image_name + '.npy'])
    try:
        fitting.run(imagepath, landmarkpath,image_mask_folder)
    except:
        print("error in fitting {}".format(image_name))
        continue

----iter: 0, time: 2023-07-03-16:22:54
landmark: 4.442091941833496, all_loss: 4.442091941833496, 
----iter: 10, time: 2023-07-03-16:22:54
landmark: 4.025257110595703, all_loss: 4.025257110595703, 
----iter: 20, time: 2023-07-03-16:22:54
landmark: 3.6009929180145264, all_loss: 3.6009929180145264, 
----iter: 30, time: 2023-07-03-16:22:54
landmark: 3.1705591678619385, all_loss: 3.1705591678619385, 
----iter: 40, time: 2023-07-03-16:22:54
landmark: 2.737680673599243, all_loss: 2.737680673599243, 
----iter: 50, time: 2023-07-03-16:22:54
landmark: 2.3095719814300537, all_loss: 2.3095719814300537, 
----iter: 60, time: 2023-07-03-16:22:54
landmark: 1.9003210067749023, all_loss: 1.9003210067749023, 
----iter: 70, time: 2023-07-03-16:22:54
landmark: 1.536821961402893, all_loss: 1.536821961402893, 
----iter: 80, time: 2023-07-03-16:22:54
landmark: 1.2615820169448853, all_loss: 1.2615820169448853, 
----iter: 90, time: 2023-07-03-16:22:55
landmark: 1.1077897548675537, all_loss: 1.1077897548675537, 