In [None]:
import warnings
import json

from omegaconf import DictConfig
from phalp.utils import get_pylogger
from PHALP_MOD import PHALP
from HMAR_MOD import HMAR

warnings.filterwarnings("ignore")
log = get_pylogger(__name__)

# load config
cfg_fp = "/mnt/arc/levlevi/nba-positions-videos-dataset/4d-pose-extraction/PHALP/scripts/config.json"
with open(cfg_fp, "r") as f:
    cfg = DictConfig(json.load(f))

# load PHALP obj
phalp_tracker = PHALP(cfg)

In [15]:
import torch
import numpy as np

from typing import List, Dict
from phalp.utils.utils import (
    smpl_to_pose_camera_vector,
)

BBX_FRAME_IDX_IDX = 0
BBX_BBX_IDX = 1


def get_measurements(img_frame: np.array):
    img_height, img_width, _ = img_frame.shape
    new_image_size = max(img_height, img_width)
    top, left = (
        (new_image_size - img_height) // 2,
        (new_image_size - img_width) // 2,
    )
    measurments = [img_height, img_width, new_image_size, left, top]
    return measurments


def get_segmentation_mask(img: np.array, bbx: List[float]) -> np.array:
    """
    Return a segmentation mask that is `True` where ever a bbx is contained within an image.

    Params
    : img: np.array     (H, W, 3)
    """

    height, width = img.shape[0], img.shape[1]
    assert width > height, f"failed a sanity check, height > width"

    x1, y1, x2, y2 = bbx
    all_false_img_arr = np.zeros((height, width))

    # set all values within bounding box region to 1.0
    all_false_img_arr[y1:y2, x1:x2] = 1.0

    # cast to bool type (t/f)
    segmentation_mask = all_false_img_arr.astype(bool)
    return segmentation_mask


def pre_process_bounding_boxes(
    phalp: PHALP, num_bbxs: int, image: np.array, bbxs: np.array
):

    masked_image_list = []
    center_list = []
    scale_list = []
    rles_list = []
    bounding_box_ids = []

    # pre-process batch of bounding boxes for an img
    for bbx_idx in range(num_bbxs):
        # TODO: verify that this indexing works
        bbox = bbxs[bbx_idx]
        # by default we do no padding, so these two objs are identical
        bbox_pad = bbox
        seg_mask = get_segmentation_mask(image, bbox)
        # min/max bbx size threshold | throw out small bbxs
        if (
            bbox[2] - bbox[0] < phalp.cfg.phalp.small_w
            or bbox[3] - bbox[1] < phalp.cfg.phalp.small_h
        ):
            continue
        # crop the entire image about the bounding box
        masked_image, _, _, rles, center_pad, scale_pad = phalp.get_croped_image(
            image, bbox, bbox_pad, seg_mask
        )
        masked_image_list.append(masked_image)
        center_list.append(center_pad)
        scale_list.append(scale_pad)
        rles_list.append(rles)
        bounding_box_ids.append(bbx_idx)

    return masked_image_list, center_list, scale_list, rles_list, bounding_box_ids


def post_process_hmar_results(
    hmar_out,
    phalp: PHALP,
    batch_size: int,
    center_list: List,
    scale_list: List,
    left,
    top,
    ratio,
):

    uv_vector = hmar_out["uv_vector"]

    # quickly calculate the appearance embedding
    appe_embedding = phalp.HMAR.autoencoder_hmar(uv_vector, en=True)
    appe_embedding = appe_embedding.view(appe_embedding.shape[0], -1)

    # simple data transform method, very fast
    pred_smpl_params, pred_joints_2d, pred_joints, pred_cam = (
        phalp.HMAR.get_3d_parameters(
            hmar_out["pose_smpl"],
            hmar_out["pred_cam"],
            center=(np.array(center_list) + np.array([left, top])) * ratio,
            img_size=phalp.cfg.render.res,
            scale=np.max(np.array(scale_list), axis=1, keepdims=True) * ratio,
        )
    )
    pred_smpl_params = [
        {k: v[i].cpu().numpy() for k, v in pred_smpl_params.items()}
        for i in range(batch_size)
    ]

    if phalp.cfg.phalp.pose_distance == "joints":
        pose_embedding = pred_joints.cpu().view(batch_size, -1)
    elif phalp.cfg.phalp.pose_distance == "smpl":
        pose_embedding = []
        for i in range(batch_size):
            pose_embedding_ = smpl_to_pose_camera_vector(
                pred_smpl_params[i], pred_cam[i]
            )
            pose_embedding.append(torch.from_numpy(pose_embedding_[0]))
        pose_embedding = torch.stack(pose_embedding, dim=0)
    else:
        raise ValueError("Unknown pose distance")
    pred_joints_2d_ = pred_joints_2d.reshape(batch_size, -1) / phalp.cfg.render.res
    pred_cam_ = pred_cam.view(batch_size, -1)
    pred_joints_2d_.contiguous()
    pred_cam_.contiguous()
    loca_embedding = torch.cat((pred_joints_2d_, pred_cam_, pred_cam_, pred_cam_), 1)

    # keeping it here for legacy reasons (T3DP), but it is not used.
    full_embedding = torch.cat(
        (appe_embedding.cpu(), pose_embedding, loca_embedding.cpu()), 1
    )
    return (
        appe_embedding,
        pose_embedding,
        loca_embedding,
        uv_vector,
        full_embedding,
        pred_smpl_params,
        pred_cam,
        pred_joints,
        pred_joints_2d,
    )


def format_results(
    bbxs: np.array,
    bounding_box_ids,
    rles_list,
    appe_embedding,
    pose_embedding,
    loca_embedding,
    uv_vector,
    full_embedding,
    center_list,
    scale_list,
    pred_smpl_params,
    pred_cam_,
    hmar_out,
    pred_joints,
    pred_joints_2d_,
) -> List[Dict]:
    detection_data_list = []
    for i, bbx_idx in enumerate(bounding_box_ids):
        detection_data = {
            "bbox": np.array(
                [
                    bbxs[bbx_idx][0],
                    bbxs[bbx_idx][1],
                    (bbxs[bbx_idx][2] - bbxs[bbx_idx][0]),
                    (bbxs[bbx_idx][3] - bbxs[bbx_idx][1]),
                ]
            ),
            "mask": rles_list[i],
            "appe": appe_embedding[i].cpu().numpy(),
            "pose": pose_embedding[i].numpy(),
            "loca": loca_embedding[i].cpu().numpy(),
            "uv": uv_vector[i].cpu().numpy(),
            "embedding": full_embedding[i],
            "center": center_list[i],
            "scale": scale_list[i],
            "smpl": pred_smpl_params[i],
            "camera": pred_cam_[i].cpu().numpy(),
            "camera_bbox": hmar_out["pred_cam"][i].cpu().numpy(),
            "3d_joints": pred_joints[i].cpu().numpy(),
            "2d_joints": pred_joints_2d_[i].cpu().numpy(),
        }
        detection_data_list.append(detection_data)
    return detection_data_list

In [16]:
import torch
import time
import numpy as np
from typing import List, Dict

from PHALP_MOD import PHALP


@torch.no_grad()
def predict_3d_poses(
    phalp: PHALP,
    img: np.array,
    bbxs: np.array,
) -> List[Dict]:
    """
    Params
    N: # bounding-boxes
    :img: np.array  (H, W, 3)
    :bbxs: np.array (N, 4) | [[ x1, y1, x2, y2]]
    """

    NUM_BBXS = bbxs.shape[0]
    if NUM_BBXS == 0:
        log.warn(f"A bbxs with dim 0 == 0 passed to `predict_3d_poses`")
        return []

    _, _, new_img_size, img_left, img_top = get_measurements(img)

    # used later for calculating some other vars
    img_ratio = 1.0 / int(new_img_size) * phalp.cfg.render.res

    # pre-process images and bouding boxes
    masked_image_list, center_list, scale_list, rles_list, bounding_box_ids = (
        pre_process_bounding_boxes(phalp, NUM_BBXS, img, bbxs)
    )

    log.info("PHALP: masked_image_list {}".format(len(masked_image_list)))

    if len(masked_image_list) == 0:
        log.error(f"No masked images generated for a non-empty input set")
        raise Exception
        return []

    # tensor of shape (N, H, W, 3)
    masked_image_list = torch.stack(masked_image_list, dim=0)
    batch_size = masked_image_list.size(0)

    log.debug("HMAR forward pass... ")

    # TODO: HMAR forward pass
    # this function appears to accept ONE IMAGE but UNLIMITED BBXS per forward pass (N, H, W, 3)
    start = time.time()
    hmar_out = phalp.HMAR(masked_image_list.cuda(), **{})
    log.info("PHALP: HMAR forward pass took {} seconds".format(time.time() - start))

    start = time.time()
    log.debug("PHALP: hmar_out {}".format(hmar_out.keys()))

    # post-process results of the HMAR forward pass
    (
        appe_embedding,
        pose_embedding,
        loca_embedding,
        uv_vector,
        full_embedding,
        pred_smpl_params,
        pred_cam,
        pred_joints,
        pred_joints_2d,
    ) = post_process_hmar_results(
        hmar_out,
        phalp,
        batch_size,
        center_list,
        scale_list,
        img_left,
        img_top,
        img_ratio,
    )

    results = format_results(
        bbxs,
        bounding_box_ids,
        rles_list,
        appe_embedding,
        pose_embedding,
        loca_embedding,
        uv_vector,
        full_embedding,
        center_list,
        scale_list,
        pred_smpl_params,
        pred_cam,
        hmar_out,
        pred_joints,
        pred_joints_2d,
    )
    log.info(
        "PHALP: the rest of the forward pass took {} seconds".format(
            time.time() - start
        )
    )
    return results

In [17]:
FRAME_IDX = 1
NUM_BBXS = 1
seg_mask = np.ones((NUM_BBXS, 720, 1280)).astype(bool)

results = predict_3d_poses(
    phalp=phalp_tracker,
    img=np.random.random((720, 1280, 3)),
    bbxs=np.array([[0, 100, 100, 200]] * NUM_BBXS),
)

In [19]:
results

[{'bbox': array([  0, 100, 100, 100]),
  'mask': [{'size': [720, 1280],
    'counts': 'T3T3\\c000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000lbmi0'}],
  'appe': array([ 7.948,  7.852,  8.018, ...,  12.584,  12.803,  11.134],
        dtype=float32),
  'pose': array([ 0.901, -0.119, -0.417, -0.035, -0.978,  0.204, -0.432, -0.169,
         -0.886,  0.982, -0.188,  0.028,  0.150,  0.859,  0.490, -0.116,
         -0.477,  0.871,  0.993,  0.113,  0.041, -0.118,  0.856,  0.503,
          0.022, -0.504,  0.863,  0.999, -0.026, -0.040, -0.000,  0.839,
         -0.544,  0.048,  0.544,  0.838,  0.997,  0.071, -0.041, -0.082,
          0.853, -0.515, -0.001,  0.517,  0.856,  0.995, -0.087,  0.055,
          0.103,  0.864, -0.493, -0.005,  0.496,  0.868,  1.000, -0.022,
          0.001,  0.022,  0.984, -0.179,  0.003,  0.179,  0.984,  0.988,
        

In [None]:
# TODO: call forward pass on a dummy input

# TODO: new output format
#    pred_bbox,  # [[x1, y1, x2, y2]] -- (NUM_BBXS, 4)
#         pred_bbox_pad,  # IDENTICAL TO `pred_bbox`
#         pred_masks,  # (NUM_BBXS, H, W) [False if out of BBX, True if in BBX]
#         pred_scores,  # [ 1. ] * NUM_BBXS
#         pred_classes,  # [0] * NUM_BBXS
#         gt_tids,  # [1] * NUM_BBXS
#         gt_annots,  # [[]] * NUM_BBXS

# TODO: COMPLETE OUTPUT FORMAT
# how long does a single forward pass take
# i.e. do we need batch processing?

# N: # BBXS
#     image_frame,       (N, H, W, 3) # TODO: how do we handle duplicate frames efficently, using a dict (duh!)
#     pred_masks,        (N, H, W)
#     pred_bbox,         (N, 4)
#     pred_bbox_pad,     (N, 4)
#     pred_scores,       (1.0) * N
#     frame_name,        (None) * N
#     pred_classes,      (0) * N
#     frame_idx,         (INT) * N
#     measurments,       (N, 5)
#     gt_tids,           (1) * N
#     gt_annots,         (()) * N
#     extra_data,        list(range(len(pred_scores))) * N