In [1]:
import torch
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm
from modules.definition import (
    bone_connections
)
bone_connections = np.array(bone_connections) - 1

%load_ext autoreload
%autoreload 2

In [2]:
with open('output/inner_mirror_without_normalization/pose2D.pkl', 'rb') as f:
    pose2d_list = pickle.load(f)

with open('output/inner_mirror_without_normalization/frame_info.pkl', 'rb') as f:
    frame_info_list = pickle.load(f)

In [3]:
from modules.pose_extractor.pose3d_estimator.pose3d_estimator import Pose3DEstimator
from modules.pose_extractor.pose3d_estimator.jointformer.lit_jointformer import LitJointFormer
from modules.pose_extractor.pose3d_estimator.pose2d_dataset import Pose2DDataset


pose2d_dataset = Pose2DDataset(frame_info_list=frame_info_list, pose2d_list=pose2d_list)
train_loader = torch.utils.data.DataLoader(
        pose2d_dataset, batch_size=256, num_workers=24, shuffle=True
    )

pose3d_estimator = Pose3DEstimator(
    LitModel=LitJointFormer,
    lifter_saved_model='saved_models/pose3d_estimator_without_normalization/all_actors/all_actors/lightning_logs/version_0/checkpoints/epoch=19-step=7020.ckpt'
)

In [4]:
results = []
for batch in tqdm(iter(train_loader)):
    pose3d = pose3d_estimator.inference(batch)
    dict_results = []
    for index, pose_2d, root_2d, scale_factor_w, scale_factor_h, pose3d in list(
        zip(
            batch['frame_index'].detach().cpu().numpy().tolist(),
            batch['pose_2d'].detach().cpu().numpy().tolist(),
            batch['root_2d'].detach().cpu().numpy().tolist(),
            batch['scale_factor'][0].detach().cpu().numpy().tolist(),
            batch['scale_factor'][1].detach().cpu().numpy().tolist(),
            pose3d.detach().cpu().numpy().tolist()
        )
    ):
        dict_results.append(dict(
            index=index,
            pose_2d=np.array(pose_2d),
            root_2d=np.array(root_2d),
            scale_factor_w=scale_factor_w,
            scale_factor_h=scale_factor_h,
            pose_3d=np.array(pose3d)))
    results += dict_results

100%|██████████████████████████████████████████████████████████████████████████| 2738/2738 [00:30<00:00, 89.60it/s]


In [5]:
frame_info_df = pd.DataFrame(frame_info_list).set_index('frame_index').sort_index()
pose_info_df = pd.DataFrame(results).set_index('index').sort_index()
pose_df = frame_info_df.merge(pose_info_df, left_on='frame_index', right_on='index')

In [6]:
with open('output/inner_mirror_without_normalization/pose_info.pkl', 'wb') as f:
    pickle.dump(pose_df, f)