In [None]:
### Mount google drive if available
try:
    from google.colab import drive
    drive.mount('/content/drive')
    drive_path = '/content/drive/MyDrive/term_paper/'
    in_colab = True
except:
    drive_path = ''
    in_colab = False

In [None]:
### Install all dependecies

# pytorch3d
import os
import sys
import torch

need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True

if need_pytorch3d:
    if torch.__version__.startswith("1.9") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{torch.__version__[0:5:2]}"
        ])
        !pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        !tar xzf 1.10.0.tar.gz
        os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'


# smpl-x
need_smplx=False
try:
    import smplx
except ModuleNotFoundError:
    need_smplx=True

if need_smplx:
    !pip install smplx
    !git clone https://github.com/vchoutas/smplx
    %cd smplx
    !python setup.py install
    %cd ..


# bps
need_bps=False
try:
    import bps
except ModuleNotFoundError:
    need_bps=True

if need_bps:
    !pip install git+https://github.com/sergeyprokudin/bps


# cleanup
!rm -rf 1.10.0.tar.gz cub-1.10.0/

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate
from torchvision.io import read_image
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes

In [None]:
import importlib
import utils.plot_structures
import utils.smpl_to_smplx
importlib.reload(utils.plot_structures)
importlib.reload(utils.smpl_to_smplx)

In [None]:
import smplx
from utils.plot_structures import plot_structure
from utils.smpl_to_smplx import humbi_smpl_mesh, construct_smplx_mesh, smpl2smplx

In [None]:
### Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
subject = 1
pose = '00000001'

smplx_model_path = drive_path + 'smplx'
smplx_model = smplx.SMPLXLayer(smplx_model_path, 'neutral').to(device)

In [None]:
global_orient, transl, body_pose, betas, scale, pose_loss, shape_loss = smpl2smplx(smplx_model, subject, pose, pose_iterations=200, shape_iterations=100)

In [None]:
smpl_mesh = humbi_smpl_mesh(subject, pose)
smplx_mesh = construct_smplx_mesh(smplx_model, global_orient, transl, body_pose, betas, scale)

plot_structure([smpl_mesh, smplx_mesh])

In [None]:
def displacement_from_smplx_param(smplx_model, betas, scale):
    smplx_faces = torch.Tensor(smplx_model.faces.astype('int')).type(torch.int32).unsqueeze(0).to(device)

    init_verts = smplx_model.forward()['vertices'].to(device) * scale
    init_mesh = Meshes(init_verts, smplx_faces)

    displaced_verts = smplx_model.forward(betas=betas)['vertices'].to(device) * scale
    displaced_mesh = Meshes(displaced_verts, smplx_faces)

    displacements = displaced_mesh.verts_packed() - init_mesh.verts_packed()
    displacements_along_nrm = torch.sum(displacements * init_mesh.verts_normals_packed(), dim=1).to(device)

    return displacements_along_nrm / scale.item()

In [None]:
shape_displacements = displacement_from_smplx_param(smplx_model, betas, scale)

init_verts = smplx_model.forward()['vertices'].to(device)
smplx_faces = torch.Tensor(smplx_model.faces.astype('int')).type(torch.int32).unsqueeze(0).to(device)

init_mesh = Meshes(init_verts, smplx_faces)
displaced_mesh = Meshes(init_verts + (init_mesh.verts_normals_packed() * shape_displacements.unsqueeze(1)), smplx_faces)

plot_structure(displaced_mesh)

In [None]:
### Extract vertex uv pixel positions on a 2D square map
# See https://github.com/facebookresearch/pytorch3d/discussions/588

def verts_uvs_positions(smplx_uv_path:str, map_size:int=1024):
    smplx_uv_mesh = load_obj(smplx_uv_path, load_textures=False)

    nb_verts = smplx_uv_mesh[0].shape[0]

    flatten_verts_idx = smplx_uv_mesh[1].verts_idx.flatten()
    flatten_textures_idx = smplx_uv_mesh[1].textures_idx.flatten()
    verts_uvs = smplx_uv_mesh[2].verts_uvs

    verts_to_uv_index = torch.zeros(nb_verts, dtype=torch.int64).to(device)
    verts_to_uv_index[flatten_verts_idx] = flatten_textures_idx
    verts_to_uvs = verts_uvs[verts_to_uv_index]

    uv_x = ( float(map_size) * verts_to_uvs[:,0] ).unsqueeze(0).to(device)
    uv_y = ( float(map_size) * (1.0 - verts_to_uvs[:,1]) ).unsqueeze(0).to(device)
    verts_uvs_positions = torch.cat((uv_x, uv_y)).moveaxis(0,1).round().to(device)

    return verts_uvs_positions

In [None]:
### Create displacement map for each vertex and perform interpolation (inpainint) between vertex values

def inpainted_displacements(subject:int, displacements:torch.Tensor, smplx_uv_path:str, path_to_textures:str):
    texture = read_image(path_to_textures + 'median_subject_%d.png' % subject)
    texture = torch.moveaxis(texture, 0, 2)
    map_size = texture.shape[:2]
    
    verts_uvs = verts_uvs_positions(smplx_uv_path, map_size[0]).flip(1)

    mask = (texture[:,:,0] == 0) & (texture[:,:,1] == 0) & (texture[:,:,2] == 0)
    displacements_uint = (displacements * 255).round().type(torch.uint8)

    interp = interpolate.LinearNDInterpolator(points=verts_uvs, values=displacements_uint.numpy(), fill_value=255/2)
    inpainted_displacements = interp( list(np.ndindex(map_size)) ).reshape(map_size)
    inpainted_displacements[mask] = 255/2

    return torch.Tensor(inpainted_displacements).to(device), ~mask, texture

In [None]:
### Test displacements inpainting
obj_path = drive_path + 'smplx/smplx_uv.obj'
texture_path = 'humbi_maps/humbi_body_texture/body_texture_medians/'

inpainted, mask, texture = inpainted_displacements(subject, shape_displacements, obj_path, texture_path)

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(inpainted, cmap='gray')

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(mask, cmap='gray')

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(texture)

In [None]:
### Code displacement mapping to tensor (and check whether it is equal to the source tensor)