In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rl_equation_solver
from rl_equation_solver.environment.algebraic import Env
from rl_equation_solver.agent.dqn import Agent as AgentDQN
from rl_equation_solver.agent.gcn import Agent as AgentGCN
from rl_equation_solver.agent.lstm import Agent as AgentLSTM
from rl_equation_solver.utilities import utilities
from rl_equation_solver.utilities.utilities import GraphEmbedding
import networkx as nx
import numpy as np
from torch_geometric.utils.convert import from_networkx
import matplotlib.pyplot as plt
from rex import init_logger
from stable_baselines3 import DQN, A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from gymnasium import spaces
from gymnasium.vector.utils.spaces import batch_space
from stable_baselines3.common.env_checker import check_env
from sympy import symbols

In [None]:
init_logger(__name__, log_level='DEBUG')
init_logger('rl_equation_solver', log_level="DEBUG")

In [None]:
env = Env(order=2)

In [None]:
check_env(env, warn=True)

In [None]:
model = A2C("MlpPolicy", env, verbose=1)

In [None]:
model.set_env(env)
model.learn(total_timesteps=int(1e5))

In [None]:
env.reset_history()
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    print(info)

In [None]:
def make_hist_plot(env, start=0):
    fig, ax = plt.subplots(1, 1, figsize=(12, 5))
    avg_complex = []
    avg_reward = []
    avg_loss = []
    for episode in list(env.history.keys())[start:]:
        avg_complex.append(np.mean(env.history[episode]['complexity']))
        avg_loss.append(np.nanmean(env.history[episode]['loss']))
        avg_reward.append(np.mean(env.history[episode]['reward']))

    plt.hist(avg_reward)
    plt.xlabel('Reward')
    plt.ylabel('Count')

def moving_avg(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

def make_plot(env, start=0):
    fig, ax = plt.subplots(1, 3, figsize=(12, 5))
    avg_complex = []
    avg_reward = []
    avg_loss = []
    for episode in list(env.history.keys())[start:]:
        avg_complex.append(np.mean(env.history[episode]['complexity']))
        avg_loss.append(np.nanmean(env.history[episode]['loss']))
        avg_reward.append(np.mean(env.history[episode]['reward']))
    
    y = moving_avg(avg_complex, 1)
    x = np.arange(len(y))
    a, b = np.polyfit(x, y, 1)
    ax[0].scatter(x, y)
    ax[0].plot(x, a*x+b, color='red')

    y = moving_avg(avg_loss, 1)
    x = np.arange(len(y))
    a, b = np.polyfit(x, y, 1)
    ax[1].scatter(x, y)
    ax[1].plot(x, a*x+b, color='red')
    
    y = moving_avg(avg_reward, 1)
    x = np.arange(len(y))
    a, b = np.polyfit(x, y, 1)
    ax[2].scatter(x, y)
    ax[2].plot(x, a*x+b, color='red')
    
    ax[0].set_title('Complexity')
    ax[1].set_title('Loss')
    ax[2].set_title("Reward")
    plt.annotate('Episode', (0.4, 0.01), xycoords='figure fraction')



In [None]:
make_hist_plot(env)

In [None]:
make_plot(env)