In [1]:
from regelum.system import System
import sys
sys.path.append('../')

from src.simulator import SciPy
from regelum.scenario import Scenario
from regelum.policy import Policy
from regelum import callback

from regelum.system import System
import numpy as np

class KinematicPoint(System):
    _name = "kinematic-point"
    _system_type = "diff_eqn"
    _dim_state = 2
    _dim_inputs = 2
    _dim_observation = 2
    _observation_naming = _state_naming = ["x", "y"]
    _inputs_naming = ["v_x", "v_y"]
    _action_bounds = [[-10.0, 10.0], [-10.0, 10.0]]

    def _compute_state_dynamics(self, time, state, inputs):
        return inputs  # The velocity inputs directly define the rate of change of position.
    
    def compute_closed_loop_rhs(self, time, state):
        # print(state)
        return self._compute_state_dynamics(time, state, self.inputs)
    
class PDController(Policy):
    def __init__(
        self,
        system: KinematicPoint,
        sampling_time: float,
    ):
        super().__init__()
        self.system = system
        self.sampling_time = sampling_time
        
        self.pd_coefs: list[float] = [
            1,
            0.1,
        ]
        
    def get_action(self, observation):

        return np.array([[-observation[0, 0], -observation[0, 1]]])
    
# Define the initial state (initial position of the kinematic point).
initial_state = np.array([2.0, 2.0])  # Start at position (2, 2)

# Initialize the kinematic point system.
kinematic_point = KinematicPoint()

# Instantiate a simulator for the kinematic point system.
simulator =  SciPy(
    system=kinematic_point, state_init=initial_state, time_final=4, max_step=0.1
)

scenario = Scenario(
    policy=PDController(
        system=kinematic_point,
        sampling_time=0.01
    ),
    simulator=simulator,
    sampling_time=0.01,
    N_episodes=1,
    N_iterations=1,
)

scenario.run()

history no exist
create hist


In [8]:
history = simulator.state_history[:-2]
states_name = kinematic_point._state_naming
actions_name = kinematic_point._inputs_naming

features_name = states_name
features_name.extend(actions_name)

states = np.array([x[0] for x in history])
actions = np.array([x[2] for x in history])
times = np.array([x[1] for x in history])

features = np.concatenate((states, actions), axis=1)


In [9]:
features

In [10]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation
import matplotlib
# matplotlib.use('Agg')
matplotlib.use('Qt5Agg')
# First set up the figure, the axis, and the plot element we want to animate
# fig = plt.figure()
# ax1 = plt.subplot(1, 1, 1)
# ax2 = plt.subplot(1, 2, 1)
fig, axes = plt.subplots(features.shape[1])


# ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
# line, = ax.plot([], [], lw=2)



# initialization function: plot the background of each frame
def init():
    for j, ax in enumerate(axes):
        ax.clear()
        # ax2.clear()
        ax.set_xlim(times[0], times[-1])
        # ax2.set_xlim(times[0], times[-1])

        ax.set_ylim(features[:, j].min(), features[:, j].max())
        # ax2.set_ylim(states[:, 1].min(), states[:, 1].max())
        ax.set(ylabel=states_name[j])
    

    axes[-1].set(xlabel='time')
    ## !! labeling
    # line.set_data([], [])
    return axes

# animation function.  This is called sequentially
def animate(i):
    for j, ax in enumerate(axes):
        ax.plot(times[:i], features[:i, j], c='red')
        # ax2.plot(times[:i], states[:i, 1], c='red')

    return axes


    # x = np.linspace(0, 2, 1000)
    # y = np.sin(2 * np.pi * (x - 0.01 * i))
    # line.set_data(x, y)
    # return line,

# call the animator.  blit=True means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=len(history)-1, interval=20) #, blit=True)

# save the animation as an mp4.  This requires ffmpeg or mencoder to be
# installed.  The extra_args ensure that the x264 codec is used, so that
# the video can be embedded in html5.  You may need to adjust this for
# your system: for more information, see
# http://matplotlib.sourceforge.net/api/animation_api.html
# anim.save('basic_animation.mp4', fps=30)

plt.show()