# PIFuHD: Hands-on Practice


#### reference
- Saito, Shunsuke, et al. "Pifuhd: Multi-level pixel-aligned implicit function for high-resolution 3d human digitization." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020.
- https://github.com/facebookresearch/pifuhd

In [None]:
%matplotlib inline
from IPython.display import Image
Image('data/pifuhd_overview.png')

In [None]:
Image('data/lowres.png')

In [None]:
Image('data/highres.png')

## Dataset

In [None]:
import matplotlib.pyplot as plt
import os

input_path = "./sample_images"
img_path = os.path.join(input_path, "test.png")

img = plt.imread(img_path)
plt.imshow(img)

## Overall Pipeline

In [None]:
import sys
import os
import torch
from tqdm import tqdm

from lib.data import EvalWPoseDataset
from lib.model import PIFuCoarse, PIFuFine
from lib.mesh_util import save_obj_mesh, reconstruction

def run_pifuhd(opt):
    cuda = torch.device("cuda:%d" % opt.gpu_id if torch.cuda.is_available() else "cpu")

    # Load checkpoint
    state_dict_path = opt.load_netMR_checkpoint_path

    state_dict = None
    if state_dict_path is not None and os.path.exists(state_dict_path):
        print("Resuming from ", state_dict_path)
        state_dict = torch.load(state_dict_path, map_location=cuda)
        print("Warning: opt is overwritten.")
        dataroot = opt.dataroot
        resolution = opt.resolution
        results_path = opt.results_path
        loadSize = opt.loadSize

        opt = state_dict["opt"]
        opt.dataroot = dataroot
        opt.resolution = resolution
        opt.results_path = results_path
        opt.loadSize = loadSize
    else:
        raise Exception("failed loading state dict!", state_dict_path)

    ## Prepare dataset
    test_dataset = EvalWPoseDataset(opt)
    print("test data size: ", len(test_dataset))
    projection_mode = test_dataset.projection_mode
    
    ## Prepare model
    opt_netG = state_dict["opt_netG"]
    net_coarse = PIFuCoarse(opt_netG, projection_mode).to(device=cuda)
    net_fine = PIFuFine(opt, net_coarse, projection_mode).to(device=cuda)

    def set_eval():
        net_coarse.eval()
        net_fine.eval()

    ## load checkpoints
    net_fine.load_state_dict(state_dict["model_state_dict"])

    os.makedirs(opt.checkpoints_path, exist_ok=True)
    os.makedirs(opt.results_path, exist_ok=True)
    os.makedirs("%s/%s/recon" % (opt.results_path, opt.name), exist_ok=True)

    start_id, end_id = 0, len(test_dataset)

    ## inference
    with torch.no_grad():
        set_eval()

        print("generate mesh (test) ...")
        for i in tqdm(range(start_id, end_id)):
            if i >= len(test_dataset):
                break

            test_data = test_dataset[i]
            save_path = "%s/%s/recon/result_%s_%d.obj" % (
                opt.results_path,
                opt.name,
                test_data["name"],
                opt.resolution,
            )
            ## prepare input
            image_tensor_global = test_data["img_512"].to(device=cuda)
            image_tensor = test_data["img"].to(device=cuda)
            calib_tensor = test_data["calib"].to(device=cuda)

            ## extract image features
            net_fine.extract_feature_global(image_tensor_global)
            net_fine.extract_feature_local(image_tensor[:, None])

            ## reconstruct with Pixel Aligned Implicit Functions 
            verts, faces, _, _ = reconstruction(
                net_fine,
                cuda,
                calib_tensor,
                opt.resolution,
                thresh=0.5,
                num_samples=50000,
            )
            save_obj_mesh(save_path, verts, faces)
            print(f"saved the result at {save_path}")


## Run PIFuHD

### 1. Setup default arguments

In [None]:
# Setup default arguments
resolution = 512
input_path = "./sample_images"
out_path = "./results"
ckpt_path = "./checkpoints/pifuhd.pt"

### 2. Run

In [None]:
from lib.options import BaseOptions

parser = BaseOptions()

cmd = ['--dataroot', input_path, '--results_path', out_path,\
       '--loadSize', '1024', '--resolution', f"{resolution}", '--load_netMR_checkpoint_path', \
       ckpt_path]

def run_wrapper(args=None):
    opt = parser.parse(args)
    run_pifuhd(opt)
    
run_wrapper(cmd)