In [1]:
import torch
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm
%load_ext autoreload
%autoreload 2

In [2]:
with open('output/inner_mirror_with_padding/pose2D.pkl', 'rb') as f:
    pose2d_list = pickle.load(f)

with open('output/inner_mirror_with_padding/frame_info.pkl', 'rb') as f:
    frame_info_list = pickle.load(f)

In [3]:
from modules.pose_extractor.pose3d_estimator.pose3d_estimator import Pose3DEstimator
from modules.pose_extractor.pose3d_estimator.jointformer.lit_jointformer import LitJointFormer
from modules.pose_extractor.pose3d_estimator.pose2d_dataset import Pose2DDataset


pose2d_dataset = Pose2DDataset(frame_info_list=frame_info_list, pose2d_list=pose2d_list)
train_loader = torch.utils.data.DataLoader(
        pose2d_dataset, batch_size=256, num_workers=24, shuffle=True
    )

pose3d_estimator = Pose3DEstimator(
    LitModel=LitJointFormer,
    lifter_saved_model='saved_models/pose3d_estimator/all_actors/jointformer/lightning_logs/version_0/checkpoints/epoch=34-step=12285.ckpt'
)

In [4]:
results = []
for batch in tqdm(iter(train_loader)):
    pose3d = pose3d_estimator.inference(batch)
    dict_results = []
    for index, pose_2d, root_2d, scale_factor_w, scale_factor_h, pose3d in list(
        zip(
            batch['frame_index'].detach().cpu().numpy().tolist(),
            batch['pose_2d'].detach().cpu().numpy().tolist(),
            batch['root_2d'].detach().cpu().numpy().tolist(),
            batch['scale_factor'][0].detach().cpu().numpy().tolist(),
            batch['scale_factor'][1].detach().cpu().numpy().tolist(),
            pose3d.detach().cpu().numpy().tolist()
        )
    ):
        dict_results.append(dict(
            index=index,
            pose_2d=pose_2d,
            root_2d=root_2d,
            scale_factor_w=scale_factor_w,
            scale_factor_h=scale_factor_h,
            pose_3d=pose3d))
    results += dict_results

100%|███████████████████████████████████████| 2738/2738 [00:36<00:00, 74.12it/s]


In [5]:
with open('output/inner_mirror_with_padding/pose_info.pkl', 'wb') as f:
    pickle.dump(results, f)

In [6]:
tmp_df = pd.DataFrame(results)

In [7]:
tmp_df = tmp_df.sort_values('index')

In [8]:
len(results)

700843

In [9]:
len(frame_info_list)

700843

In [10]:
tmp_df

Unnamed: 0,index,pose_2d,root_2d,scale_factor_w,scale_factor_h,pose_3d
98726,0,"[[-0.2419912566721685, 0.23354965254237767], [...","[678.8348508470718, 293.5010224218897]",370.804039,464.364574,"[[0.024415314197540283, 0.02220449596643448, 0..."
594504,1,"[[0.7851466624717325, -0.13325511573045093], [...","[675.787124873307, 296.429846223951]",363.684897,567.013209,"[[0.13263992965221405, -0.040789805352687836, ..."
207838,2,"[[-0.24467456126438158, 0.22606250361437633], ...","[682.8855701779067, 291.02653253478445]",369.774037,477.809359,"[[0.02548731118440628, 0.023366957902908325, 0..."
637198,3,"[[-0.24405026438726524, 0.16948461187091535], ...","[681.9270605339568, 296.5707343573449]",359.798944,583.737259,"[[0.040584903210401535, -0.013084925711154938,..."
21199,4,"[[0.7653742295095523, -0.13845419246896554], [...","[684.5113896012731, 290.19153639614143]",377.109710,482.990141,"[[0.12188244611024857, -0.037278518080711365, ..."
...,...,...,...,...,...,...
38146,700838,"[[-0.013138014711380793, -0.21142890512127982]...","[975.5589306300742, 393.94255846017285]",370.941299,591.957234,"[[0.0051663899794220924, -0.14868715405464172,..."
421544,700839,"[[0.2508698922011951, -0.35552890805092086], [...","[883.7946016769315, 376.12863137771524]",354.830550,371.025236,"[[0.09904636442661285, -0.16549000144004822, -..."
148053,700840,"[[0.21851344704983414, -0.20983741132972916], ...","[879.803681187412, 374.6336703669537]",418.756658,609.870701,"[[0.1464637964963913, -0.10237132012844086, 0...."
617981,700841,"[[-0.013321627402003785, -0.22008265062328541]...","[973.9085571928587, 392.24742493844644]",347.025294,609.480813,"[[-0.03899464011192322, -0.11350341886281967, ..."
