In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rl_equation_solver
from rl_equation_solver.environment.algebraic import Env
from rl_equation_solver.agent.dqn import Agent as AgentDQN
from rl_equation_solver.agent.gcn import Agent as AgentGCN
from rl_equation_solver.agent.lstm import Agent as AgentLSTM
from rl_equation_solver.utilities import utilities
from rl_equation_solver.utilities.utilities import GraphEmbedding
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from rex import init_logger
from sympy import symbols, sqrt, simplify, expand, nsimplify, parse_expr, sympify
import sympy

In [None]:
init_logger(__name__, log_level='DEBUG')
init_logger('rl_equation_solver', log_level="DEBUG")

In [None]:
# plot reward distribution
def plot_reward_dist(env):
    fig, ax = plt.subplots(1, 1)
    ax.hist(env.avg_history['reward'])
    ax.set_ylabel('Count')
    ax.set_xlabel('Reward')
    plt.show()

# plot complexity, loss, reward
def plot_trends(env):
    fig, ax = plt.subplots(1, 3, figsize=(10, 5))
    x = np.arange(len(env.avg_history['ep']))
    
    y = env.avg_history['complexity']
    a, b = np.polyfit(x, y, 1)
    ax[0].scatter(x, y)
    ax[0].plot(a*x + b, color='r')
    
    y = env.avg_history['loss']
    mask = np.array([i for i, v in enumerate(y) if not np.isnan(v)])
    mask = slice(mask[0], mask[-1] + 1)
    a, b = np.polyfit(x[mask], y[mask], 1)
    ax[1].scatter(x[mask], y[mask])
    ax[1].plot(a*x[mask] + b, color='r')
    
    y = env.avg_history['reward']
    a, b = np.polyfit(x, y, 1)
    ax[2].scatter(x, y)
    ax[2].plot(a*x + b, color='r')
    
    ax[0].set_title('Complexity')
    ax[1].set_title('Loss')
    ax[2].set_title('Reward')
    plt.show()
    

## Initialize Env and Agent ##
### Agent can be AgentGCN, AgentLSTM, AgentDQN ###

In [None]:
env = Env(order=2, init_state=symbols('0'))
agent = AgentGCN(env, device='cuda:0')

## Train Agent ##

In [None]:
for _ in range(10):
    agent.train(5)

## Plot Reward Distribution ##

In [None]:
plot_reward_dist(env)

## Plot complexity, loss, and reward trend ##

In [None]:
plot_trends(env)


## Run trained agent in eval mode ##

In [None]:
agent.train(10, eval=True)

## Plot reward distribution for trained agent ##

In [None]:
plot_reward_dist(env)

## Render final state graph ##

In [None]:
nx.draw(env.state_graph, labels=env.node_labels)

## Run new Agent in eval mode ##

In [None]:
env = Env(order=2)
agent = AgentGCN(env, device='cuda:0')
agent.train(10, eval=True)

## Plot reward distribution for untrained agent ##

In [None]:
plot_reward_dist(env)