In [1]:
import torch
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm
%load_ext autoreload
%autoreload 2

In [2]:
with open('output/inner_mirror/train_pose2D.pkl', 'rb') as f:
    pose2d_list = pickle.load(f)

with open('output/inner_mirror/train_frame_info.pkl', 'rb') as f:
    frame_info_list = pickle.load(f)

In [3]:
from modules.pose_extractor.pose3d_estimator.pose3d_estimator import Pose3DEstimator
from modules.pose_extractor.pose3d_estimator.jointformer.lit_jointformer import LitJointFormer
from modules.pose_extractor.pose3d_estimator.pose2d_dataset import Pose2DDataset


pose2d_dataset = Pose2DDataset(frame_info_list=frame_info_list, pose2d_list=pose2d_list)
train_loader = torch.utils.data.DataLoader(
        pose2d_dataset, batch_size=256, num_workers=24, shuffle=True
    )

pose3d_estimator = Pose3DEstimator(
    LitModel=LitJointFormer,
    lifter_saved_model='saved_models/pose3d_estimator/all_actors/jointformer/lightning_logs/version_0/checkpoints/epoch=34-step=12285.ckpt'
)

In [4]:
results = []
for batch in tqdm(iter(train_loader)):
    pose3d = pose3d_estimator.inference(batch)
    dict_results = []
    for index, pose_2d, root_2d, scale_factor_w, scale_factor_h, pose3d in list(
        zip(
            batch['frame_index'].detach().cpu().numpy().tolist(),
            batch['pose_2d'].detach().cpu().numpy().tolist(),
            batch['root_2d'].detach().cpu().numpy().tolist(),
            batch['scale_factor'][0].detach().cpu().numpy().tolist(),
            batch['scale_factor'][1].detach().cpu().numpy().tolist(),
            pose3d.detach().cpu().numpy().tolist()
        )
    ):
        dict_results.append(dict(
            index=index,
            pose_2d=pose_2d,
            root_2d=root_2d,
            scale_factor_w=scale_factor_w,
            scale_factor_h=scale_factor_h,
            pose_3d=pose3d))
    results += dict_results

100%|█████████████████████████████████████████| 981/981 [00:14<00:00, 68.10it/s]


In [6]:
with open('output/inner_mirror/train_pose_info.pkl', 'wb') as f:
    pickle.dump(results, f)