```
Copyright 2024 The HIVEX Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
See the License for the specific language governing permissions and
limitations under the License.
```

# Example Training using ML-Agents: Reproducing Paper Results

Note:
1. Install the dependencies as described in the README.md.
2. Download or clone the hivex-environments.
3. Select the correct kernel for this jupyter notebook at the top right.

## Import Libraries

In [1]:
from pathlib import Path
import subprocess
from tqdm import tqdm
from pathlib import Path
from hivex.training.baseline.ml_agents.utils import (
    load_hivex_config,
    clean_temp_configs,
    construct_command,
    create_training_config_files,
    create_test_config_files,
)

## Initialize Functions

In [2]:
def train(
    train_config_path: Path,
    experiment_name: str,
    train_run_count: int = 1,
    port: str = "5005",
):
    batch_config_files = create_training_config_files(
        experiment_name=experiment_name,
        train_config_path=train_config_path,
        train_run_count=train_run_count,
    )

    for batch_config_file in tqdm(batch_config_files, desc="train"):
        cmd = construct_command(
            config_path=batch_config_file, force=False, experiment_name=experiment_name
        )
        cmd += "--base-port " + port
        subprocess.run(cmd, shell=True)


def test(
    test_config_path: Path,
    experiment_name: str,
    test_run_count: int = 1,
    port: str = "5005",
):
    batch_config_files = create_test_config_files(
        test_config_path=test_config_path,
        experiment_name=experiment_name,
        test_run_count=test_run_count,
    )

    for batch_config_file in tqdm(batch_config_files, desc="test"):
        cmd = construct_command(
            config_path=batch_config_file,
            experiment_name=experiment_name,
            force=False,
            train=False,
        )
        cmd += "--base-port " + port
        subprocess.run(cmd, shell=True)


def run_pipeline(config: Path, port: str):
    hivex_config = load_hivex_config(config)
    train(
        train_config_path=hivex_config["train_config_path"],
        experiment_name=hivex_config["experiment_name"],
        train_run_count=hivex_config["train_run_count"],
        port=port,
    )
    test(
        test_config_path=hivex_config["test_config_path"],
        experiment_name=hivex_config["experiment_name"],
        test_run_count=hivex_config["test_run_count"],
        port=port,
    )
    clean_temp_configs()

## Run Training

In [None]:
# Wind Farm Control
wind_farm_control_config = Path("../src/hivex/training/baseline/ml_agents/configs/experiments/WindFarmControl_hivex.yaml")
# Wildfire Resource Management
wildfire_resource_management_config = Path("../src/hivex/training/baseline/ml_agents/configs/experiments/WildfireResourceManagement_hivex.yaml")
# Drone-Based Reforestation
drone_based_reforestation_config = Path("../src/hivex/training/baseline/ml_agents/configs/experiments/DroneBasedReforestation_hivex.yaml")
# Ocean Plastic Collector
ocean_plastic_collection_config = Path("../src/hivex/training/baseline/ml_agents/configs/experiments/OceanPlasticCollection_hivex.yaml")
# Aerial Wildfire Suppression
aerial_wildfire_suppression_config = Path("../src/hivex/training/baseline/ml_agents/configs/experiments/AerialWildfireSuppression_hivex.yaml")

run_pipeline(
    config=wind_farm_control_config,
    port="5005",
)