In [181]:
from curses.ascii import ctrl
from ucl.common import byte_print, decode_version, decode_sn, getVoltage, pretty_print_obj, lib_version
from ucl.lowState import lowState
from ucl.lowCmd import lowCmd
from ucl.unitreeConnection import unitreeConnection, LOW_WIFI_DEFAULTS, LOW_WIRED_DEFAULTS
from ucl.enums import GaitType, SpeedLevel, MotorModeLow
from ucl.complex import motorCmd, motorCmdArray
import time
import sys
import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
from datetime import datetime
import glob

import threading


# You can use one of the 3 Presets WIFI_DEFAULTS, LOW_CMD_DEFAULTS or HIGH_CMD_DEFAULTS.
# IF NONE OF THEM ARE WORKING YOU CAN DEFINE A CUSTOM ONE LIKE THIS:

class Robot :
    def __init__(self):
        ## Initialize connection ##
        print(f'Running lib version: {lib_version()}')
        self.conn = unitreeConnection(LOW_WIFI_DEFAULTS)
        self.conn.startRecv()
        ## == Initialize connection == ##

        ## instantiate lowlevel command and state objects ##
        self.lcmd = lowCmd()
        self.lstate = lowState()
        self.mCmdArr = motorCmdArray()


        # 로봇 제어에 사용되는 딕셔너리, lowcmd에서 이미 지정된 순서서
        self.d = {'FR_0':0, 'FR_1':1, 'FR_2':2,
                'FL_0':3, 'FL_1':4, 'FL_2':5,
                'RR_0':6, 'RR_1':7, 'RR_2':8,
                'RL_0':9, 'RL_1':10, 'RL_2':11 }

        # IsaacLab 에서 사용되는 Joint 순서, 제어시 사용함
        self.joint = [
            'FL_0', 'FL_1', 'FL_2', # FL_hip, FL_thigh, FL_calf
            'FR_0', 'FR_1', 'FR_2', # FR_hip, FR_thigh, FR_calf
            'RL_0', 'RL_1', 'RL_2', # RL_hip, RL_thigh, RL_calf
            'RR_0', 'RR_1', 'RR_2'  # RR_hip, RR_thigh, RR_calf
            ]
        
        ## Start Position
        #  REAL : 22,66,-160
        #  Sim : 30, 62, 180 
        # --> Sim -> Real : 0, 0, -340
        # 펴면 
        # Sim : 6, 46, -85
        # Real : 

        # IsaacLab 에서 사용되는 기본 관절 위치
        self.default_joint_pos = np.array([
            0.1,   # FL_hip
            0.8,   # FL_thigh
            -1.5,  # FL_calf
            -0.1,  # FR_hip
            0.8,   # FR_thigh
            -1.5,  # FR_calf
            0.1,   # RL_hip
            1.0,   # RL_thigh
            -1.5,  # RL_calf
            -0.1,  # RR_hip
            1.0,   # RR_thigh
            -1.5   # RR_calf
        ])

        ## get initial state & print Log ##
        # Send empty command to tell the dog the receive port and initialize the connection
        cmd_bytes = self.lcmd.buildCmd(debug=False)
        self.conn.send(cmd_bytes)


        ## Robot Control Parameters
        # control HZ : 50Hz
        self.ctrldt = 1/50 # Ctrl loop HZ : 50Hz
        # PD Parameters
        self.kp = 8
        self.kd = 1
        # Motor Torque Limit
        self.tau = [20,10,10,20,10,10,20,10,10,20,10,10]

        ## Define Default Robot Status values
        # robot status
        self.qlist = np.zeros(12) #joint pos
        self.dqlist = np.zeros(12) #joint vel
        self.rpy = np.zeros(3) #base rpy
        self.gyroscope = np.zeros(3) #base ang vel
        self.accelerometer = np.zeros(3) #base lin acc
        self.quaternion = np.zeros(4) #base quat

        # Calculat'd
        self.base_lin_vel = np.zeros(3) # base lin vel
        self.projected_gravity = np.zeros(3) # projected gravity
        self.filtered_accel = np.zeros(3)

        # Policy rel.
        self.actions = np.zeros(12) # actions(Current Target)
        self.last_actions = np.zeros(12) # last actions(Previous Target)

        # Control Target
        self.velocity_commands = np.zeros(3) # velocity commands

        # Timestamp
        self.timestamp = time.time() # Initialize timestamp
        


    ## define a function to print log data
    # 나중에 안쓰는 메서드일듯
    def printLog(self):
        print(f'Cycles:\t\t\t{self.lstate.bms.cycle}')
        data = self.conn.getData()
        packet = data[-1] #Get Last Data
        print('+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=')
        self.lstate.parseData(packet)

        print(f'SN [{byte_print(self.lstate.SN)}]:\t{decode_sn(self.lstate.SN)}')
        print(f'Ver [{byte_print(self.lstate.version)}]:\t{decode_version(self.lstate.version)}')
        print(f'SOC:\t\t\t{self.lstate.bms.SOC} %')
        print(f'Overall Voltage:\t{getVoltage(self.lstate.bms.cell_vol)} mv') #something is still wrong here ?!
        print(f'Current:\t\t{self.lstate.bms.current} mA')
        print(f'Temps BQ:\t\t{self.lstate.bms.BQ_NTC[0]} °C, {self.lstate.bms.BQ_NTC[1]}°C')
        print(f'Temps MCU:\t\t{self.lstate.bms.MCU_NTC[0]} °C, {self.lstate.bms.MCU_NTC[1]}°C')
        print(f'FootForce:\t\t{self.lstate.footForce}')
        print(f'FootForceEst:\t\t{self.lstate.footForceEst}')
        print(f'IMU Temp:\t\t{self.lstate.imu.temperature}')
        print(f'MotorState FR_0 MODE:\t\t{self.lstate.motorState[self.d["FR_0"]].mode}')
        print('+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=')

    def retrive_data(self):
        """
        conn에서 최신 패킷(가장 마지막)을 받아 lstate에 저장
        이후 값 파싱해서 self.* 에 저장
        """
        data = self.conn.getData()
        self.lstate.parseData(data[-1])
        self.q_list = np.array([self.lstate.motorState[self.d[joint]].q for joint in self.joint]) # 현재 관절위치 (절대위치!!!!)
        self.q_list_rel = np.array(self.default_joint_pos) - self.q_list # 현재 관절위치 상대위치

        self.dq_list = np.array([self.lstate.motorState[self.d[joint]].dq for joint in self.joint])

        # process IMU data
        self.rpy = np.array(self.lstate.imu.rpy)
        self.quaternion = np.array(self.lstate.imu.quaternion)
        self.gyroscope = np.array(self.lstate.imu.gyroscope)
        self.accelerometer = np.array(self.lstate.imu.accelerometer)

        # Calculate Extra Values 
        self.projected_gravity = self.estim_projected_gravity()
        self.estimate_lin_vel()

    def estim_projected_gravity(self):
        """
        projected_gravity 계산 - Quaternion 우선 사용!
        """
        # Quaternion 사용
        w, x, y, z = self.quaternion
        return np.array([
            2 * (x * z - w * y),
            2 * (w * x + y * z),
            w**2 - x**2 - y**2 + z**2
        ])
    
        
    def estimate_lin_vel(self):
        """
        가속도 적분으로 base_lin_vel 추정 (비추천)
        
        Args:
            accelerometer: (3,) 현재 가속도
            quaternion: (w, x, y, z)
            dt: 시간 간격 (예: 0.02초)
            prev_velocity: 이전 속도
            prev_filtered_accel: 이전 필터링된 가속도
            alpha: 필터 계수
            decay: 드리프트 감쇠
        
        Returns:
            velocity: (3,) 추정 속도
            filtered_accel: (3,) 필터링된 가속도
        """

        ## Set Default Values
        self.alpha = 0.9
        self.decay = 0.999
        
        # Set Previous Values
        prev_velocity = self.base_lin_vel * 1 # preveny var.linking
        prev_filtered_accel = self.filtered_accel * 1 # preveny var.linking

        # 1. projected_gravity
        pg = self.estim_projected_gravity()


        # 2. 중력 제거
        linear_accel = np.array(self.accelerometer) - pg * 9.81
        
        # 3. 필터링
        filtered_accel = self.alpha * np.array(prev_filtered_accel) + (1 - self.alpha) * linear_accel
        
        # 4. 적분
        velocity = prev_velocity + filtered_accel * self.ctrldt 
        
        # 5. 드리프트 보정
        velocity *= self.decay
        velocity[2] = 0 # Z축속도 0 으로 고정
        
        # Write output to class variables
        self.base_lin_vel = velocity
        self.filtered_accel = filtered_accel


    """
    아래 메서드들 이미 retrive_data 에서 처리됨, 없애도될듯?
    """
    def get_joint_pos(self):
        """
        self.lstate에서 최신 패킷을 받아 12개 관절의 각도(q, radian) 리스트를 반환
        --> **joint_pos** @ IsaacLab

        Returns:
            q_list (np.ndarray): 12개 관절의 각도 리스트 (radian, joint 순서는 global 'joint' 리스트와 동일)
                        오류시 None 반환
        """
        self.q_list = np.array([self.lstate.motorState[self.d[joint]].q for joint in self.joint])
        return self.q_list
    
    def get_joint_vel(self):
        """
        self.lstate에서 최신 패킷을 받아 12개 관절의 속도(dq, radian/s) 리스트를 반환
        --> **joint_vel** @ IsaacLab
        """
        self.dq_list = np.array([self.lstate.motorState[self.d[joint]].dq for joint in self.joint])
        return self.dq_list

    def get_rpy(self):
        """
        self.lstate에서 최신 패킷을 받아 로봇의 imu RPY 각도(roll, pitch, yaw, radian) 리스트를 반환
        --> **base_rpy** @ IsaacLab
        """
        self.rpy = np.array(self.lstate.imu.rpy)
        return self.rpy
    
    def get_gyroscope(self):
        """
        self.lstate에서 최신 패킷을 받아 로봇의 imu 각속도(roll, pitch, yaw, radian/s) 리스트를 반환
        --> **base_gyroscope** @ IsaacLab
        """
        self.gyroscope = np.array(self.lstate.imu.gyroscope)
        return self.gyroscope

    def move_motor(self,target_pos):
        """
        관절들들을 `target_pos` 로 이동
        """
        joint = self.joint
        for i in range(len(joint)):
            self.mCmdArr.setMotorCmd(joint[i],  motorCmd(mode=MotorModeLow.Servo, q=target_pos[i], dq = 0, Kp = self.kp, Kd = self.kd, tau = self.tau[i]))
        
        self.lcmd.motorCmd = self.mCmdArr
        cmd_bytes = self.lcmd.buildCmd(debug=False)
        self.conn.send(cmd_bytes)
    
    def do_motion(self,init_pos,target_pos,duration,debug=False):
        """
        특정 관절을 `init_pos` 에서 `target_pos` 로 선형 보간하여 이동

        Args:
            init_pos (list): 초기 관절 위치
            target_pos (list): 타겟 관절 위치
            duration (float): 이동 시간 (초)
        """
        hz = 1/self.ctrldt
        motion_step = int(duration*hz)
        print(f"do motion while : {motion_step/hz} seconds")
        for i in range(motion_step):
            # self.retrive_data()
            qDes = self.jointLinearInterpolation(init_pos,target_pos,i/motion_step)
            qDes = list(qDes)
            if debug:
                # print(f"target : {self.rad2deg(qDes,2)}")
                print(f"target : {qDes}")
                print(f"Step : {i}/{motion_step}")
            else:
                # print(f"target : {self.rad2deg(qDes,2)}")
                print(f"target : {qDes}")
                print(f"Step : {i}/{motion_step}")
                self.move_motor(qDes)

            time.sleep(self.ctrldt)

    ## 별도의 쓰레드에서 loop_func(루프문) 함수 실행
    def follow_target_pos(self):
        """
        별도의 쓰레드에서 loop_func(루프문) 함수 실행
        loop_func: 루프를 돌릴 함수. self를 인자로 받거나, 필요한 경우 *args, **kwargs로 인자 전달
        """

        def robot_control_loop():
            while True : 
                self.timestamp = time.time()
                self.retrive_data() # Refresh Data
                ## 
                # Do something HERE!
                ## 
                loop_delay = self.ctrldt - (time.time() - self.timestamp) # 루프 지연시간이 보상된 sleeptime
                time.sleep(loop_delay)


        self.control_thread = threading.Thread(target=robot_control_loop, daemon=True)
        self.control_thread.start()



    ## Utils
    def deg2rad(self,deg,digit=0):
        """
        numpy array 또는 list 형식의 각도를 radian으로 변환
        """
        deg = np.array(deg)
        if digit != 0:
            result = np.array([round(math.radians(d),digit) for d in deg])
        else:
            result = np.array([math.radians(d) for d in deg])
        return result
    
    def rad2deg(self,rad,digit=0):
        """
        numpy array 또는 list 형식의 각도를 도로 변환
        """
        rad = np.array(rad)
        if digit != 0:
            result = np.array([round(math.degrees(r),digit) for r in rad])
        else:
            result = np.array([math.degrees(r) for r in rad])
        return result


    ## Linear interpolation between two joint positions
    # Input : initPos(rate=0) ~~~~~~ targetPos(rate=1)
    # Output : interpolated_position =(p)
    def jointLinearInterpolation(self,initPos, targetPos, rate):
        rate = np.fmin(np.fmax(rate, 0.0), 1.0)
        p = initPos*(1-rate) + targetPos*rate
        return p

# 제어 인스턴스 생성
robot = Robot()
time.sleep(0.02)

robot.retrive_data() # 먼저 self.lstate 에 최신패킷 넣고

# 그다음 값 파싱해서 쓸것
print("default_joint_pos : ",robot.rad2deg(robot.default_joint_pos,2))
print("joint_pos : ",robot.rad2deg(robot.get_joint_pos(),2))
print("joint_vel : ",robot.rad2deg(robot.get_joint_vel(),2))
print("rpy : ",robot.rad2deg(robot.get_rpy(),2))
print("gyroscope : ",robot.rad2deg(robot.get_rpy(),2))

Running lib version: 0.2
default_joint_pos :  [  5.73  45.84 -85.94  -5.73  45.84 -85.94   5.73  57.3  -85.94  -5.73
  57.3  -85.94]
joint_pos :  [ -37.3    42.34 -156.41   34.97   38.76 -155.25  -32.75   15.89 -154.96
   50.01   30.59 -164.27]
joint_vel :  [-0.42  0.18  0.04 -0.21 -0.12 -0.19 -0.11 -0.13  0.13  0.84  0.32 -0.45]
rpy :  [ -1.91  -9.13 121.62]
gyroscope :  [ -1.91  -9.13 121.62]


In [99]:
## For State Logging 
# Default Power-on Joint Position (rad): [0.36284593,  1.17937028, -2.79138637, -0.34794939,  1.17319369, -2.74948239,  0.37132359,  1.18784809, -2.75150084, -0.27728164, 1.19438803, -2.88072538]
# Default Joint Position Position (deg): [  20.8   67.6 -159.9  -19.9   67.2 -157.5   21.3   68.1 -157.6  -15.9  68.4 -165.1]
# Calf Joint Range : -160(Folded) ~ -50(Unfolded) = 110 Deg
# Ankle Joint Range : -90(Folded) ~ 0(Unfolded) = 90 Deg 

robot.get_joint_pos()

array([ 0.891976  ,  1.82252562, -0.85875148,  0.89615434,  2.48100138,
       -1.36983621,  0.91032422,  2.56008625, -1.35510111,  0.90202814,
        2.52587271, -1.42893791])

In [178]:
robot.retrive_data() # 먼저 self.lstate 에 최신패킷 넣고

initpos = robot.get_joint_pos()
targetpos = np.copy(initpos)

targetpos[0:3] = [0.1,0.8,-1.5]  # FL_hip, FL_thigh, FL_calf
# print("initpos : ",robot.rad2deg(initpos,2))
# print("targetpos : ",robot.rad2deg(targetpos,2))
print("initpos : ",initpos)
print("targetpos : ",targetpos)
print("motion Diff (deg) : ",robot.rad2deg(targetpos - initpos,2))

initpos :  [-0.55177772  0.81180114 -2.71872044  0.7618432   0.72696346 -2.70184565
 -0.62171876  0.76135874 -2.63584065  0.8810761   0.88034946 -2.8669591 ]
targetpos :  [ 0.1         0.8        -1.5         0.7618432   0.72696346 -2.70184565
 -0.62171876  0.76135874 -2.63584065  0.8810761   0.88034946 -2.8669591 ]
motion Diff (deg) :  [37.34 -0.68 69.83  0.    0.    0.    0.    0.    0.    0.    0.    0.  ]


In [180]:
np.rad2deg(-2.72)

-155.84452027558393

In [179]:
robot.do_motion(initpos,targetpos,5,debug=False)

do motion while : 5.0 seconds
target : [-0.551777720451355, 0.8118011355400085, -2.7187204360961914, 0.761843204498291, 0.726963460445404, -2.701845645904541, -0.6217187643051147, 0.7613587379455566, -2.635840654373169, 0.8810760974884033, 0.8803494572639465, -2.8669590950012207]
Step : 0/250
target : [-0.5491706095695496, 0.8117539309978485, -2.7138455543518063, 0.761843204498291, 0.726963460445404, -2.701845645904541, -0.6217187643051147, 0.7613587379455566, -2.635840654373169, 0.8810760974884033, 0.8803494572639465, -2.8669590950012207]
Step : 1/250
target : [-0.5465634986877441, 0.8117067264556884, -2.708970672607422, 0.761843204498291, 0.726963460445404, -2.701845645904541, -0.6217187643051147, 0.7613587379455566, -2.635840654373169, 0.8810760974884033, 0.8803494572639465, -2.8669590950012207]
Step : 2/250
target : [-0.5439563878059387, 0.8116595219135285, -2.704095790863037, 0.761843204498291, 0.726963460445404, -2.701845645904541, -0.6217187643051147, 0.7613587379455566, -2.6358

KeyboardInterrupt: 

In [54]:
while True :
    robot.retrive_data() # 먼저 self.lstate 에 최신패킷 넣고
    print(robot.rad2deg(robot.get_joint_pos(),1))
    time.sleep(0.2)

[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.3 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1  -37.3   59.4 -155.8   49.3
   24.4 -159.5]
[ -22.2   44.9 -156.4   44.8   45.5 -155.1 

KeyboardInterrupt: 

In [48]:
np.rad2deg(-1.5)

-85.94366926962348

In [16]:
dummy_obs = [ 0.8249, -0.0841,  0.0609,  0.3209,  0.2600,  0.1070,  0.0558, -0.0461,
        -0.9974,  0.0462,  0.0558,  2.6461,  0.7924,  0.0000,  0.0595, -0.3614,
         0.3551, -0.5186,  0.4916,  0.0299, -0.4008, -0.9144, -0.3664, -0.1314,
         0.2164,  0.4206,  0.1846, -1.5743,  0.2659,  0.0053, -0.1578, -1.1460,
         1.1299,  3.0969,  1.6914, -1.2098,  1.2192, -2.4923,  0.7213, -1.5323,
         1.4793, -2.1506,  1.6954, -0.6748, -1.1294, -3.2514, -1.4343, -0.9111,
         2.1069,  2.2329,  1.4590,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,  1.7000,
         1.7000,  1.7000,  1.7000,  1.7000,  1.4263,  1.7000,  1.7000,  1.6781,
         1.5467,  1.5987,  1.5927,  1.4468,  1.4710,  1.4407,  1.3787,  1.3066,
         1.3767,  1.3718,  1.3515,  1.6321,  1.6398,  1.6475,  1.6545,  1.6276,
         1.6014,  1.5759,  1.3814,  1.5270,  1.4680,  1.4455,  1.4236,  1.4022,
         1.2308,  1.0367,  1.0207,  1.0051,  0.9898,  1.0260,  1.0756,  1.2795,
         1.2619,  1.2447,  1.2278,  1.1736,  1.1226,  1.0636,  1.0529,  1.0789,
         1.0707,  1.0503,  0.9984,  1.2508,  1.2338,  1.0780,  1.0633,  1.0488,
         1.0346,  1.0208,  1.0471,  0.9797,  0.9668,  0.9542,  0.9418,  0.9297,
         0.9178,  0.8552,  0.8989,  0.9392,  0.9768,  0.9650,  0.9534,  0.9421,
         0.9309,  0.9468,  0.9358,  0.9250,  0.9144,  0.9039,  0.8937,  0.8726,
         0.8400,  0.7919,  0.7570,  0.7741,  0.8677,  0.8945,  0.9220,  0.9115,
         0.9011,  0.8909,  0.8809,  0.7487,  0.6961,  0.6880,  0.6799,  0.6720,
         0.6642,  0.6566,  0.6490,  0.6416,  0.7624,  0.7542,  0.7462,  0.7383,
         0.7304,  0.6008,  0.5942,  0.5876,  0.5811,  0.5747,  0.7076,  0.7004,
         0.6933,  0.6863,  0.6794,  0.7514,  0.7436,  0.7359,  0.7283,  0.5664,
         0.5601,  0.5540,  0.5479,  0.5508,  0.5449,  0.5390,  0.5332,  0.5275,
         0.5218,  0.5162,  0.5107,  0.5053,  0.5000,  0.4947,  0.4907,  0.4856,
         0.4805,  0.4755,  0.4705,  0.4656,  0.4608,  0.4560,  0.4513,  0.4467,
         0.4421,  0.4376,  0.5506,  0.4679,  0.4630,  0.4581,  0.4533,  0.4486,
         0.4439,  0.4393,  0.4347,  0.4380,  0.4335,  0.4291,  0.4247,  0.4203,
         0.4160,  0.4118,  0.4076,  0.4272,  0.4527,  0.4800,  0.4565,  0.3947,
         0.3846,  0.3808,  0.3769,  0.3732,  0.3694,  0.3657,  0.3621,  0.3585,
         0.3549,  0.3513,  0.3478,  0.3739,  0.3701,  0.3664,  0.3627,  0.3590,
         0.3554,  0.3518,  0.3482,  0.3447,  0.3481,  0.3446,  0.3412,  0.3415,
         0.3612,  0.3823,  0.4047,  0.4072,  0.4035,  0.3999,  0.3963,  0.3928,
         0.3461,  0.3067,  0.3036,  0.3006,  0.2976,  0.2947,  0.2918,  0.2889,
         0.2860,  0.2832,  0.2804,  0.3004,  0.2974,  0.2944,  0.2915,  0.2885,
         0.2857,  0.2828,  0.2800,  0.2598,  0.2896,  0.3063,  0.3239,  0.3426,
         0.3448,  0.3418,  0.3388,  0.3358,  0.3328,  0.3299,  0.3270,  0.3241,
         0.3213,  0.3037,  0.2638,  0.2422,  0.2398,  0.2373,  0.2350,  0.2326,
         0.2303,  0.2279,  0.2256,  0.2414,  0.2389,  0.2365,  0.2341,  0.2157,
         0.2134,  0.2112,  0.2089,  0.2067,  0.2905,  0.2926,  0.2900,  0.2875,
         0.2849,  0.2824,  0.2800,  0.2775,  0.2751,  0.2727,  0.2703,  0.2679,
         0.2655,  0.2632,  0.2609,  0.2313,  0.2001,  0.1901,  0.1881,  0.1861,
         0.1842,  0.1823,  0.1803,  0.1929,  0.1761,  0.1742,  0.1722,  0.1703,
         0.1684,  0.1665,  0.1646,  0.1628,  0.2438,  0.2416,  0.2395,  0.2374,
         0.2352,  0.2332,  0.2311,  0.2290,  0.2270,  0.2249,  0.2229,  0.2209,
         0.2189,  0.2170,  0.2150,  0.2131,  0.2023,  0.1744,  0.1493,  0.1471,
         0.1455,  0.1438,  0.1422]
dummy_depth = dummy_obs[51:]

In [39]:
import torch
import sys
import os

from legged_loco_policy.simple_cnn_inference import (
    load_policy,
    infer,
    create_proprio_observation,
    create_example_observation,
    create_realistic_observation,
    ActorDepthCNN,
    get_activation
)

def make_obs_from_flat_tensor(flat):
    """
    51개는 proprio, 나머지는 24x32로 reshape 해서 depth로 dict로 반환.
    Args:
        flat (array-like or torch.Tensor): (51 + 768,) shape 또는 그 이상

    Returns:
        dict: {'proprio': torch.Tensor(51,), 'depth': torch.Tensor(24,32)}
    """
    flat = torch.as_tensor(flat, dtype=torch.float32)
    proprio = flat[:51]
    depth_raw = flat[51:]
    depth = depth_raw.reshape(24, 32)
    return {"proprio": proprio, "depth": depth}
    
print("✓ simple_cnn_inference.py import 완료!")

actor = load_policy("./legged_loco_policy/model_1500.pt")
print("✓ actor load 완료!")


✓ simple_cnn_inference.py import 완료!
✓ actor load 완료!


In [43]:

robot.retrive_data() # 먼저 self.lstate 에 최신패킷 넣고
# 데이터 구조화 
robot.velocity_commands = [0,0,0]


proprio = [robot.base_lin_vel,
           robot.gyroscope,
           robot.projected_gravity,
           robot.rpy,
           robot.velocity_commands,
           robot.qlist,
           robot.dqlist,
           robot.actions,
            ]
proprio_flat = torch.tensor(np.concatenate([np.array(x).flatten() for x in proprio]))

depth = torch.tensor(dummy_depth)

proprio.append(depth)


obs_flat = torch.tensor(np.concatenate([np.array(x).flatten() for x in proprio]))
obs = make_obs_from_flat_tensor(obs_flat)

In [46]:
## Finally... Inference
infer(obs,actor)

tensor([-0.0710,  0.3499, -0.4363,  0.2816, -0.0771, -0.3769, -0.2241, -0.3106,
        -0.8599,  0.1944,  0.2097, -0.2437])