In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import os, sys
sys.path.append('/home/hrai/codes/hpe_library')
from lib_import import *
from my_utils import *

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np

from IPython.display import display
from ipywidgets import interact, interactive, fixed, interact_manual, interactive_output
import threading
from ipywidgets import GridspecLayout
from ipywidgets import TwoByTwoLayout

os.chdir('/home/hrai/codes/MotionBERT')
from custom_codes.Inference_and_evaluation.utils_result_analysis.button import get_inference_button, get_visualize_button, get_root_rel_button, get_reset_button, get_go_to_max_frame_button, get_procrustes_button, get_analysis_error_button, get_analysis_dh_button
from custom_codes.Inference_and_evaluation.utils_result_analysis.progress import get_inference_progress
from custom_codes.Inference_and_evaluation.utils_result_analysis.text import get_inputs_all_shape_text, get_gts_all_shape_text, get_results_all_shape_text, get_error_max_frame_text, get_batch_num_text, get_frame_num_text
from custom_codes.Inference_and_evaluation.utils_result_analysis.slider import get_azim_slider, get_elev_slider, get_zoom_slider, get_delay_slider, get_frame_slider, get_trans_slider
from custom_codes.Inference_and_evaluation.utils_result_analysis.dropdown import get_model_list_dropdown, get_dataset_list_dropdown
from custom_codes.Inference_and_evaluation.utils_result_analysis.play import get_play_vis_button
from custom_codes.Inference_and_evaluation.utils_result_analysis.select import get_subject_select, get_action_select, get_cam_select, get_batch_select, get_part_select

box_layout = widgets.Layout(
        border='solid 1px red',
        margin='0px 10px 10px 0px',
        padding='5px 5px 5px 5px')
 
plt.clf()
plt.cla()

class analysis_tool():
    def __init__(self):
        # variables -------------------------------------------------------------------
        self.joint_names = ['pelvis', 'right_hip', 'right_knee', 'right_ankle', 'left_hip', 'left_knee', 'left_ankle', 'torso', 'neck', 'nose', 'head', 'left_shoulder', 'left_elbow', 'left_wrist', 'right_shoulder', 'right_elbow', 'right_wrist']
        self.part_list = ['R_UPPER_ARM', 'R_UNDER_ARM', 'L_UPPER_ARM', 'L_UNDER_ARM', 'R_UPPER_LEG', 'R_UNDER_LEG', 'L_UPPER_LEG', 'L_UNDER_LEG']
        self.frame_num = 0
        self.init_dx, self.init_dy, self.init_dz = 3, 0, 0
        self.dx, self.dy, self.dz = 0, 0, 0
        self.prev_part_ids = ['']
        self.test_out = widgets.Output()
        try:
            # load data -----------------------------------------------------
            self.load_h36m()
            # init panel ---------------------------------------------------------------------
            self.init_panel()
            # init interactive -----------------------------------------------------------------
            self.interactive()
            # init layout ----------------------------------------------------------------------
            self.layout()
            # init display ---------------------------------------------------------------------
            dis = display(self.ui)
        except Exception as e:
            print(e)
            with self.test_out:
                print(e)
                
    def load_h36m(self):
        print('load_h36m')
        self.h36m_3d_world, self.h36m_cam_param = load_h36m()
        self.subject_list = natsorted(list(self.h36m_3d_world._data.keys()))
        self.action_list = natsorted(list(self.h36m_3d_world._data[self.subject_list[0]].keys()))
        self.subject = self.subject_list[0]
        self.action = self.action_list[0]
        self.pose_3d_list, self.cam_param = get_pose_seq_and_cam_param(self.h36m_3d_world, self.h36m_cam_param, self.subject, self.action)
        
        target = '54138969'
        calib_mat = self.cam_param[target]['int']['calibration_matrix']
        self.init_camera(W=self.cam_param[target]['W'], H=self.cam_param[target]['H'],
                    fx=calib_mat[0][0], fy=calib_mat[1][1], cx=calib_mat[0][2], cy=calib_mat[1][2])
    
    def init_panel(self):
        print('init panel')
        # select
        self.select_subject = get_subject_select(self.subject_list)
        self.select_subject.layout.width = 'max-content'
        self.select_action = get_action_select(self.action_list)
        self.select_action.layout.width = 'max-content'
        # test
        self.frame_text = get_frame_num_text()
        # slider
        self.frame_slider = get_frame_slider(1)
        self.trans_x_slider = get_trans_slider(description='trans x')
        self.trans_y_slider = get_trans_slider(description='trans y')
        self.trans_z_slider = get_trans_slider(description='trans z')
        #widgets.jslink((self.frame_slider, 'value'), (self.frame_text, 'value'))
        # plot         
        self.plot = widgets.Output(layout=box_layout)
        
        self.init_plot()
        
        # print('vis_control_panel')
        # print('analysis_panel')
        
    def init_camera(self, W=1000, H=1000, cam_height=1.0, 
                    fx=1.0, fy=1.0, cx=500.0, cy=500.0,
                    init_roll_angle=0, init_pitch_angle=0, init_yaw_angle=0):
        print('init_camera')
        self.W, self.H = W, H
        self.calib_mat = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])    
                
        # camera parameter
        self.cam_height = cam_height
        self.cam_origin = np.array([0, 0, self.cam_height])

        forward = [1, 0, 0]
        left = [0, -1, 0]
        up = np.cross(forward, left)
        cam_default_R = np.array([left, up, forward]) # default camera orientation

        roll_angle = init_roll_angle
        pitch_angle = init_pitch_angle
        yaw_angle = init_yaw_angle

        self.camera = Camera(origin=self.cam_origin, 
                        calib_mat=self.calib_mat, 
                        cam_default_R=cam_default_R, 
                        roll=roll_angle,
                        pitch=pitch_angle,
                        yaw=yaw_angle,
                        IMAGE_HEIGHT=self.H, 
                        IMAGE_WIDTH=self.W)
        
        self.cam_ext = self.camera.extrinsic
        self.cam_int = self.camera.intrinsic
        self.cam_proj = self.camera.cam_proj
        
    def interactive(self):
        print('interactive')
        self.select_interact = widgets.interactive(self.update_select, subject=self.select_subject, action=self.select_action)
        self.frame_interact = widgets.interactive(self.update_frame, frame=self.frame_slider)
        self.trans_interact = widgets.interactive(self.update_pose, trans_x=self.trans_x_slider, trans_y=self.trans_y_slider, trans_z=self.trans_z_slider)
        
    def layout(self):
        print('layout')
        ui_select = widgets.HBox([self.select_subject, self.select_action], layout=box_layout)
        ui_trans = widgets.VBox([self.frame_slider, self.trans_x_slider, self.trans_y_slider, self.trans_z_slider], layout=box_layout)
        ui_plot = widgets.HBox([self.plot], layout=box_layout)
        ui_layer1 = widgets.HBox([ui_select, ui_trans], layout=box_layout)
        ui_layer2 = widgets.HBox([ui_plot], layout=box_layout)
        self.ui = widgets.VBox([ui_layer1, ui_layer2, self.test_out])
        
    def update_select(self, subject, action):
        try:
            if len(self.h36m_3d_world._data.keys()) > 0:
                # with self.test_out:
                #     print('update select')
                if self.subject != subject:
                    #self.select_subject.options = natsorted(list(self.h36m_3d_world._data.keys()))
                    self.select_subject.value = subject
                    self.select_action.options = natsorted(list(self.h36m_3d_world._data[subject].keys()))
                    self.select_action.value = self.select_action.options[0]
                    self.subject = subject
                    self.action = self.select_action.value
                elif self.action != action:
                    self.select_action.value = action
                    self.action = action
                self.pose_3d_list, self.cam_param = get_pose_seq_and_cam_param(self.h36m_3d_world, self.h36m_cam_param, subject, action)
                self.frame_slider.max = len(self.pose_3d_list)-1
                self.update_ref_pose()
                self.visualize_data()
        except Exception as e:
            with self.test_out:
                print('select error', e)
                print()
                
    def update_frame(self, frame):
        if self.frame_num != frame:
            self.frame_num = frame
            self.update_ref_pose()
            self.visualize_data()
            
    def update_ref_pose(self):
        frame_num = self.frame_num
        ref_pose = self.pose_3d_list[frame_num].copy()
        ref_pose = rotate_torso_by_R(ref_pose, Rotation.from_rotvec([0, 0, -np.pi/2]).as_matrix())
        ref_pose[:, :2] -= ref_pose[0, :2]
        self.ref_pose = ref_pose
            
    def update_pose(self, trans_x, trans_y, trans_z):
        self.dx = trans_x
        self.dy = trans_y
        self.dz = trans_z
        self.visualize_data()
    
    def init_plot(self):
        with self.plot:
            self.fig = plt.figure(1, figsize=(15, 5),  layout='tight')
            self.fig.clear()
            self.ax_3d = axes_3d(self.fig, loc=131, xlim=(-3,3), ylim=(-3,3), zlim=(0,2), view=(25,180), show_axis=True)      
            self.ax_input = axes_2d(self.fig, loc=132, normalize=True, show_axis=False)
            self.ax_canonical = axes_2d(self.fig, loc=133, normalize=True, show_axis=False)
            f = plt.show()        
    
    def visualize_data(self):
        # subject = self.select_subject.value
        # action = self.select_action.value
        self.pose_3d = self.ref_pose.copy() + np.array([self.cam_origin[0]+self.init_dx, self.init_dy, self.init_dz]) + np.array([self.dx, self.dy, self.dz])
        self.pose_2d = projection(self.pose_3d, self.cam_proj)[..., :2]
        self.pose_2d_norm = normalize_screen_coordinates(self.pose_2d, self.W, self.H)
        self.pose_2d_norm_canonical = self.pose_2d_norm - self.pose_2d_norm[0]
        # with self.test_out:
        #     print(pose_2d)
        with self.plot:
            clear_axes([self.ax_3d, self.ax_input, self.ax_canonical])
            draw_3d_pose(self.ax_3d, self.pose_3d)
            draw_2d_pose(self.ax_input, self.pose_2d_norm, normalize=True)
            draw_2d_pose(self.ax_canonical, self.pose_2d_norm_canonical, normalize=True)
            #plt.show()
    
at = analysis_tool()
    

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
load_h36m
==> Loading 3D data wrt World CS...
init_camera
init panel


interactive
layout


VBox(children=(HBox(children=(HBox(children=(Select(description='Subject:', layout=Layout(width='max-content')…