In [1]:
import os
import h5py
import tqdm
from matplotlib import pyplot
import numpy as np
from importlib import reload
import torch
import pickle
import cv2
from PIL import Image
from scipy.spatial.transform import Rotation
import copy
import itertools
import functools
import tqdm
import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from typing import Union, Tuple, Any, Dict, List
from copy import copy
import torchvision.transforms
import ipywidgets as widgets
from IPython.display import display, HTML
from collections import defaultdict
import multiprocessing
from torch.utils.data import Subset

In [2]:
from trackertraincode import utils
from trackertraincode import vis
from trackertraincode.neuralnets.affine2d import Affine2d
from trackertraincode.neuralnets.modelcomponents import DeformableHeadKeypoints, rigid_transformation_25d
from trackertraincode.datasets.dshdf5pose import Hdf5PoseDataset
import trackertraincode.datatransformation as dtr
import trackertraincode.datatransformation.affinetrafo as dtraffine
from trackertraincode.datasets.batch import Batch

from trackertraincode.neuralnets.gmm import GaussianMixture, unpickle_scipy_gmm
# https://github.com/rfeinman/pytorch-minimize
import torchmin

In [3]:
from trackertraincode.facemodel.keypoints68 import (
    chin_left, chin_right, 
    eyecorners_left, eyecorners_right, 
    eye_left_bottom, eye_right_bottom, 
    eye_left_top, eye_right_top, 
    brows_left, brows_right, 
    nose_left, nose_right,
    uppermouth_left, uppermouth_right,
    lowermouth_left, lowermouth_right)

face_left = chin_left + eyecorners_left + eye_left_top + eye_left_bottom + uppermouth_left + lowermouth_left + brows_left + nose_left
face_right = chin_right + eyecorners_right + eye_right_top + eye_right_bottom + uppermouth_right + lowermouth_right + brows_right + nose_right

In [4]:
%matplotlib notebook

In [5]:
filename = os.path.join(os.environ['DATADIR'],'wflw_train.h5')

In [6]:
tr = torchvision.transforms.Compose([
    dtr.FocusRoi(224, extent_factor=1.2, insert_backtransform=True),
    dtr.normalize_batch)
])
ds = Hdf5PoseDataset(filename, monochrome=False, transform=tr, whitelist=['/images','/rois','/pt2d_68','/pseudolabels/*'])

fig, ax = pyplot.subplots(1,1)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
pyplot.tight_layout()

@widgets.interact(idx=(0,20))
def plot_image(idx):
    smpl = ds[idx]
    smpl = dtr.to_numpy(dtr.unnormalize_batch(smpl))
    img = vis.draw_dataset_sample(smpl)
    ax.clear()
    ax.imshow(img)
    pyplot.draw()

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=10, description='idx', max=20), Output()), _dom_classes=('widget-interac…

In [7]:
class PosedDeformableHead(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.deformablekeypoints = DeformableHeadKeypoints()

    def forward(self, quats, coords, params):
        local_keypts = self.deformablekeypoints.forward(params)
        pt3d_68 = rigid_transformation_25d(
            quats,
            coords[...,:2],
            coords[...,2:],
            local_keypts)
        assert pt3d_68.shape[-1] == 3 and pt3d_68.shape[-2] == 68
        return pt3d_68


class DeformableHeadFitting:
    def __init__(self, fit_3d_projections):
        self.fit_3d_projections = fit_3d_projections
        self.headmodel = PosedDeformableHead()
        self.shapemodel = unpickle_scipy_gmm('../trackertraincode/facemodel/shapeparams_gmm.pkl')
    
    def _extract_params(self, params):
        quats = params[...,:4]
        quats = quats / torch.linalg.norm(quats,dim=-1,keepdim=True)
        coords = params[...,4:4+3]
        shapeparams = params[...,4+3:]
        return quats, coords, shapeparams

    def evalheadmodel(self, params):
        quats, coords, shapeparams = self._extract_params(params)
        points = self.headmodel(quats, coords, shapeparams)
        return points

    def lossfunc(self, xvec, yvec, pointweights):
        points = self.evalheadmodel(xvec)
        errs = (pointweights[...,None]*torch.nn.functional.smooth_l1_loss(points[...,:2], yvec, beta=0.1, reduction='none')).mean(axis=(-2,-1))
        norm_constraint = torch.square(1. - torch.linalg.norm(xvec[...,:4]))
        shape_plausibility = -(1./150.)*self.shapemodel.score_samples(xvec[...,-50:])
        size_constraint = 10.*torch.exp(-xvec[6]/0.05)
        return errs + 1.e-6*norm_constraint + 0.01*shape_plausibility + size_constraint
    
    def _make_initial_guess(self, smpl):
        xvec = torch.concat([smpl['pose'], smpl['coord'], torch.zeros((50,))])
        return xvec
    
    def _make_target_points_and_weights(self, smpl : Batch):
        pointweights = torch.ones((68,))
 
        if not self.fit_3d_projections:
            pointweights[chin_left] *= 0.1
            pointweights[chin_right] *= 0.1

        angle_cutoff_for_fitting_jaw = 20. * np.pi/180.
        angle_cutoff_for_fitting_face_side = 70. * np.pi/180.
        
        h,_,_ = utils.as_hpb(Rotation.from_quat(smpl['pose']))
        backside_weight_jaw = max(0., float(1. - np.abs(h)/angle_cutoff_for_fitting_jaw))
        backside_weight_face_side = max(0., float(1. - np.abs(h)/angle_cutoff_for_fitting_face_side))

        if not self.fit_3d_projections:
            if h < 0:
                pointweights[face_right] = backside_weight_face_side
                pointweights[chin_right] = backside_weight_jaw
            elif h > 0:
                pointweights[face_left] = backside_weight_face_side
                pointweights[chin_left] = backside_weight_jaw
               
        if self.fit_3d_projections:
            targets = smpl['pt3d_68'][...,:2]
        else:
            targets = smpl['pt2d_68'][...]
 
        return targets, pointweights
        
    
    def fit(self, smpl : Batch):
        quat = smpl['pose']

        ytarget, pointweights = self._make_target_points_and_weights(smpl)
            
        x0 = self._make_initial_guess(smpl)

        def lossfunc_wo_shape(x):
            return self.lossfunc(torch.concat([x,x0[...,7:]],axis=-1), ytarget, pointweights)
        
        def actual_lossfunc(x):
            return self.lossfunc(x, ytarget, pointweights)
        
        ret = torchmin.minimize(lossfunc_wo_shape, x0[...,:7], method='bfgs', max_iter=50)

        x0 = torch.concat([ret['x'],x0[...,7:]],axis=-1)
        
        ret = torchmin.minimize(actual_lossfunc, x0, method='bfgs', max_iter=100, tol=0.0005)

        if not ret['success']:
            print (f"fit failure at index {smpl['index']} because {ret['message']}")

        bestfitparams = ret['x']
        fitted_points = self.evalheadmodel(bestfitparams)
        quats, coords, shapeparams = self._extract_params(bestfitparams.cpu())
        pred = Batch(smpl.meta, {
            'pt3d_68' : fitted_points.cpu().numpy(),
            'pose' : quats.numpy(),
            'coord' : coords.numpy(),
            'shapeparam' : shapeparams.numpy(),
            'index' : smpl['index'].numpy(),
            'image_backtransform' : smpl['image_backtransform'].numpy(),
        })

        return pred

headmodel = DeformableHeadFitting(False)
headmodel.fit(ds[2])

Batch()

In [8]:
#fitds = Subset(ds, list(range(40)))
fitds = ds

preds = [ headmodel.fit(smpl) for smpl in tqdm.tqdm(fitds) ]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2629/2629 [02:47<00:00, 15.74it/s]


In [9]:
back_transformed = []
for pred in preds:
    pred = dtr.to_tensor(pred)
    backtrafo = Affine2d(pred['image_backtransform'])
    back_pred = {}
    back_pred['pose'] = dtraffine.transform_rot(backtrafo, pred['pose'])
    back_pred['coord'] = dtraffine.transform_coord(backtrafo, pred['coord'])
    back_pred['pt3d_68'] = dtraffine.transform_points(backtrafo, pred['pt3d_68'])
    back_pred['shapeparam'] = pred['shapeparam']
    back_transformed.append(back_pred)

In [10]:
zipped_preds = defaultdict(list)
for pred in back_transformed:
    for k,v in pred.items():
        zipped_preds[k].append(v)
for k, v in zipped_preds.items():
    zipped_preds[k] = np.stack(v, axis=0)

In [11]:
facerender = vis.FaceRender()

In [12]:
keypointmodel = PosedDeformableHead()

In [13]:
fig, ax = pyplot.subplots(1,1)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
pyplot.tight_layout()

dsplot = Hdf5PoseDataset(filename, monochrome=False, whitelist=['/images','/rois','/pt2d_68'])

@widgets.interact(idx=(0,len(fitds)-1))
def plot_image(idx):
    smpl = copy(dsplot[idx])
    pred = copy(back_transformed[idx])
    ktps2d = smpl['pt2d_68']
    del smpl['pt2d_68']
    
    with torch.no_grad():
        kpts = keypointmodel(pred['pose'], pred['coord'], pred['shapeparam'])

    pred['pt68_3d'] = kpts
    
    pred['image'] = smpl['image']
    pred = dtr.to_numpy(pred)
    
    rendering = facerender.set(
        pred['coord'][:2],
        pred['coord'][2],
        Rotation.from_quat(pred['pose']),
        pred['shapeparam'],
        pred['image'].shape[:2]
    )

    rendering = np.ascontiguousarray(rendering)
    
    vis.draw_points3d(
        rendering,
        ktps2d.numpy(),
        labels=False,
        brightness = 128,
    )
    vis.draw_points3d(
        rendering,
        kpts.numpy(),
        labels=False
    )
    
    rendering = (0.6*rendering + 0.4*pred['image']).astype(np.uint8)
    
    img = vis.draw_dataset_sample(pred)
    
    img = np.concatenate([img, rendering],axis=1)
    
    ax.clear()
    ax.imshow(img)
    pyplot.draw()

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=1314, description='idx', max=2628), Output()), _dom_classes=('widget-int…

In [15]:
ds.close()

In [16]:
fitds.close()

In [17]:
dsplot.close()

In [18]:
from trackertraincode.datasets.dshdf5pose import create_pose_dataset, FieldCategory

C = FieldCategory

with h5py.File(filename,'a') as f:
    #del f['2dfit_v2']
    g = f.create_group("2dfit_v2.1")
    ds_quats = create_pose_dataset(g, C.quat, data=zipped_preds['pose'])
    ds_coords = create_pose_dataset(g, C.xys, data=zipped_preds['coord'])
    ds_pt3d_68 = create_pose_dataset(g, C.points, name='pt3d_68', data=zipped_preds['pt3d_68'])
    ds_shapeparams = create_pose_dataset(g, C.general, name='shapeparams', dtype=np.float16, data=zipped_preds['shapeparam'])