In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rl_equation_solver
from rl_equation_solver.environment.algebraic import Env
from rl_equation_solver.agent.dqn import Agent as AgentDQN
from rl_equation_solver.agent.gcn import Agent as AgentGCN
from rl_equation_solver.agent.lstm import Agent as AgentLSTM
from rl_equation_solver.utilities import utilities
from rl_equation_solver.utilities.utilities import GraphEmbedding
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from rex import init_logger
from sympy import (
    symbols,
    sqrt,
    simplify,
    expand,
    nsimplify,
    parse_expr,
    sympify,
)
import sympy
import cProfile, pstats, io
from pstats import SortKey

In [None]:
init_logger(__name__, log_level="INFO")
init_logger("rl_equation_solver", log_level="INFO")

## Run agent until solution is found a single time for sets of hyperparameters ##

In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import itertools

soln_steps = {}
configs = {}
taus = [0.7, 0.85, 1.0]
gammas = [0.7, 0.8, 0.9]
eps_start = [0.9, 0.7, 0.5]
eps_decay_steps = [100, 500, 1000]
eps_end = [0.2, 0.1, 0.05, 0.01]
batch_size = [16, 32, 64]
run_number = 20
max_workers = 64

combos = list(
    itertools.product(
        *[taus, gammas, eps_start, eps_decay_steps, eps_end, batch_size]
    )
)
for i, combo in enumerate(combos):
    configs[i] = {
        key: val
        for key, val in zip(
            [
                "tau",
                "gamma",
                "eps_start",
                "eps_decay_steps",
                "eps_end",
                "batch_size",
            ],
            combo,
        )
    }

for i in range(len(configs)):
    soln_steps[i] = [0] * run_number


def find_soln(i, j):
    config = configs[i]
    env = Env(order=2, config={"reward_function": "diff_loss_reward"})
    agent = AgentGCN(env, device="cuda:0", config=config)
    agent.train(1)
    soln_steps[i][j] = env.loop_step


def run_pool():
    futures = {}
    with ThreadPoolExecutor(max_workers=max_workers) as pool:
        for i in range(len(configs)):
            for j in range(run_number):
                future = pool.submit(find_soln, i, j)
                futures[future] = (i, j)

        for _, future in enumerate(as_completed(futures)):
            _ = future.result()
            i, j = futures[future]
            print(f"future {(i, j)} completed")

In [None]:
run_pool()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
for i, config in configs.items():
    ax.plot(soln_steps[i], label=str(config))
ax.set_ylabel("Steps to solution")
ax.set_xlabel("Run number")
plt.legend()

In [None]:
avgs = {k: np.mean(soln_steps[k]) for k in soln_steps}

In [None]:
configs[np.argmin(list(avgs.values()))]