In [None]:
### Mount google drive if available
try:
    from google.colab import drive
    drive.mount('/content/drive')
    drive_path = '/content/drive/MyDrive/term_paper/'
    in_colab = True
except:
    drive_path = ''
    in_colab = False

In [None]:
### Install all dependecies

# pytorch3d
import os
import sys
import torch

need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True

if need_pytorch3d:
    if torch.__version__.startswith("1.9") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{torch.__version__[0:5:2]}"
        ])
        !pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        !tar xzf 1.10.0.tar.gz
        os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'


# smplx
need_smplx=False
try:
    import smplx
except ModuleNotFoundError:
    need_smplx=True

if need_smplx:
    !pip install smplx
    !git clone https://github.com/vchoutas/smplx
    %cd smplx
    !python setup.py install
    %cd ..


# bps
need_bps=False
try:
    import bps
except ModuleNotFoundError:
    need_bps=True

if need_bps:
    !pip install git+https://github.com/sergeyprokudin/bps


# chamfdist
if torch.cuda.is_available():
    need_chamferdist=False
    try:
        import chamferdist
    except ModuleNotFoundError:
        need_chamferdist=True
    
    if need_chamferdist:
        !pip install chamferdist


# cleanup
!rm -rf 1.10.0.tar.gz cub-1.10.0/

In [None]:
import importlib
import pointcloud_fitting
importlib.reload(pointcloud_fitting)

In [None]:
import smplx
from utils import plot_structure
from pointcloud_fitting import pointcloud_list, fit_pointclouds

In [None]:
### Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
### Construct and plot pointclouds for chosen subject

subject = 1

# select poses
poses = ['00000001', '00000017', '00000033', '00000041', '00000089']

# select all poses
'''
poses_path = 'subject_%d/body/' % subject
for pose in sorted(os.listdir(poses_path)):
    pose_path = os.path.join(poses_path, pose)
    if os.path.isdir(pose_path):
        poses.append(pose)
'''

# construct pointclouds
try:
    humbi_pointclouds
except NameError:
    print("Load pointclouds")
    humbi_pointclouds = pointcloud_list(subject, poses, device)

# plotting
plot_structure(humbi_pointclouds[:5])

In [None]:
smplx_model_path = drive_path + 'smplx'
smplx_model = smplx.SMPLXLayer(smplx_model_path, 'neutral').to(device)

In [None]:
# fit pointclouds
fitting_output = fit_pointclouds(smplx_model, subject, poses, humbi_pointclouds, global_iterations=400, shape_iterations=300)
displacements, smplx_visualizations, disps_visualizations, smplx_losses, disps_losses = fitting_output

In [None]:
plot_structure(disps_visualizations[0])

In [None]:
smplx_losses_tensor = torch.Tensor().to(device)
disps_losses_tensor = torch.Tensor().to(device)
for smplx_loss, disps_loss in zip(smplx_losses, disps_losses):
    smplx_losses_tensor = torch.cat( (smplx_losses_tensor, smplx_loss.unsqueeze(0)) )
    disps_losses_tensor = torch.cat( (disps_losses_tensor, disps_loss.unsqueeze(0)) )

min_disps_loss = disps_losses_tensor.median().item() - disps_losses_tensor.std().item()
max_disps_loss = disps_losses_tensor.median().item() + disps_losses_tensor.std().item()

min_disps_loss, max_disps_loss

In [None]:
displacements_tensor = torch.Tensor().to(device)
for disp, loss in zip(displacements, disps_losses):
    if min_disps_loss < loss < max_disps_loss:
        displacements_tensor = torch.cat( (displacements_tensor, disp.unsqueeze(0)) )

avg_disp = torch.mean(displacements_tensor, dim=0)

In [None]:
from pytorch3d.structures import Meshes

smplx_faces = torch.Tensor(smplx_model.faces.astype('int')).type(torch.int32).unsqueeze(0).to(device)
verts = smplx_model.forward()['vertices'].to(device)
init_mesh = Meshes(verts, smplx_faces)

displaced_verts = verts + (init_mesh.verts_normals_packed() * avg_disp.unsqueeze(1)).unsqueeze(0)
displaced_mesh = Meshes(displaced_verts , smplx_faces)

In [None]:
plot_structure([init_mesh, displaced_mesh])