Skip to content

Commit

Permalink
add torch expert (#578)
Browse files Browse the repository at this point in the history
* add torch expert and fix import chain

* format code

* add torch expert doc
  • Loading branch information
CarlDegio committed Dec 12, 2023
1 parent f247bdd commit 231c91e
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 5 deletions.
2 changes: 1 addition & 1 deletion metadrive/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from metadrive.constants import TerminationState
from metadrive.examples.ppo_expert.numpy_expert import expert
from metadrive.examples.ppo_expert import expert


def get_terminal_state(info):
Expand Down
6 changes: 5 additions & 1 deletion metadrive/examples/ppo_expert/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from metadrive.examples.ppo_expert.numpy_expert import expert
import importlib
if importlib.util.find_spec("torch") is not None:
from metadrive.examples.ppo_expert.torch_expert import torch_expert as expert
else:
from metadrive.examples.ppo_expert.numpy_expert import expert
11 changes: 8 additions & 3 deletions metadrive/examples/ppo_expert/numpy_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,25 @@ def obs_correction(obs):
def expert(vehicle, deterministic=False, need_obs=False):
global _expert_weights
global _expert_observation
obs_cfg = dict(
expert_obs_cfg = dict(
lidar=dict(num_lasers=240, distance=50, num_others=4, gaussian_noise=0.0, dropout_prob=0.0),
random_agent_model=False
)
origin_obs_cfg = dict(
lidar=dict(num_lasers=240, distance=50, num_others=0, gaussian_noise=0.0, dropout_prob=0.0),
random_agent_model=False
)

if _expert_weights is None:
_expert_weights = np.load(ckpt_path)
config = get_global_config().copy()
config["vehicle_config"].update(obs_cfg)
config["vehicle_config"].update(expert_obs_cfg)
_expert_observation = LidarStateObservation(config)
assert _expert_observation.observation_space.shape[0] == 275, "Observation not match"

vehicle.config.update(obs_cfg)
vehicle.config.update(expert_obs_cfg)
obs = _expert_observation.observe(vehicle)
vehicle.config.update(origin_obs_cfg)
obs = obs_correction(obs)
weights = _expert_weights
obs = obs.reshape(1, -1)
Expand Down
101 changes: 101 additions & 0 deletions metadrive/examples/ppo_expert/torch_expert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import numpy as np
import os.path as osp
from metadrive.engine.engine_utils import get_global_config
from metadrive.obs.state_obs import LidarStateObservation

ckpt_path = osp.join(osp.dirname(__file__), "expert_weights.npz")
_expert_weights = None
_expert_observation = None


def obs_correction(obs):
obs[15] = 1 - obs[15]
obs[10] = 1 - obs[10]
return obs


def numpy_to_torch(weights, device):
"""
Convert numpy weights to torch tensors and move them to the specified device.
:params:
weights: numpy weights
device: torch device
:return:
torch_weights: weights in torch tensor
"""
torch_weights = {}
for k in weights.keys():
torch_weights[k] = torch.from_numpy(weights[k]).to(device)
return torch_weights


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def torch_expert(vehicle, deterministic=False, need_obs=False):
"""
load weights by torch, use ppo actor to predict action
:params:
vehicle: vehicle instance
deterministic: whether to use deterministic policy
need_obs: whether to return observation
:return:
action: action predicted by expert
"""
global _expert_weights
global _expert_observation
expert_obs_cfg = dict(
lidar=dict(num_lasers=240, distance=50, num_others=4, gaussian_noise=0.0, dropout_prob=0.0),
random_agent_model=False
)
origin_obs_cfg = dict(
lidar=dict(num_lasers=240, distance=50, num_others=0, gaussian_noise=0.0, dropout_prob=0.0),
random_agent_model=False
)
with torch.no_grad(): # Disable gradient computation
if _expert_weights is None:
_expert_weights = numpy_to_torch(np.load(ckpt_path), device)
config = get_global_config().copy()
config["vehicle_config"].update(expert_obs_cfg)
_expert_observation = LidarStateObservation(config)
assert _expert_observation.observation_space.shape[0] == 275, "Observation not match"

vehicle.config.update(expert_obs_cfg)
obs = _expert_observation.observe(vehicle)
vehicle.config.update(origin_obs_cfg)
obs = obs_correction(obs)
obs = torch.from_numpy(obs).float().unsqueeze(0).to(device) # Convert to tensor and move to device
weights = _expert_weights
x = torch.matmul(obs, weights["default_policy/fc_1/kernel"]) + weights["default_policy/fc_1/bias"]
x = torch.tanh(x)
x = torch.matmul(x, weights["default_policy/fc_2/kernel"]) + weights["default_policy/fc_2/bias"]
x = torch.tanh(x)
x = torch.matmul(x, weights["default_policy/fc_out/kernel"]) + weights["default_policy/fc_out/bias"]
x = x.squeeze(0).cpu() # Move back to CPU and remove batch dimension
mean, log_std = torch.split(x, 2, dim=-1)
if deterministic:
return (mean.numpy(), obs.cpu().numpy()) if need_obs else mean.numpy()
std = torch.exp(log_std)
action = torch.normal(mean, std).cpu() # Move back to CPU
return (action.numpy(), obs.cpu().numpy()) if need_obs else action.numpy()


def torch_value(obs, weights):
"""
ppo critic to predict value
:params:
obs: observation
weights: weights
:return:
value: value predicted by critic
"""
with torch.no_grad(): # Disable gradient computation
obs = torch.from_numpy(obs).float().unsqueeze(0).to(device) # Convert to tensor and move to device
weights = _expert_weights
x = torch.matmul(obs, weights["default_policy/fc_value_1/kernel"]) + weights["default_policy/fc_value_1/bias"]
x = torch.tanh(x)
x = torch.matmul(x, weights["default_policy/fc_value_2/kernel"]) + weights["default_policy/fc_value_2/bias"]
x = torch.tanh(x)
x = torch.matmul(x, weights["default_policy/value_out/kernel"]) + weights["default_policy/value_out/bias"]
return x.squeeze(0).cpu().numpy() # Move back to CPU and remove batch dimension

0 comments on commit 231c91e

Please sign in to comment.