## Load Interaction and Configs
The interaction class allows the environment object and the agent object to interface with one another
The interaction class is parameterized by a configuration file. This is where you define parameters for:

AGENT: # hidden layers, # nodes per layer, size of memory, etc. 

INTERACTION: how many training episodes, type of epsilon decay, number of testing episodes, etc. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# IMPORT DESIRED INTERACTION CLASS AND CONFIGURATION
import sys
import os

# Get the parent directory (where `configs/` and `interactions/` are located)
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))

# Add it to sys.path
sys.path.append(parent_dir)

# Now you should be able to import
from interactions import dqn_interaction as dqn
from configs.dqn_configs import interaction_example_lunarlander, agent_example_lunarlander, env_lunar_lander_config
from configs.llm_dqn_configs import dqn_llm_agent_configs

In [4]:
# Initialize an interaction using this configuration
dqn_interaction = dqn.DQNInteraction(interaction_configs = interaction_example_lunarlander,
                                     agent_configs = agent_example_lunarlander,
                                     env_configs = env_lunar_lander_config,
                                     llm_configs = dqn_llm_agent_configs)

# Train an agent
train_scores, trained_agent = dqn_interaction.train()

OpenAIError: The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable

In [None]:
# Test the agent
test_scores = dqn_interaction.test(trained_agent)

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))  # Adjust figure size if needed
fig.suptitle(f"Lunar Lander DQN Training - Discount: {dqn_interaction.config.gamma}")

# Training plot
ax1.plot(train_scores, color='r')
ax1.set_title('Training')
ax1.set_xlabel('Episodes')
ax1.set_ylabel('Score')
ax1.grid()

# Testing plot
ax2.plot(test_scores, color='b')
ax2.set_title('Testing')
ax2.set_xlabel('Episodes')
ax2.set_ylabel('Score')
ax2.grid()

plt.tight_layout()  # Adjust layout to prevent overlap
plt.show()


