In [None]:
import warnings
import json

from omegaconf import DictConfig
from phalp.utils import get_pylogger
from PHALP_MOD import PHALP
from HMAR_MOD import HMAR

warnings.filterwarnings("ignore")
log = get_pylogger(__name__)

# load config
cfg_fp = "/mnt/arc/levlevi/nba-positions-videos-dataset/4d-pose-extraction/PHALP/scripts/config.json"
with open(cfg_fp, "r") as f:
    cfg = DictConfig(json.load(f))

# load PHALP obj
phalp_tracker = PHALP(cfg)

In [None]:
import torch
import time
import numpy as np
from typing import List, Dict

from phalp.external.deep_sort_.detection import Detection
from phalp.utils.utils import (
    smpl_to_pose_camera_vector,
)


def get_human_features(
    phalp: PHALP,
    image,
    seg_mask,
    bbox,
    bbox_pad,
    score,
    frame_name,
    cls_id,
    frame_idx: int,
    measurments,
    gt=1,
    ann=None,
    extra_data=None,
) -> List[Dict]:
    """
    Given: an image and a set of bbxs.
    Return: a list 3D pose predictions.
    
    img in -> 3d poses out

    This by far the most complex and important function in this entire script.
    
    Can we re-write this function to support mutliple images?
    """

    NPEOPLE = len(score)
    if NPEOPLE == 0:
        return []

    img_height, img_width, new_image_size, left, top = measurments
    # resize ratio
    ratio = 1.0 / int(new_image_size) * phalp.cfg.render.res
    
    masked_image_list = []
    center_list = []
    scale_list = []
    rles_list = []
    selected_ids = []

    # crop images
    for p_ in range(NPEOPLE):
        if (
            bbox[p_][2] - bbox[p_][0] < phalp.cfg.phalp.small_w
            or bbox[p_][3] - bbox[p_][1] < phalp.cfg.phalp.small_h
        ):
            continue
        masked_image, center_, scale_, rles, center_pad, scale_pad = (
            phalp.get_croped_image(image, bbox[p_], bbox_pad[p_], seg_mask[p_])
        )

        masked_image_list.append(masked_image)
        center_list.append(center_pad)
        scale_list.append(scale_pad)
        rles_list.append(rles)
        selected_ids.append(p_)

    log.info("PHALP: masked_image_list {}".format(len(masked_image_list)))

    if len(masked_image_list) == 0:
        return []

    masked_image_list = torch.stack(masked_image_list, dim=0)
    BS = masked_image_list.size(0)

    # TODO: HMAR forward pass
    with torch.no_grad():

        extra_args = {}

        # forward pass, bulk of computation occurs here
        log.debug("Calculating HMAR forward pass")
        start = time.time()
        hmar_out = phalp.HMAR(masked_image_list.cuda(), **extra_args)
        log.info("PHALP: HMAR forward pass took {} seconds".format(time.time() - start))

        start = time.time()

        log.debug("PHALP: hmar_out {}".format(hmar_out.keys()))

        uv_vector = hmar_out["uv_vector"]

        # something i don't understand that, very fast
        appe_embedding = phalp.HMAR.autoencoder_hmar(uv_vector, en=True)

        appe_embedding = appe_embedding.view(appe_embedding.shape[0], -1)

        # simple data transform method, very fast
        pred_smpl_params, pred_joints_2d, pred_joints, pred_cam = (
            phalp.HMAR.get_3d_parameters(
                hmar_out["pose_smpl"],
                hmar_out["pred_cam"],
                center=(np.array(center_list) + np.array([left, top])) * ratio,
                img_size=phalp.cfg.render.res,
                scale=np.max(np.array(scale_list), axis=1, keepdims=True) * ratio,
            )
        )
        pred_smpl_params = [
            {k: v[i].cpu().numpy() for k, v in pred_smpl_params.items()}
            for i in range(BS)
        ]

        if phalp.cfg.phalp.pose_distance == "joints":
            pose_embedding = pred_joints.cpu().view(BS, -1)
        elif phalp.cfg.phalp.pose_distance == "smpl":
            pose_embedding = []
            for i in range(BS):
                pose_embedding_ = smpl_to_pose_camera_vector(
                    pred_smpl_params[i], pred_cam[i]
                )
                pose_embedding.append(torch.from_numpy(pose_embedding_[0]))
            pose_embedding = torch.stack(pose_embedding, dim=0)
        else:
            raise ValueError("Unknown pose distance")
        pred_joints_2d_ = pred_joints_2d.reshape(BS, -1) / phalp.cfg.render.res
        pred_cam_ = pred_cam.view(BS, -1)
        pred_joints_2d_.contiguous()
        pred_cam_.contiguous()
        loca_embedding = torch.cat(
            (pred_joints_2d_, pred_cam_, pred_cam_, pred_cam_), 1
        )

    # keeping it here for legacy reasons (T3DP), but it is not used.
    full_embedding = torch.cat(
        (appe_embedding.cpu(), pose_embedding, loca_embedding.cpu()), 1
    )

    detection_data_list = []
    for i, p_ in enumerate(selected_ids):
        detection_data = {
            "bbox": np.array(
                [
                    bbox[p_][0],
                    bbox[p_][1],
                    (bbox[p_][2] - bbox[p_][0]),
                    (bbox[p_][3] - bbox[p_][1]),
                ]
            ),
            "mask": rles_list[i],
            "conf": score[p_],
            "appe": appe_embedding[i].cpu().numpy(),
            "pose": pose_embedding[i].numpy(),
            "loca": loca_embedding[i].cpu().numpy(),
            "uv": uv_vector[i].cpu().numpy(),
            "embedding": full_embedding[i],
            "center": center_list[i],
            "scale": scale_list[i],
            "smpl": pred_smpl_params[i],
            "camera": pred_cam_[i].cpu().numpy(),
            "camera_bbox": hmar_out["pred_cam"][i].cpu().numpy(),
            "3d_joints": pred_joints[i].cpu().numpy(),
            "2d_joints": pred_joints_2d_[i].cpu().numpy(),
            "size": [img_height, img_width],
            "img_path": frame_name,
            "img_name": (
                frame_name.split("/")[-1] if isinstance(frame_name, str) else None
            ),
            "class_name": cls_id[p_],
            "time": frame_idx,
            "ground_truth": gt[p_],
            "annotations": ann[p_],
            "extra_data": extra_data[p_] if extra_data is not None else None,
        }
        detection_data_list.append(Detection(detection_data))

    log.info(
        "PHALP: the rest of the forward pass took {} seconds".format(
            time.time() - start
        )
    )

    return detection_data_list