In [73]:
import gymnasium as gym
import highway_env
from matplotlib import pyplot as plt
import numpy as np
import pprint
import math


In [74]:

class observation_shape:
        def __init__(self, obs, info, num_history):

            self.ego = obs[0]
            self.npc = obs[1]
            r, theta_degrees = self.lidar_proc(self.npc)
            obs = self.obs_edit(self.ego, self.npc, r, theta_degrees)
            self.num_history = num_history
            self.observation_shape = obs.shape
            infos = np.array([info["speed"], info["action"][0]], dtype=np.float16)
            self.info_shape = infos.shape
            
        def lidar_proc(self, npc):
            
            dx = npc[1]
            dy = npc[2]

            r = math.sqrt(dx**2 + dy**2)
            theta = math.atan2(dy, dx)
            theta_degrees = math.degrees(theta)

            return r, theta_degrees

        def obs_edit(self, ego, npc, r, theta_degrees):
            ego = np.concatenate([ego,[r, theta_degrees]])

            obs = np.concatenate([ego,npc])
            print(obs)
            print(obs.shape)
            return obs 
        
        def info_edit(self, info):
            return np.array([info["speed"], info["action"][0]])
                            
        def reset(self):
            
            self.last_observations = [np.zeros(self.observation_shape)] * self.num_history
            self.last_info = [np.zeros(self.info_shape)] * self.num_history

        def update_input(self, obs, info):

            self.ego = obs[0]
            self.npc = obs[1]
            r, theta_degrees = self.lidar_proc(self.npc)
            obs = self.obs_edit(self.ego, self.npc, r, theta_degrees)

            info  = self.info_edit(info)

            self.last_observations.append(obs)
            self.last_observations.pop(0)

            self.last_info.append(info)
            self.last_info.pop(0)

        def get_input(self):
            # obs_stack = np.stack(self.last_observations)
            # info_stack = np.stack(self.last_info)

            input = np.concatenate([self.last_observations[0].flatten(),
                                    self.last_observations[1].flatten(),
                                    self.last_info[0].flatten(),
                                    self.last_info[1].flatten()])


            return input

In [91]:
# envirenment config

env = gym.make('racetrack-v0', render_mode = 'rgb_array')
env.configure({ # type: ignore
    'action': {'lateral': True,
            'longitudinal': False,
            'type': 'ContinuousAction'},
    "observation": {
        "type": "Kinematics",
        "vehicles_count": 2,
        "features": ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h",
                     "heading", "long_off", "lat_off", "ang_off"],
        "features_range": {
            "x": [-100, 100],
            "y": [-100, 100],
            "vx": [-20, 20],
            "vy": [-20, 20]}
    },
    'show_trajectories': True,
    'other_vehicles': 1,
    'duration': 100,
    'collision_reward': -5,
})

# prints env configs
#* obs is flattened to 1D array for nn

pprint.pprint(env.config) # type: ignore
(obs, info), done = env.reset(), False
# obs = np.array(obs.flatten())
print("Environment is setted up.")

{'action': {'lateral': True, 'longitudinal': False, 'type': 'ContinuousAction'},
 'action_reward': -0.3,
 'centering_position': [0.5, 0.5],
 'collision_reward': -5,
 'controlled_vehicles': 1,
 'duration': 100,
 'lane_centering_cost': 4,
 'lane_centering_reward': 1,
 'manual_control': False,
 'observation': {'features': ['presence',
                              'x',
                              'y',
                              'vx',
                              'vy',
                              'cos_h',
                              'sin_h',
                              'heading',
                              'long_off',
                              'lat_off',
                              'ang_off'],
                 'features_range': {'vx': [-20, 20],
                                    'vy': [-20, 20],
                                    'x': [-100, 100],
                                    'y': [-100, 100]},
                 'type': 'Kinematics',
                 'vehicles

  logger.warn(
  logger.warn(


In [98]:
print(obs[0])
print(obs[1])


[ 1.          0.80936617  0.05        0.5         0.          1.
  0.          0.         38.936615    0.          0.        ]
[ 1.          0.25822103 -0.00930938 -0.20981804 -0.08148462  0.9627625
 -0.27034873 -0.27375525  6.843881   -0.          0.        ]


In [99]:
proc = observation_shape(obs,info,2)
proc.reset()
input = proc.get_input()
print(input)

[ 1.00000000e+00  8.09366167e-01  5.00000007e-02  5.00000000e-01
  0.00000000e+00  1.00000000e+00  0.00000000e+00  0.00000000e+00
  3.89366150e+01  0.00000000e+00  0.00000000e+00  2.58388787e-01
 -2.06473309e+00  1.00000000e+00  2.58221030e-01 -9.30938404e-03
 -2.09818035e-01 -8.14846158e-02  9.62762475e-01 -2.70348728e-01
 -2.73755252e-01  6.84388113e+00 -0.00000000e+00  0.00000000e+00]
(24,)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]


In [101]:
proc.update_input(obs,info)
input = proc.get_input()
print(input)

[ 1.00000000e+00  8.09366167e-01  5.00000007e-02  5.00000000e-01
  0.00000000e+00  1.00000000e+00  0.00000000e+00  0.00000000e+00
  3.89366150e+01  0.00000000e+00  0.00000000e+00  2.58388787e-01
 -2.06473309e+00  1.00000000e+00  2.58221030e-01 -9.30938404e-03
 -2.09818035e-01 -8.14846158e-02  9.62762475e-01 -2.70348728e-01
 -2.73755252e-01  6.84388113e+00 -0.00000000e+00  0.00000000e+00]
(24,)
[ 1.00000000e+00  8.09366167e-01  5.00000007e-02  5.00000000e-01
  0.00000000e+00  1.00000000e+00  0.00000000e+00  0.00000000e+00
  3.89366150e+01  0.00000000e+00  0.00000000e+00  2.58388787e-01
 -2.06473309e+00  1.00000000e+00  2.58221030e-01 -9.30938404e-03
 -2.09818035e-01 -8.14846158e-02  9.62762475e-01 -2.70348728e-01
 -2.73755252e-01  6.84388113e+00 -0.00000000e+00  0.00000000e+00
  1.00000000e+00  8.09366167e-01  5.00000007e-02  5.00000000e-01
  0.00000000e+00  1.00000000e+00  0.00000000e+00  0.00000000e+00
  3.89366150e+01  0.00000000e+00  0.00000000e+00  2.58388787e-01
 -2.06473309e+00  

In [103]:
proc.reset()
input = proc.get_input()
print(input.shape)

(52,)
