Skip to content

Commit

Permalink
Introduce policy classes and policy manager (#475)
Browse files Browse the repository at this point in the history
* init policy manager

* init test script

* fix a bug

* add some todo

* no bug

* runnable

* finish the step

* bug free

* format

* add some scripts

* runnable

* WIP (I can't set the target speed right now!)

* bug free

* WIP Runnable but some bugs exist

* WIP

* WIP

* format

* fix many tests

* format
  • Loading branch information
PENG Zhenghao committed Jul 29, 2021
1 parent a21c684 commit 577f8c8
Show file tree
Hide file tree
Showing 28 changed files with 1,157 additions and 549 deletions.
2 changes: 1 addition & 1 deletion pgdrive/engine/asset_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from pgdrive.utils import is_win
from pgdrive.utils.utils import is_win
import os
import pathlib
import sys
Expand Down
4 changes: 2 additions & 2 deletions pgdrive/engine/core/pg_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from direct.showbase import ShowBase
from panda3d.bullet import BulletDebugNode
from panda3d.core import AntialiasAttrib, loadPrcFileData, LineSegs, PythonCallbackObject

from pgdrive.constants import RENDER_MODE_OFFSCREEN, RENDER_MODE_NONE, RENDER_MODE_ONSCREEN, PG_EDITION, CamMask, \
BKG_COLOR

from pgdrive.engine.asset_loader import AssetLoader, initialize_asset_loader, close_asset_loader
from pgdrive.engine.core.collision_callback import pg_collision_callback
from pgdrive.engine.core.force_fps import ForceFPS
Expand All @@ -18,7 +18,7 @@
from pgdrive.engine.core.pg_physics_world import PGPhysicsWorld
from pgdrive.engine.core.sky_box import SkyBox
from pgdrive.engine.core.terrain import Terrain
from pgdrive.utils import is_mac, setup_logger
from pgdrive.utils.utils import is_mac, setup_logger


def _suppress_warning():
Expand Down
2 changes: 1 addition & 1 deletion pgdrive/engine/core/sky_box.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from panda3d.core import SamplerState, Shader, NodePath, ConfigVariableString

from pgdrive.constants import CamMask
from pgdrive.utils import is_mac
from pgdrive.engine.asset_loader import AssetLoader
from pgdrive.utils.object import Object
from pgdrive.utils.utils import is_mac


class SkyBox(Object):
Expand Down
9 changes: 7 additions & 2 deletions pgdrive/engine/pgdrive_engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from pgdrive.scene_managers.base_manager import BaseManager
from typing import Dict, AnyStr
from pgdrive.engine.pgdrive_scene_cull import PGDriveSceneCull

import numpy as np

from pgdrive.engine.core.pg_world import PGWorld
from pgdrive.engine.pgdrive_scene_cull import PGDriveSceneCull
from pgdrive.scene_managers.base_manager import BaseManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -162,6 +164,9 @@ def after_step(self) -> Dict:
return step_infos

def update_state_for_all_target_vehicles(self):

# TODO(pzh): What is this function? Should we need to call it all steps?

if self.detector_mask is not None:
is_target_vehicle_dict = {
v_obj.name: self.agent_manager.is_active_object(v_obj.name)
Expand Down
11 changes: 8 additions & 3 deletions pgdrive/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from pgdrive.scene_managers.map_manager import MapManager
from pgdrive.scene_managers.object_manager import TrafficSignManager
from pgdrive.scene_managers.traffic_manager import TrafficManager
from pgdrive.utils import PGConfig, merge_dicts
from pgdrive.utils import get_np_random
from pgdrive.scene_managers.policy_manager import PolicyManager
from pgdrive.utils import PGConfig, merge_dicts, get_np_random
from pgdrive.utils.engine_utils import get_pgdrive_engine, initialize_pgdrive_engine, close_pgdrive_engine, \
pgdrive_engine_initialized, set_global_random_seed

Expand Down Expand Up @@ -53,6 +53,7 @@

# ===== Vehicle =====
vehicle_config=dict(
increment_steering=False,
show_navi_mark=True,
wheel_friction=0.6,
max_engine_force=500,
Expand Down Expand Up @@ -347,7 +348,10 @@ def for_each_vehicle(self, func, *args, **kwargs):
@property
def vehicle(self):
"""A helper to return the vehicle only in the single-agent environment!"""
assert len(self.vehicles) == 1, "env.vehicle is only supported in single-agent environment!"
assert len(self.vehicles) == 1, (
"env.vehicle is only supported in single-agent environment!"
if len(self.vehicles) > 1 else "Please initialize the environment first!"
)
ego_v = self.vehicles[DEFAULT_AGENT]
return ego_v

Expand Down Expand Up @@ -424,6 +428,7 @@ def setup_engine(self):
self.pgdrive_engine.register_manager("traffic_manager", TrafficManager())
self.pgdrive_engine.register_manager("map_manager", MapManager())
self.pgdrive_engine.register_manager("object_manager", TrafficSignManager())
self.pgdrive_engine.register_manager("policy_manager", PolicyManager())

@property
def current_map(self):
Expand Down
1 change: 0 additions & 1 deletion pgdrive/envs/pgdrive_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
mini_map=(84, 84, 250), # buffer length, width
rgb_cam=(84, 84), # buffer length, width
depth_cam=(84, 84, True), # buffer length, width, view_ground
increment_steering=False,
side_detector=dict(num_lasers=0, distance=50), # laser num, distance
show_side_detector=False,
lane_line_detector=dict(num_lasers=0, distance=20), # laser num, distance
Expand Down
25 changes: 21 additions & 4 deletions pgdrive/envs/pgdrive_env_v2_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import gym
import numpy as np

from pgdrive.envs.pgdrive_env_v2 import PGDriveEnvV2
from pgdrive.obs.state_obs import LidarStateObservation
from pgdrive.obs.observation_base import ObservationBase
from pgdrive.obs.state_obs import LidarStateObservation
from pgdrive.utils import PGConfig
from pgdrive.utils.engine_utils import get_pgdrive_engine
from pgdrive.utils.math_utils import norm, clip

DISTANCE = 50
Expand Down Expand Up @@ -155,9 +157,24 @@ def traffic_vehicle_state(self, vehicle):
s.append(state['vy'] / vehicle.MAX_SPEED)
s.append(state["cos_h"])
s.append(state["sin_h"])
s.append(state["cos_d"])
s.append(state["sin_d"])
s.append(vehicle.target_speed / vehicle.MAX_SPEED)

# TODO(pzh): This is stupid here!!
pm = get_pgdrive_engine().policy_manager
p = pm.get_policy(vehicle.name)
# s.append(state["cos_d"])
# s.append(state["sin_d"])

# TODO(pzh): This is a workaround!!
if p is None:
s.append(0.0)
s.append(0.0)
s.append(0.0)
else:
s.append(p.destination[0])
s.append(p.destination[1])
target_speed = p.target_speed
s.append(target_speed / vehicle.MAX_SPEED)

s.append(vehicle.speed / vehicle.MAX_SPEED)
s.append(math.cos(vehicle.heading))
s.append(math.sin(vehicle.heading))
Expand Down
Empty file added pgdrive/policy/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions pgdrive/policy/base_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pgdrive.utils.object import Object


class BasePolicy(Object):
def __init__(self, name=None, random_seed=None):
super(BasePolicy, self).__init__(name=name, random_seed=random_seed)

def destroy(self):
pass

def reset(self):
pass

def before_step(self, *args, **kwargs):
pass

def after_step(self, *args, **kwargs):
pass

def step(self, *args, **kwargs):
pass
Loading

0 comments on commit 577f8c8

Please sign in to comment.