In [1]:
from itertools import islice

import compiler_gym
import ray
from compiler_gym.wrappers import (
    ConstrainedCommandline,
    TimeLimit,
    CycleOverBenchmarks,
)
from matplotlib import pyplot as plt
from ray import tune
from ray.rllib.algorithms.ppo import PPO, PPOConfig
from sklearn.model_selection import train_test_split

# import wandb
from train import config
# from ray.rllib.env.wrappers.multi_agent_env_compatibility import MultiAgentEnvCompatibility

from ray.tune.logger import pretty_print

In [2]:
def make_env() -> compiler_gym.envs.CompilerEnv:
    env = compiler_gym.make(
        config["compiler_gym_env"],
        observation_space=config["observation_space"],
        reward_space=config["reward_space"],
    )
    env = ConstrainedCommandline(
        env,
        flags=config["actions"],
    )
    env = TimeLimit(env, max_episode_steps=config["episode_length"])
    return env

In [3]:
def prepare_datasets(env: compiler_gym.envs.CompilerEnv) -> tuple:
    train_benchmarks = list(
        islice(env.datasets[config["train_benchmarks"]].benchmarks(), 10000)
    )
    train_benchmarks, val_benchmarks = train_test_split(
        train_benchmarks, test_size=0.15, random_state=config["random_state"]
    )
    test_benchmarks = list(env.datasets[config["test_benchmarks"]].benchmarks())
    return train_benchmarks, val_benchmarks, test_benchmarks

In [4]:
def make_training_env(*args) -> compiler_gym.envs.CompilerEnv:
    del args
    return CycleOverBenchmarks(make_env(), train_benchmarks)

In [5]:
def run_agent_on_benchmarks(benchmarks):
    with make_env() as env:
        rewards = []
        for i, benchmark in enumerate(benchmarks, start=1):
            observation, done = env.reset(benchmark=benchmark), False
            while not done:
                action = agent.compute_single_action(observation)
                observation, _, done, _ = env.step(action)
            rewards.append(env.episode_reward)
            print(f"[{i}/{len(benchmarks)}] {env.state}")

    return rewards

In [6]:
def plot_results(x, y, name, ax):
    plt.sca(ax)
    plt.bar(range(len(y)), y)
    plt.ylabel("Reward (higher is better)")
    plt.xticks(range(len(x)), x, rotation=90)
    plt.title(f"Performance on {name} set")

In [7]:
with make_env() as env:
    train_benchmarks, val_benchmarks, test_benchmarks = prepare_datasets(env)

In [8]:
if ray.is_initialized():
    ray.shutdown()
ray.init(
    include_dashboard=True,
    ignore_reinit_error=True,
    num_gpus=1,
)
tune.register_env("compiler_gym", make_training_env)

2024-01-28 14:43:38,159	INFO worker.py:1529 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8266 [39m[22m


In [16]:
env = make_env()

algo = (
    PPOConfig()
    .rollouts(num_rollout_workers=0, create_env_on_local_worker=True)
    .resources(num_gpus=1)
    .environment(env="compiler_gym")
    .framework("torch")
    .build()
)



In [17]:
for i in range(10):
    result = algo.train()
    print(pretty_print(result))

    if i % 5 == 0:
        checkpoint_dir = algo.save()
        print(f"Checkpoint saved in directory {checkpoint_dir}")

agent_timesteps_total: 4000
counters:
  num_agent_steps_sampled: 4000
  num_agent_steps_trained: 4000
  num_env_steps_sampled: 4000
  num_env_steps_trained: 4000
custom_metrics: {}
date: 2024-01-28_14-51-02
done: false
episode_len_mean: 100.0
episode_media: {}
episode_reward_max: 1.0
episode_reward_mean: 0.6780326716641432
episode_reward_min: -0.1978021978021981
episodes_this_iter: 40
episodes_total: 40
experiment_id: f9b58545707349a68330965ecef1ea79
hostname: debian
info:
  learner:
    default_policy:
      custom_metrics: {}
      diff_num_grad_updates_vs_sampler_policy: 464.5
      learner_stats:
        cur_kl_coeff: 0.20000000000000004
        cur_lr: 5.0000000000000016e-05
        entropy: 2.700178833674359
        entropy_coeff: 0.0
        kl: 0.007871751682426363
        policy_loss: -0.009317691399846026
        total_loss: 0.15157360241370355
        vf_explained_var: 0.3591258937953621
        vf_loss: 0.15931694290089993
      model: {}
      num_grad_updates_lifetime: 46

ValueError: Observation ([0.33333334 1.3333334  0.33333334 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.33333334 0.         0.         0.
 0.         0.         0.         0.33333334 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.        ] dtype=float32) outside given space (Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], (69,), float32))!

> [0;32m/home/flint/diplom/experiments/venv/lib/python3.9/site-packages/ray/rllib/models/preprocessors.py[0m(74)[0;36mcheck_shape[0;34m()[0m
[0;32m     72 [0;31m            [0;32mtry[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m                [0;32mif[0m [0;32mnot[0m [0mself[0m[0;34m.[0m[0m_obs_space[0m[0;34m.[0m[0mcontains[0m[0;34m([0m[0mobservation[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 74 [0;31m                    raise ValueError(
[0m[0;32m     75 [0;31m                        "Observation ({} dtype={}) outside given space ({})!".format(
[0m[0;32m     76 [0;31m                            [0mobservation[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m
array([0.33333334, 1.3333334 , 0.33333334, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.     

In [13]:
%pdb

Automatic pdb calling has been turned ON
