# aquacrop-gym: PPO example

This notebook will show the processed used to train a PPO agent to learn to make irrigation decisions within AquaCrop-OSPy



import libraries and functions

In [1]:
from aquacrop.classes import *
from aquacrop.core import *

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from aquacropgym.utils import calc_eto_faopm
from aquacropgym.envs import CropEnv, tunis_maize_config
from aquacropgym.utils import evaluate_agent, evaluate_agent_single_year
from aquacropgym.utils import calc_eto_faopm

import copy

import ray

In [2]:
ray.shutdown()
ray.init(num_cpus=1,num_gpus=0) # set number of cpus and gpus available

RayContext(dashboard_url='', python_version='3.8.16', ray_version='1.13.0', ray_commit='e4ce38d001dbbe09cd21c497fedd03d692b2be3e', address_info={'node_ip_address': '172.17.2.195', 'raylet_ip_address': '172.17.2.195', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2025-09-17_07-30-02_521222_1714616/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2025-09-17_07-30-02_521222_1714616/sockets/raylet', 'webui_url': '', 'session_dir': '/tmp/ray/session_2025-09-17_07-30-02_521222_1714616', 'metrics_export_port': 63111, 'gcs_address': '172.17.2.195:51411', 'address': '172.17.2.195:51411', 'node_id': '84ee59cc6746e414a023348ba04337eacc713c962cbe67592cfcb7f2'})



[2m[36m(RolloutWorker pid=1714873)[0m Starting new season in year  1990
[2m[36m(RolloutWorker pid=1714873)[0m the:  [0.05673147 0.06       0.06       0.06       0.06       0.06
[2m[36m(RolloutWorker pid=1714873)[0m  0.06       0.06       0.06       0.06       0.06       0.06      ]
[2m[36m(RolloutWorker pid=1714873)[0m dz:  0     0.1
[2m[36m(RolloutWorker pid=1714873)[0m 1     0.1
[2m[36m(RolloutWorker pid=1714873)[0m 2     0.1
[2m[36m(RolloutWorker pid=1714873)[0m 3     0.1
[2m[36m(RolloutWorker pid=1714873)[0m 4     0.1
[2m[36m(RolloutWorker pid=1714873)[0m 5     0.1
[2m[36m(RolloutWorker pid=1714873)[0m 6     0.3
[2m[36m(RolloutWorker pid=1714873)[0m 7     0.3
[2m[36m(RolloutWorker pid=1714873)[0m 8     0.3
[2m[36m(RolloutWorker pid=1714873)[0m 9     0.3
[2m[36m(RolloutWorker pid=1714873)[0m 10    0.3
[2m[36m(RolloutWorker pid=1714873)[0m 11    0.3
[2m[36m(RolloutWorker pid=1714873)[0m Name: dz, dtype: float64
[2m[36m(RolloutWorker p

In [3]:
weather_path = get_filepath("tunis_climate.txt")
weather_df = prepare_weather(weather_path)
weather_df["Precipitation"] = 0  # force no rain

## Define crop simulation config options

In [None]:
IRR_CAP=100 # max amount of irrigation (mm/ha) that can be applied in a single season
ACTION_SET='depth' # action sets, alternatives are: 'depth', 'binary', 'smt4'
DAYS_TO_IRR=5 # 'number of days between irrigation decisons (e.g., 1, 3, 5, 7)

envconfig=tunis_maize_config.copy() # get default config dictionary
envconfig['weather_df']=weather_df # set weather data

envconfig['include_rain'] = False
envconfig['action_set']=ACTION_SET # action sets, alternatives are: 'depth', 'binary', 'smt4'
envconfig['days_to_irr']=DAYS_TO_IRR # 'number of days between irrigation decisons (e.g., 1, 3, 5, 7)
envconfig['max_irr_season']=IRR_CAP # max amount of irrigation (mm/ha) that can be applied in a single season

env=CropEnv(envconfig)

In [5]:
from ray.rllib.agents.ppo import ppo

config = ppo.DEFAULT_CONFIG.copy()

config['num_workers'] = 1
config['num_gpus'] = 0
config['observation_filter'] = 'MeanStdFilter' # normalize observations
config['rollout_fragment_length'] = 160
config['train_batch_size'] = 512
config['model']['fcnet_hiddens'] = [64]*3
config['num_cpus_per_worker'] = 0 
config['framework'] = 'torch'
config['gamma'] = 1.
config['env_config']=envconfig
config['model']['vf_share_layers'] = False

  from .autonotebook import tqdm as notebook_tqdm


## create ppo agent

In [6]:
agent = ppo.PPOTrainer(config, CropEnv)

2025-09-17 07:30:07,632	INFO ppo.py:414 -- In multi-agent mode, policies will be optimized sequentially by the multi-GPU optimizer. Consider setting simple_optimizer=True if this doesn't work for you.
2025-09-17 07:30:07,633	INFO trainer.py:903 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


## train and evaluate agent

In [7]:
if False:
    rewards=[]
    timesteps=[]
    caps=[]

    for i in range(1,500001):
        result = agent.train()

        ts = result['timesteps_total']

        if i%5==0: # evaluate agent on train and test years
            print('eval')
            for irr_cap in [IRR_CAP]:
                test_env_config=copy.deepcopy(envconfig) # make a copy of the training env
                test_env_config['manual_year'] = 1980
                
                reward = evaluate_agent_single_year(agent, CropEnv, test_env_config)
                rewards.append(reward)
                
                #proftrain.append(train_rew)
                #proftest.append(test_rew)
                timesteps.append(ts)
                caps.append(irr_cap)

                print("Reward after ", i, " epoches: ", reward)
                
                #print(irr_cap,f'Train:{round(train_rew,3)}')
                #print(irr_cap,f'Test:{round(test_rew,3)}')


        if i%5==0: # save results
            checkpoint_path = agent.save()
            print(checkpoint_path)

            result_df = pd.DataFrame([timesteps,rewards,caps]).T
            #result_df.to_csv(f'outputs/neb_corn_ppo_day_{DAYS_TO_IRR}_act_{ACTION_SET}_cap_{IRR_CAP}.csv')
            #plt.plot(timesteps,proftrain)
            #plt.plot(timesteps,rewards)
            #plt.show()

In [None]:
val_years = [1995, 1996, 1997, 1998, 1999]

rewards = []
timesteps = []
caps = []

max_val_yield = -float('inf')  # startwert für bestes Validierungs-Yield

for i in range(1, 500001):
    result = agent.train()
    ts = result['timesteps_total']

    if i % 5 == 0:  # evaluiere alle 10 Epochen
        print(f'\n=== Evaluation nach {i} Epochen ===')
        val_yields = []

        for year in val_years:
            test_env_config = copy.deepcopy(envconfig)
            test_env_config['manual_year'] = year

            yield_val = evaluate_agent_single_year(agent, CropEnv, test_env_config)
            val_yields.append(yield_val)

            #print(f"Year {year}: Reward = {yield_val:.3f}")

        mean_val_yield = sum(val_yields) / len(val_yields)
        print(f"Mean validation reward: {mean_val_yield:.3f}")

        # prüfen, ob neues max_val_yield erreicht wurde
        if mean_val_yield > max_val_yield:
            max_val_yield = mean_val_yield
            checkpoint_path = agent.save()
            print(f"New max validation yield reached! Agent saved at: {checkpoint_path}")

            # zusätzliche Evaluation für das Testjahr 1980
            test_env_config = copy.deepcopy(envconfig)
            test_env_config['manual_year'] = 1980
            test_yield = evaluate_agent_single_year(agent, CropEnv, test_env_config)
            print(f"Test year 1980 yield: {test_yield:.3f}")

        # Ergebnisse für Plot/CSV
        rewards.append(mean_val_yield)
        timesteps.append(ts)
        caps.append(IRR_CAP)





=== Evaluation nach 5 Epochen ===
Starting new season in year  1995
the:  [0.05694529 0.06       0.06       0.06       0.06       0.06
 0.06       0.06       0.06       0.06       0.06       0.06      ]
dz:  0     0.1
1     0.1
2     0.1
3     0.1
4     0.1
5     0.1
6     0.3
7     0.3
8     0.3
9     0.3
10    0.3
11    0.3
Name: dz, dtype: float64
Zroot:  0.3
Root zone water content (mm):  17.694528670631122
the:  [0.0812371  0.10237369 0.06       0.06       0.06       0.06
 0.06       0.06       0.06       0.06       0.06       0.06      ]
dz:  0     0.1
1     0.1
2     0.1
3     0.1
4     0.1
5     0.1
6     0.3
7     0.3
8     0.3
9     0.3
10    0.3
11    0.3
Name: dz, dtype: float64
Zroot:  0.3
Root zone water content (mm):  24.361079494269106
the:  [0.04993538 0.10124675 0.06       0.06       0.06       0.06
 0.06       0.06       0.06       0.06       0.06       0.06      ]
dz:  0     0.1
1     0.1
2     0.1
3     0.1
4     0.1
5     0.1
6     0.3
7     0.3
8     0.3
9     0