-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Read me file updated and added cartpole gif
- Loading branch information
Charana Umesha
committed
Jun 7, 2023
1 parent
f9c8ace
commit 439c78e
Showing
2 changed files
with
84 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Reinforcement Learning with Stable Baselines3 | ||
|
||
This is a simple example of using Stable Baselines3, a library for reinforcement learning, to train an agent on the CartPole-v0 environment. | ||
|
||
![CartPole](cartpole.gif) | ||
|
||
## Dependencies | ||
Make sure you have the following dependencies installed: | ||
- stable-baselines3 | ||
- gym | ||
- pyglet | ||
|
||
You can install them using pip: | ||
``` | ||
pip install stable-baselines3[extra] | ||
pip install pyglet==1.5.27 | ||
``` | ||
|
||
## Load Environment | ||
First, we import the necessary dependencies and create an instance of the CartPole-v0 environment: | ||
```python | ||
import os | ||
import gym | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.vec_env import DummyVecEnv | ||
from stable_baselines3.common.evaluation import evaluate_policy | ||
|
||
environment_name = 'CartPole-v0' | ||
env = gym.make(environment_name) | ||
``` | ||
|
||
## Training | ||
To train the agent, we initialize the PPO algorithm and pass in the environment. We then call the `learn` method to start the training process: | ||
```python | ||
log_path = os.path.join('Training', 'Logs') | ||
env = DummyVecEnv([lambda: env]) | ||
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path) | ||
model.learn(total_timesteps=20000) | ||
``` | ||
|
||
## Save and Load the Model | ||
You can save the trained model to a file and load it later for evaluation or further training: | ||
```python | ||
PPO_Path = os.path.join('Training', 'Saved Models', 'PPO_Model_Cartpole') | ||
model.save(PPO_Path) | ||
model = PPO.load(PPO_Path, env=env) | ||
``` | ||
|
||
## Evaluation | ||
To evaluate the performance of the trained agent, you can use the `evaluate_policy` function: | ||
```python | ||
evaluate_policy(model, env, n_eval_episodes=10, render=True) | ||
``` | ||
|
||
## Testing the Model | ||
You can test the trained model by running episodes and observing its behavior: | ||
```python | ||
episodes = 5 | ||
for episode in range(1, episodes + 1): | ||
obs = env.reset() | ||
done = False | ||
score = 0 | ||
|
||
while not done: | ||
env.render() | ||
action, _ = model.predict(obs) | ||
obs, reward, done, info = env.step(action) | ||
score += reward | ||
|
||
print('Episode: {} Score: {}'.format(episode, score)) | ||
|
||
env.close() | ||
``` | ||
|
||
## Viewing Training Logs | ||
You can visualize the training progress using TensorBoard. First, specify the log directory and start TensorBoard: | ||
```python | ||
training_log_path = os.path.join(log_path, 'PPO_1') | ||
!tensorboard --logdir={training_log_path} | ||
``` | ||
Then, open `localhost:6006` in your browser to view the training logs. | ||
|
||
## Conclusion | ||
Reinforcement learning with Stable Baselines3 is a powerful tool for training agents in various environments. By following the steps in this example, you can train, save, and evaluate a reinforcement learning agent for the CartPole-v0 environment. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.