In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import rl_equation_solver
from rl_equation_solver.environment.algebraic import Env
from rl_equation_solver.agent.dqn import Agent as AgentDQN
from rl_equation_solver.agent.gcn import Agent as AgentGCN
from rl_equation_solver.agent.lstm import Agent as AgentLSTM
from rl_equation_solver.utilities import utilities
from rl_equation_solver.utilities.utilities import GraphEmbedding
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from rex import init_logger
from sympy import symbols, sqrt, simplify, expand

In [None]:
init_logger(__name__, log_level='DEBUG')
init_logger('rl_equation_solver', log_level="DEBUG")

In [None]:
def moving_avg(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

def make_plot(agent, round=0, start=0):
    fig, ax = plt.subplots(1, 3, figsize=(12, 5))
    avg_complex = []
    avg_reward = []
    avg_loss = []
    for episode in list(agent.history.keys())[start:]:
        avg_complex.append(np.mean(agent.history[episode]['complexity']))
        avg_loss.append(np.nanmean(agent.history[episode]['loss']))
        avg_reward.append(np.mean(agent.history[episode]['reward']))
    
    y = moving_avg(avg_complex, 1)
    x = np.arange(len(y))
    a, b = np.polyfit(x, y, 1)
    ax[0].scatter(x, y)
    ax[0].plot(x, a*x+b, color='red')

    y = moving_avg(avg_loss, 1)
    x = np.arange(len(y))
    a, b = np.polyfit(x, y, 1)
    ax[1].scatter(x, y)
    ax[1].plot(x, a*x+b, color='red')
    
    y = moving_avg(avg_reward, 1)
    x = np.arange(len(y))
    a, b = np.polyfit(x, y, 1)
    ax[2].scatter(x, y)
    ax[2].plot(x, a*x+b, color='red')
    
    ax[0].set_title('Complexity')
    ax[1].set_title('Loss')
    ax[2].set_title("Reward")
    plt.annotate('Episode', (0.4, 0.01), xycoords='figure fraction')
    fig.savefig(f'./figs/round_{round}.png', dpi=300)


def make_hist_plot(agent, round=0, start=0):
    fig, ax = plt.subplots(1, 1, figsize=(12, 5))
    avg_complex = []
    avg_reward = []
    avg_loss = []
    for episode in list(agent.history.keys())[start:]:
        avg_complex.append(np.mean(agent.history[episode]['complexity']))
        avg_loss.append(np.nanmean(agent.history[episode]['loss']))
        avg_reward.append(np.mean(agent.history[episode]['reward']))

    plt.hist(avg_reward)
    plt.xlabel('Reward')
    plt.ylabel('Count')  


In [None]:

x, a0, a1 = symbols('x a0 a1')
env = Env(order=2, init_state=symbols('1'))
agent_gcn = AgentGCN(env, device='cuda:0')

In [None]:
for i in range(1):
    agent_gcn.train(100)
    #make_plot(agent_dqn, round=i, start=2)    

In [None]:

for i in range(1):
    agent_gcn.train(10, eval=True)
    #make_plot(agent_dqn, round=i, start=2)    

In [None]:
make_hist_plot(agent_dqn)

In [None]:
nx.draw(dqn_env.state_graph, labels=dqn_env.node_labels)

In [None]:
from sympy import symbols, sqrt
x, a0, a1, a2 = symbols('x a0 a1 a2')
gcn_env = Env(order=3, init_state =(-a1 + sqrt(a1**2 - 4*a0*a2))/2/a0)
agent_gcn = AgentGCN(gcn_env, device='cuda:0')

In [None]:
agent_gcn.env.state_string

In [None]:
for i in range(1):
    agent_gcn.train(num_episodes=10)
    make_plot(agent_gcn, round=i, start=3)    

In [None]:
nx.draw(gcn_env.state_graph, labels=gcn_env.node_labels)

In [None]:
lstm_env = Env(order=3)
agent_lstm = AgentLSTM(lstm_env, device='cuda:0')

In [None]:
for i in range(1):
    agent_lstm.train(num_episodes=1000)
    make_plot(agent_lstm, round=i, start=3)    

In [None]:
nx.draw(lstm_env.state_graph, labels=lstm_env.node_labels)