# Training an RL agent with a standard environment
In this notebook, we show how to train an RL agent using the stable-baselines3 library over an environemnt provided by CyclesGym.

In [None]:
import gym
import cyclesgym
import numpy as np
import wandb

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor, VecNormalize
from wandb.integration.sb3 import WandbCallback

from cyclesgym.utils.paths import PROJECT_PATH

First, we initalize a wandb session to track all the parameters of interest as well as statistics recorded during training

In [None]:
# Track PPO parameters with wandb
config = dict(total_timesteps=1000, 
              n_steps=80, 
              batch_size=80, 
              n_epochs=10, run_id=0, 
              verbose=1, 
              n_process=1, 
              device='cpu')

wandb.init(
    config=config,
    sync_tensorboard=True,
    project='notebook_experiments',
    monitor_gym=True,
    save_code=True,
    dir=PROJECT_PATH,
)

config = wandb.config

Subsequently, we initialize the vectorized environment

In [None]:
def env_maker():
    # 1-year fertilization environment with fixed weather in Rock Springs
    env = gym.make('CornShortRockSpringsFW-v1')
    return gym.wrappers.RecordEpisodeStatistics(env)

# Vectorize environment
env = SubprocVecEnv([env_maker for _ in range(config['n_process'])], start_method='fork')

# Monitor
env = VecMonitor(env, 'runs')

# Normalize values (clipping range high so that, in practice, it does not happen)
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=5000., clip_reward=5000.)


Finally, we train the model. We can monitor the training in wandb with the link we obtained above after the initializing the wandb session.

In [None]:
model = PPO('MlpPolicy', env, n_steps=config['n_steps'], batch_size=config['batch_size'],
            n_epochs=config['n_epochs'], verbose=config['verbose'], tensorboard_log=wandb.run.dir,
            device=config['device'])

model.learn(total_timesteps=config["total_timesteps"], callback=[WandbCallback()])