Skip to content

Commit

Permalink
Read me file updated and added cartpole gif
Browse files Browse the repository at this point in the history
  • Loading branch information
Charana Umesha committed Jun 7, 2023
1 parent f9c8ace commit 439c78e
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Reinforcement Learning with Stable Baselines3

This is a simple example of using Stable Baselines3, a library for reinforcement learning, to train an agent on the CartPole-v0 environment.

![CartPole](cartpole.gif)

## Dependencies
Make sure you have the following dependencies installed:
- stable-baselines3
- gym
- pyglet

You can install them using pip:
```
pip install stable-baselines3[extra]
pip install pyglet==1.5.27
```

## Load Environment
First, we import the necessary dependencies and create an instance of the CartPole-v0 environment:
```python
import os
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

environment_name = 'CartPole-v0'
env = gym.make(environment_name)
```

## Training
To train the agent, we initialize the PPO algorithm and pass in the environment. We then call the `learn` method to start the training process:
```python
log_path = os.path.join('Training', 'Logs')
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)
model.learn(total_timesteps=20000)
```

## Save and Load the Model
You can save the trained model to a file and load it later for evaluation or further training:
```python
PPO_Path = os.path.join('Training', 'Saved Models', 'PPO_Model_Cartpole')
model.save(PPO_Path)
model = PPO.load(PPO_Path, env=env)
```

## Evaluation
To evaluate the performance of the trained agent, you can use the `evaluate_policy` function:
```python
evaluate_policy(model, env, n_eval_episodes=10, render=True)
```

## Testing the Model
You can test the trained model by running episodes and observing its behavior:
```python
episodes = 5
for episode in range(1, episodes + 1):
obs = env.reset()
done = False
score = 0

while not done:
env.render()
action, _ = model.predict(obs)
obs, reward, done, info = env.step(action)
score += reward

print('Episode: {} Score: {}'.format(episode, score))

env.close()
```

## Viewing Training Logs
You can visualize the training progress using TensorBoard. First, specify the log directory and start TensorBoard:
```python
training_log_path = os.path.join(log_path, 'PPO_1')
!tensorboard --logdir={training_log_path}
```
Then, open `localhost:6006` in your browser to view the training logs.

## Conclusion
Reinforcement learning with Stable Baselines3 is a powerful tool for training agents in various environments. By following the steps in this example, you can train, save, and evaluate a reinforcement learning agent for the CartPole-v0 environment.
Binary file added cartpole.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 439c78e

Please sign in to comment.