Skip to content

Commit

Permalink
fix ipynb (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanyiLi committed Nov 30, 2023
1 parent bb145cf commit b7b5b1e
Showing 1 changed file with 16 additions and 25 deletions.
41 changes: 16 additions & 25 deletions metadrive/examples/Basic_MetaDrive_Usages.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,32 +139,20 @@
"\n",
"ep_reward = 0.0\n",
"obs, info = env.reset()\n",
"frames = []\n",
"for i in range(1000):\n",
" obs, reward, terminated, truncated, info = env.step(expert(env.vehicle))\n",
" ep_reward += reward\n",
" frame = env.render(mode=\"top_down\", film_size=(800, 800), track_target_vehicle=True, screen_size=(500, 500))\n",
" frames.append(frame)\n",
" env.render(mode=\"top_down\", screen_record=True, track_target_vehicle=True, screen_size=(500, 500))\n",
" if terminated or truncated:\n",
" print(\"Arriving Destination: {}\".format(info[\"arrive_dest\"]))\n",
" print(\"\\nEpisode reward: \", ep_reward)\n",
" break\n",
"\n",
"print(\"\\nThe last returned information: {}\".format(info))\n",
"\n",
"env.top_down_renderer.generate_gif()\n",
"env.close()\n",
"print(\"\\nMetaDrive successfully run!\")\n",
"\n",
"# render image\n",
"print(\"\\nGenerate gif...\")\n",
"import pygame\n",
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"imgs = [pygame.surfarray.array3d(frame) for frame in frames]\n",
"imgs = [Image.fromarray(img) for img in imgs]\n",
"imgs[0].save(\"demo.gif\", save_all=True, append_images=imgs[1:], duration=50, loop=0)\n",
"print(\"\\nOpen gif...\")\n",
"from IPython.display import Image\n",
"Image(open(\"demo.gif\", 'rb').read())\n"
]
Expand Down Expand Up @@ -206,7 +194,7 @@
" obs, reward, terminated, truncated, info = env.step(expert(env.vehicle))\n",
" ep_reward += reward\n",
" ep_cost += info[\"cost\"]\n",
" frame = env.render(mode=\"top_down\", film_size=(1500, 1500), track_target_vehicle=True, screen_size=(500, 500))\n",
" frame = env.render(mode=\"top_down\", no_window=True, track_target_vehicle=True, screen_size=(500, 500))\n",
" frames.append(frame)\n",
" if terminated or truncated:\n",
" print(\"Arriving Destination: {}\".format(info[\"arrive_dest\"]))\n",
Expand All @@ -223,7 +211,7 @@
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"imgs = [pygame.surfarray.array3d(frame) for frame in frames]\n",
"imgs = [frame for frame in frames]\n",
"imgs = [Image.fromarray(img) for img in imgs]\n",
"imgs[0].save(\"demo.gif\", save_all=True, append_images=imgs[1:], duration=50, loop=0)\n",
"print(\"\\nOpen gif...\")\n",
Expand Down Expand Up @@ -269,7 +257,10 @@
" for a in action.values(): \n",
" a[-1] = 1.0\n",
" o,r,tm,tc,i = env.step(action)\n",
" frame = env.render(mode=\"top_down\", film_size=(500, 500), track_target_vehicle=False, screen_size=(500, 500))\n",
" frame = env.render(mode=\"top_down\", \n",
" scaling=4, # 4 pixels per meter\n",
" camera_position=env.current_map.get_center_point(), \n",
" screen_size=(500, 500))\n",
" frames.append(frame)\n",
" env.close()\n",
"\n",
Expand All @@ -279,7 +270,7 @@
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"imgs = [pygame.surfarray.array3d(frame) for frame in frames]\n",
"imgs = [frame for frame in frames]\n",
"imgs = [Image.fromarray(img) for img in imgs]\n",
"imgs[0].save(\"demo.gif\", save_all=True, append_images=imgs[1:], duration=50, loop=0)\n",
"print(\"\\nOpen gif...\")\n",
Expand All @@ -303,23 +294,23 @@
"from metadrive.policy.replay_policy import ReplayEgoCarPolicy\n",
"from metadrive.constants import HELP_MESSAGE\n",
"from metadrive.engine.asset_loader import AssetLoader\n",
"from metadrive.envs.real_data_envs.waymo_env import WaymoEnv\n",
"from metadrive.envs import ScenarioEnv\n",
"\n",
"\n",
"class DemoWaymoEnv(WaymoEnv):\n",
"class DemoEnv(ScenarioEnv):\n",
" def reset(self, seed=None):\n",
" if self.engine is not None:\n",
" seeds = [i for i in range(self.config[\"num_scenarios\"])]\n",
" seeds.remove(self.current_seed)\n",
" seed = random.choice(seeds)\n",
" return super(DemoWaymoEnv, self).reset(seed=seed)\n",
" return super(DemoEnv, self).reset(seed=seed)\n",
"\n",
"\n",
"extra_args = dict(film_size=(1200, 1200))\n",
"asset_path = AssetLoader.asset_path\n",
"\n",
"try:\n",
" env = DemoWaymoEnv(\n",
" env = DemoEnv(\n",
" {\n",
" \"manual_control\": False,\n",
" \"reactive_traffic\": False,\n",
Expand Down Expand Up @@ -353,7 +344,7 @@
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"imgs = [pygame.surfarray.array3d(frame) for frame in frames]\n",
"imgs = [frame for frame in frames]\n",
"imgs = [Image.fromarray(img) for img in imgs]\n",
"imgs[0].save(\"demo.gif\", save_all=True, append_images=imgs[1:], duration=50, loop=0)\n",
"print(\"\\nOpen gif...\")\n",
Expand Down Expand Up @@ -397,7 +388,7 @@
"from metadrive.utils.draw_top_down_map import draw_top_down_map\n",
"\n",
"env = MetaDriveEnv(config=dict(\n",
" environment_num=100,\n",
" num_scenarios=100,\n",
" map=7,\n",
" start_seed=random.randint(0, 1000)\n",
"))\n",
Expand Down Expand Up @@ -443,7 +434,7 @@
"from metadrive import MetaDriveEnv\n",
"\n",
"env = MetaDriveEnv(config=dict(\n",
" environment_num=100,\n",
" num_scenarios=100,\n",
" map=\"CrTRXOS\",\n",
" start_seed=random.randint(0, 1000)\n",
"))\n",
Expand Down

0 comments on commit b7b5b1e

Please sign in to comment.