In [1]:
import numpy as np

from functools import partial
from kusanagi.shell.arduino import SerialPlant
from kusanagi.utils import print_with_stamp
from kusanagi.shell.cartpole import cartpole_loss
from kusanagi.shell.cost import build_loss_func

In [2]:
# state dims are [cart_pos, cart_vel, pole_vel, pole_angle]

# control rate
dt = 0.1

# cost function
loss_task1 = partial(cartpole_loss, target=[0.0,0.0,0.0,np.pi])
loss_task2 = partial(cartpole_loss, target=[0.5,0.0,0.0,np.pi])
loss_task3 = partial(cartpole_loss, target=[-0.5,0.0,0.0,np.pi])

loss_func = build_loss_func(loss_task1, False, 'cartpole_loss')

env = SerialPlant(state_indices=[0,2,3,1], dt=dt, port='/dev/ttyACM0', maxU=[10], loss_func=loss_func)

In [3]:
# this waits for user input
# pole_angle, pole_vel, cart_pos, cart_vel = env.reset()

# this doesn't wait for user input
cart_pos, cart_vel, pole_vel, pole_angle = env._reset(wait_for_user=False)

print((cart_pos, cart_vel, pole_vel, pole_angle))

(0.0, 0.0, 0.0, 0.0)


In [6]:
env.reset()
t = env.t
# runs for 40 steps (when dt =0.1, this is 4 seconds)
for i in range(40):
    obs, cost, done, info = env.step(np.array([10])*np.sin(2*np.pi*t))
    t = info['t']
    print_with_stamp('%f        ' % cost, same_line=True)

[2018-04-19 11:37:31.760382] SerialPlant > Please reset your plant to its initial state and hit Enter

[2K[2018-04-19 11:37:41.411807] 0.339259        