In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import os, sys
sys.path.append('/home/hrai/codes/hpe_library')
from lib_import import *
from my_utils import *

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np

from IPython.display import display
from ipywidgets import interact, interactive, fixed, interact_manual, interactive_output
import threading
from ipywidgets import GridspecLayout
from ipywidgets import TwoByTwoLayout

os.chdir('/home/hrai/codes/MotionBERT')
from custom_codes.Inference_and_evaluation.utils_result_analysis.button import get_inference_button, get_visualize_button, get_root_rel_button, get_reset_button, get_go_to_max_frame_button, get_procrustes_button, get_analysis_error_button, get_analysis_dh_button, get_toggle_button
from custom_codes.Inference_and_evaluation.utils_result_analysis.progress import get_inference_progress
from custom_codes.Inference_and_evaluation.utils_result_analysis.text import get_inputs_all_shape_text, get_gts_all_shape_text, get_results_all_shape_text, get_error_max_frame_text, get_batch_num_text, get_frame_num_text, get_float_text, get_str_text
from custom_codes.Inference_and_evaluation.utils_result_analysis.slider import get_azim_slider, get_elev_slider, get_zoom_slider, get_delay_slider, get_frame_slider, get_trans_slider
from custom_codes.Inference_and_evaluation.utils_result_analysis.dropdown import get_model_list_dropdown, get_dataset_list_dropdown
from custom_codes.Inference_and_evaluation.utils_result_analysis.play import get_play_vis_button
from custom_codes.Inference_and_evaluation.utils_result_analysis.select import get_subject_select, get_action_select, get_cam_select, get_batch_select, get_part_select

box_layout = widgets.Layout(
        border='solid 1px red',
        margin='0px 10px 10px 0px',
        padding='5px 5px 5px 5px')
 
plt.clf()
plt.cla()

class analysis_tool():
    def __init__(self):
        # variables -------------------------------------------------------------------
        self.joint_names = ['pelvis', 'right_hip', 'right_knee', 'right_ankle', 'left_hip', 'left_knee', 'left_ankle', 'torso', 'neck', 'nose', 'head', 'left_shoulder', 'left_elbow', 'left_wrist', 'right_shoulder', 'right_elbow', 'right_wrist']
        self.part_list = ['R_UPPER_ARM', 'R_UNDER_ARM', 'L_UPPER_ARM', 'L_UNDER_ARM', 'R_UPPER_LEG', 'R_UNDER_LEG', 'L_UPPER_LEG', 'L_UNDER_LEG']
        self.frame_num = 0
        self.line_mode = False
        self.init_dx, self.init_dy, self.init_dz = 3, 0, 0
        self.dx, self.dy, self.dz, self.rz = 0, 0, 0, 0
        self.compare1_pose = None
        self.compare2_pose = None
        self.prev_part_ids = ['']
        self.test_out = widgets.Output()
        try:
            # load data -----------------------------------------------------
            self.load_h36m()
            # init panel ---------------------------------------------------------------------
            self.init_panel()
            # init plot ---------------------------------------------------------------------
            self.init_plot()
            # init interactive -----------------------------------------------------------------
            self.interactive()
            # init layout ----------------------------------------------------------------------
            self.init_layout()
            # init display ---------------------------------------------------------------------
            dis = display(self.ui)
        except Exception as e:
            print(e)
            with self.test_out:
                print(e)
                
### init ##############################################################################################################
### load h36m ###
    def load_h36m(self):
        print('load_h36m')
        self.h36m_3d_world, self.h36m_cam_param = load_h36m()
        self.subject_list = natsorted(list(self.h36m_3d_world._data.keys()))
        self.action_list = natsorted(list(self.h36m_3d_world._data[self.subject_list[0]].keys()))
        self.subject = self.subject_list[0]
        self.action = self.action_list[0]
        self.pose_3d_list, self.cam_param = get_pose_seq_and_cam_param(self.h36m_3d_world, self.h36m_cam_param, self.subject, self.action)
        
        target = '54138969'
        calib_mat = self.cam_param[target]['int']['calibration_matrix']
        self.init_camera(W=self.cam_param[target]['W'], H=self.cam_param[target]['H'],
                    fx=calib_mat[0][0], fy=calib_mat[1][1], cx=calib_mat[0][2], cy=calib_mat[1][2])
        
    def init_camera(self, W=1000, H=1000, cam_height=1.0, 
                    fx=1.0, fy=1.0, cx=500.0, cy=500.0,
                    init_roll_angle=0, init_pitch_angle=0, init_yaw_angle=0):
        print('init_camera')
        self.W, self.H = W, H
        self.calib_mat = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])    
                
        # camera parameter
        self.cam_height = cam_height
        self.cam_origin = np.array([0, 0, self.cam_height])

        forward = [1, 0, 0]
        left = [0, -1, 0]
        up = np.cross(forward, left)
        cam_default_R = np.array([left, up, forward]) # default camera orientation

        roll_angle = init_roll_angle
        pitch_angle = init_pitch_angle
        yaw_angle = init_yaw_angle

        self.camera = Camera(origin=self.cam_origin, 
                        calib_mat=self.calib_mat, 
                        cam_default_R=cam_default_R, 
                        roll=roll_angle,
                        pitch=pitch_angle,
                        yaw=yaw_angle,
                        IMAGE_HEIGHT=self.H, 
                        IMAGE_WIDTH=self.W)
        
        self.cam_ext = self.camera.ext_mat
        self.cam_int = self.camera.intrinsic
        self.cam_proj = self.camera.cam_proj
        
### init panel ###
    def init_panel(self):
        print('init panel')
        # select
        self.select_subject = get_subject_select(self.subject_list, rows=7)
        self.select_subject.layout.width = 'max-content'
        self.select_action = get_action_select(self.action_list, rows=7)
        self.select_action.layout.width = 'max-content'
        
        # text
        #self.frame_text = get_frame_num_text()
        self.vec_cam_origin_to_pelvis = get_str_text('vec_cam2pelvis')
        self.dist_cam_origin_to_pelvis = get_str_text('dist_cam2pelvis')
        self.text_vec_to_lhip = get_str_text('vec_to_lhip')
        self.text_vec_to_rhip = get_str_text('vec_to_rhip')
        self.text_dist_to_lhip = get_str_text('dist_to_lhip')
        self.text_dist_to_rhip = get_str_text('dist_to_rhip')
        self.text_total_length = get_str_text('total_length')
        self.text_l_ratio = get_str_text('l_ratio')
        self.text_r_ratio = get_str_text('r_ratio')
        
        # slider
        self.frame_slider = get_frame_slider(1)
        self.trans_x_slider = get_trans_slider(min=-2.0, max=2.0, description='trans x')
        self.trans_y_slider = get_trans_slider(min=-2.0, max=2.0, description='trans y')
        self.trans_z_slider = get_trans_slider(min=-2.0, max=2.0, description='trans z')
        self.rot_z_slider = get_trans_slider(min=-90, max=90, description='rot z')
        self.cam_roll_slider = get_trans_slider(min=-45, max=45, description='cam roll')
        self.cam_pitch_slider = get_trans_slider(min=-45, max=45, description='cam pitch')
        self.cam_yaw_slider = get_trans_slider(min=-45, max=45, description='cam yaw')
        self.cam_height_slider = get_trans_slider(value=self.cam_height, min=0.0, max=2.0, description='cam height')
        #widgets.jslink((self.frame_slider, 'value'), (self.frame_text, 'value'))
        
        # button
        self.frame_reset_button = get_reset_button('reset')
        self.total_reset_button = get_reset_button('total reset')
        self.trans_x_reset_button = get_reset_button('reset')
        self.trans_y_reset_button = get_reset_button('reset')
        self.trans_z_reset_button = get_reset_button('reset')
        self.rot_z_reset_button = get_reset_button('reset')
        self.cam_roll_reset_button = get_reset_button('reset')
        self.cam_pitch_reset_button = get_reset_button('reset')
        self.cam_yaw_reset_button = get_reset_button('reset')
        self.cam_height_reset_button = get_reset_button('reset')
        self.set_compare1_button = get_reset_button('set compare1')
        self.set_compare2_button = get_reset_button('set compare2')
        self.set_line_mode_button = get_reset_button('set line mode')
        # plot         
        self.plot3d = widgets.Output()
        self.plot2d = widgets.Output()
   
### init plot ###
    def init_plot(self):
        with self.plot3d:
            fig = plt.figure(1, figsize=(15, 5),  layout='tight')
            fig.clear()
            self.ax_3d_1 = axes_3d(fig, loc=121, xlim=(-1,5), ylim=(-3,3), zlim=(0,2), view=(90,180), show_axis=True)   
            self.ax_3d_2 = axes_3d(fig, loc=122, xlim=(-1,5), ylim=(-3,3), zlim=(0,2), view=(0,90), show_axis=True)
            f = plt.show()    
            
        with self.plot2d:
            fig = plt.figure(2, figsize=(10, 5),  layout='tight') 
            fig.clear()
            self.ax_input = axes_2d(fig, loc=221, normalize=True, show_axis=False)
            self.ax_canonical = axes_2d(fig, loc=222, normalize=True, show_axis=False)
            self.ax_compare1 = axes_2d(fig, loc=223, normalize=True, show_axis=False)
            self.ax_compare2 = axes_2d(fig, loc=224, normalize=True, show_axis=False)
            f = plt.show() 
      
### init layout ###
    def init_layout(self):
        print('layout')
        
        ui_select = widgets.HBox([self.select_subject, self.select_action])
        block_frame = widgets.HBox([self.frame_slider, self.frame_reset_button])
        block_trans_x = widgets.HBox([self.trans_x_slider, self.trans_x_reset_button])
        block_trans_y = widgets.HBox([self.trans_y_slider, self.trans_y_reset_button])
        block_trans_z = widgets.HBox([self.trans_z_slider, self.trans_z_reset_button])
        block_rot_z = widgets.HBox([self.rot_z_slider, self.rot_z_reset_button])
        block_cam_roll = widgets.HBox([self.cam_roll_slider, self.cam_roll_reset_button])
        block_cam_pitch = widgets.HBox([self.cam_pitch_slider, self.cam_pitch_reset_button])
        block_cam_yaw = widgets.HBox([self.cam_yaw_slider, self.cam_yaw_reset_button])
        block_cam_height = widgets.HBox([self.cam_height_slider, self.cam_height_reset_button])
        
        ui_control_frame = widgets.HBox([block_frame, self.total_reset_button])
        ui_control_pose = widgets.VBox([block_trans_x, block_trans_y, block_trans_z, block_rot_z])
        ui_control_cam = widgets.VBox([block_cam_roll, block_cam_pitch, block_cam_yaw, block_cam_height])
        ui_control = widgets.HBox([ui_control_pose, ui_control_cam])
        
        ui_monitor = widgets.VBox([self.vec_cam_origin_to_pelvis, self.dist_cam_origin_to_pelvis, self.text_vec_to_lhip, self.text_vec_to_rhip, self.text_dist_to_lhip, self.text_dist_to_rhip, self.text_total_length, self.text_l_ratio, self.text_r_ratio])
        
        # Integration
        ui_layer1 = widgets.HBox([self.plot3d, self.plot2d], layout=box_layout)
        ui_layer2 = widgets.HBox([ui_control_frame, self.set_compare1_button, self.set_compare2_button, self.set_line_mode_button])
        ui_layer3 = widgets.HBox([ui_select, ui_control, ui_monitor], layout=box_layout)
        self.ui = widgets.VBox([ui_layer1, ui_layer2, ui_layer3, self.test_out])
    

### interactive ##############################################################################################################  
    def interactive(self):
        print('interactive')
        self.select_interact = widgets.interactive(self.update_select, subject=self.select_subject, action=self.select_action)
        self.frame_interact = widgets.interactive(self.update_frame, frame=self.frame_slider)
        self.trans_interact = widgets.interactive(self.update_pose_diff, trans_x=self.trans_x_slider, trans_y=self.trans_y_slider, trans_z=self.trans_z_slider, rot_z=self.rot_z_slider)
        self.cam_interact = widgets.interactive(self.update_cam_pose, roll=self.cam_roll_slider, pitch=self.cam_pitch_slider, yaw=self.cam_yaw_slider, height=self.cam_height_slider)
        
        # reset
        self.frame_reset_button.on_click(self.reset_frame)
        self.total_reset_button.on_click(self.reset_total)
        self.trans_x_reset_button.on_click(self.reset_trans_x)
        self.trans_y_reset_button.on_click(self.reset_trans_y)
        self.trans_z_reset_button.on_click(self.reset_trans_z)
        self.rot_z_reset_button.on_click(self.reset_rot_z)
        self.cam_roll_reset_button.on_click(self.reset_cam_roll)
        self.cam_pitch_reset_button.on_click(self.reset_cam_pitch)
        self.cam_yaw_reset_button.on_click(self.reset_cam_yaw)
        self.cam_height_reset_button.on_click(self.reset_cam_height)
        
        # compare
        self.set_compare1_button.on_click(self.set_compare1)
        self.set_compare2_button.on_click(self.set_compare2)
        
        # set mode
        self.set_line_mode_button.on_click(self.set_line_mode)
        
    def set_line_mode(self, b):
        self.line_mode = not self.line_mode
        self.visualize_data()
        
    def reset_frame(self, b):
        self.frame_slider.value = 0
        
    def reset_total(self, b):
        self.trans_x_slider.value = 0
        self.trans_y_slider.value = 0
        self.trans_z_slider.value = 0
        self.rot_z_slider.value = 0
        self.cam_roll_slider.value = 0
        self.cam_pitch_slider.value = 0
        self.cam_yaw_slider.value = 0
        self.cam_height_slider.value = 1.0
        self.compare1_pose = None
        self.compare2_pose = None
        
        self.visualize_data()
        
    def reset_trans_x(self, b):
        self.trans_x_slider.value = 0
        
    def reset_trans_y(self, b):
        self.trans_y_slider.value = 0
        
    def reset_trans_z(self, b):
        self.trans_z_slider.value = 0
        
    def reset_rot_z(self, b):
        self.rot_z_slider.value = 0
        
    def reset_cam_roll(self, b):
        self.cam_roll_slider.value = 0
        
    def reset_cam_pitch(self, b):
        self.cam_pitch_slider.value = 0
        
    def reset_cam_yaw(self, b):
        self.cam_yaw_slider.value = 0
        
    def reset_cam_height(self, b):
        self.cam_height_slider.value = 1.0
        
    def set_compare1(self, b):
        self.compare1_pose = self.pose_2d_norm_canonical.copy()
        self.visualize_data()
        
    def set_compare2(self, b):
        self.compare2_pose = self.pose_2d_norm_canonical.copy()
        self.visualize_data()
        
    def update_select(self, subject, action):
        try:
            if len(self.h36m_3d_world._data.keys()) > 0:
                # with self.test_out:
                #     print('update select')
                if self.subject != subject:
                    #self.select_subject.options = natsorted(list(self.h36m_3d_world._data.keys()))
                    self.select_subject.value = subject
                    self.select_action.options = natsorted(list(self.h36m_3d_world._data[subject].keys()))
                    self.select_action.value = self.select_action.options[0]
                    self.subject = subject
                    self.action = self.select_action.value
                elif self.action != action:
                    self.select_action.value = action
                    self.action = action
                self.pose_3d_list, self.cam_param = get_pose_seq_and_cam_param(self.h36m_3d_world, self.h36m_cam_param, subject, action)
                self.frame_slider.max = len(self.pose_3d_list)-1
                self.update_ref_pose()
                self.visualize_data()
        except Exception as e:
            with self.test_out:
                print('select error', e)
                print()
                
    def update_frame(self, frame):
        if self.frame_num != frame:
            self.frame_num = frame
            self.update_ref_pose()
            self.visualize_data()
            
    def update_cam_pose(self, roll, pitch, yaw, height):
        print('update cam pose')
        self.camera.update_camera_parameter(origin=np.array([self.camera.origin[0], self.camera.origin[1], height]), roll=roll, pitch=pitch, yaw=yaw)
        self.cam_ext = self.camera.ext_mat
        self.cam_int = self.camera.intrinsic
        self.cam_proj = self.camera.cam_proj
        self.visualize_data()
            
    def update_ref_pose(self):
        frame_num = self.frame_num
        ref_pose = self.pose_3d_list[frame_num].copy()
        ref_pose = rotate_torso_by_R(ref_pose, Rotation.from_rotvec([0, 0, -np.pi/2]).as_matrix())
        ref_pose[:, :2] -= ref_pose[0, :2]
        self.ref_pose = ref_pose
            
    def update_pose_diff(self, trans_x, trans_y, trans_z, rot_z):
        self.dx = trans_x
        self.dy = trans_y
        self.dz = trans_z
        self.rz = rot_z
        self.visualize_data()
        
### functions ##############################################################################################################
    def generate_pose(self):      
        self.pose_3d = self.ref_pose.copy() + np.array([self.cam_origin[0]+self.init_dx, self.init_dy, self.init_dz]) + np.array([self.dx, self.dy, self.dz])
        self.pose_3d = rotate_torso_by_R(self.pose_3d, Rotation.from_rotvec([0, 0, radians(self.rz)]).as_matrix())
        self.pose_2d = projection(self.pose_3d, self.cam_proj)[..., :2]
        self.pose_2d_norm = normalize_screen_coordinates(self.pose_2d, self.W, self.H)
        self.pose_2d_norm_canonical = self.pose_2d_norm - self.pose_2d_norm[0]
        self.cam_3d = projection(self.pose_3d.copy(), self.cam_ext)
        
    def generate_ref_line(self):
        self.pose_3d = np.array([[0, 0, 0], [0, -0.2, 0], [0, 0.2, 0]]) + np.array([self.cam_origin[0]+self.init_dx, self.init_dy, self.init_dz+self.cam_origin[2]]) + np.array([self.dx, self.dy, self.dz])
        self.pose_3d = rotate_torso_by_R(self.pose_3d, Rotation.from_rotvec([0, 0, radians(self.rz)]).as_matrix())
        self.pose_2d = projection(self.pose_3d, self.cam_proj)[..., :2]
        self.pose_2d_norm = normalize_screen_coordinates(self.pose_2d, self.W, self.H)
        self.pose_2d_norm_canonical = self.pose_2d_norm - self.pose_2d_norm[0]
        self.cam_3d = projection(self.pose_3d.copy(), self.cam_ext)
        
    def calculate_error(self):
        cam_origin = self.camera.origin
        pelvis = self.pose_3d[0]
        vec_cam_origin_to_pelvis = pelvis - cam_origin
        mag = np.linalg.norm(vec_cam_origin_to_pelvis)
        cam_yaw = self.cam_yaw_slider.value
        
        self.dist_cam_origin_to_pelvis.value = str(mag)
        self.vec_cam_origin_to_pelvis.value = f"z: {mag * cos(radians(cam_yaw)):.2f} x: {-mag * sin(radians(cam_yaw)):.2f}"
        
        if self.set_line_mode:
            self.vec_to_lhip = self.pose_2d_norm_canonical[1] - self.pose_2d_norm_canonical[0]
            self.vec_to_rhip = self.pose_2d_norm_canonical[2] - self.pose_2d_norm_canonical[0]
            self.dist_to_lhip = np.linalg.norm(self.vec_to_lhip)
            self.dist_to_rhip = np.linalg.norm(self.vec_to_rhip)    
            self.total_length = self.dist_to_lhip + self.dist_to_rhip
            self.l_ratio = self.dist_to_lhip / self.total_length
            self.r_ratio = self.dist_to_rhip / self.total_length
            
            self.text_vec_to_lhip.value = f"x: {self.vec_to_lhip[0]:.2f} y: {self.vec_to_lhip[1]:.2f}"
            self.text_vec_to_rhip.value = f"x: {self.vec_to_rhip[0]:.2f} y: {self.vec_to_rhip[1]:.2f}"
            self.text_dist_to_lhip.value = f"{self.dist_to_lhip:.2f}"
            self.text_dist_to_rhip.value = f"{self.dist_to_rhip:.2f}"
            self.text_total_length.value = f"{self.total_length:.2f}"
            self.text_l_ratio.value = f"{self.l_ratio:.2f}"
            self.text_r_ratio.value = f"{self.r_ratio:.2f}"
        
    def visualize_data(self):
        if self.line_mode:
            self.generate_ref_line()
            dataset_type = 'base'
        else:
            self.generate_pose()
            dataset_type = 'h36m'
        self.calculate_error()
        
        with self.plot3d:
            clear_axes([self.ax_3d_1, self.ax_3d_2])
            plt.sca(self.ax_3d_1)
            self.camera.cam_frame.draw3d()
            plt.sca(self.ax_3d_2)
            self.camera.cam_frame.draw3d()
            draw_3d_pose(self.ax_3d_1, self.pose_3d, dataset=dataset_type)
            draw_3d_pose(self.ax_3d_2, self.pose_3d, dataset=dataset_type)
            
        with self.plot2d:
            clear_axes([self.ax_input, self.ax_canonical, self.ax_compare1, self.ax_compare2])
            draw_2d_pose(self.ax_input, self.pose_2d_norm, normalize=True, dataset=dataset_type)
            draw_2d_pose(self.ax_canonical, self.pose_2d_norm_canonical, normalize=True, dataset=dataset_type)
            if self.compare1_pose is not None:
                draw_2d_pose(self.ax_compare1, self.compare1_pose, normalize=True, dataset=dataset_type)
            if self.compare2_pose is not None:
                draw_2d_pose(self.ax_compare2, self.compare2_pose, normalize=True, dataset=dataset_type)
            
        
    
at = analysis_tool()
    

load_h36m
==> Loading 3D data wrt World CS...
init_camera
init panel


interactive
layout


VBox(children=(HBox(children=(Output(), Output()), layout=Layout(border_bottom='solid 1px red', border_left='s…