In [None]:
import gymnasium as gym

from collections import defaultdict
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import sys
import os

# Add project root to path
sys.path.append(os.path.abspath(".."))

from diablo_env.env.diablo_env import DiabloEnv
from diablo_env.env.utils import load_robot



In [None]:

class diabloAgent:
    def __init__(
        self,
        env: DiabloEnv,
        learning_rate: float,
        initial_epsilon: float,
        epsilon_decay: float,
        final_epsilon: float,
        discount_factor: float = 0.99,
    ):
        
        self.env = env

        self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))

        self.lr = learning_rate
        self.discount_factor = discount_factor  

        # Exploration parameters
        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon

        # Track learning progress
        self.training_error = []

    # Discretization of continuous state space
    def discretize(self, obs: np.ndarray) -> tuple:
        #continuous observation to discrete bins
        cart_pos_bins = np.linspace(-2.4, 2.4, 24)
        cart_vel_bins = np.linspace(-4.0, 4.0, 20)
        pole_angle_bins = np.linspace(-0.2, 0.2, 40)
        pole_vel_bins = np.linspace(-4.0, 4.0, 20)
        bins = [cart_pos_bins, cart_vel_bins, pole_angle_bins, pole_vel_bins]

        return tuple(np.digitize(obs[i], bins[i]) for i in range(len(obs)))

    def get_action(self, obs: np.ndarray) -> int:
        
        # explore
        state = self.discretize(obs)
        if np.random.random() < self.epsilon:
            return self.env.action_space.sample()

        # exploit
        else:
            return int(np.argmax(self.q_values[state]))

    def update(
        self,
        obs: np.ndarray,
        action: int,
        reward: float,
        terminated: bool,
        next_obs: np.ndarray,
    ):
        
        state = self.discretize(obs)
        next_state = self.discretize(next_obs)
        future_q_value = (not terminated) * np.max(self.q_values[next_state])
        target = reward + self.discount_factor * future_q_value
        
        temporal_difference = target - self.q_values[state][action]
        
        self.q_values[state][action] = (
            self.q_values[state][action] + self.lr * temporal_difference
        )
        
        self.training_error.append(temporal_difference)

    #reduce exploration over time
    def decay_epsilon(self):
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)

learning_rate = 0.1        
n_episodes = 50000       
start_epsilon = 1.0         
epsilon_decay = 0.99995  
final_epsilon = 0.01         

#setting up the environment
env = DiabloEnv(render_mode="None")
env = gym.wrappers.RecordEpisodeStatistics(env, buffer_length=n_episodes)

#Initialization the agent
agent = diabloAgent(
    env=DiabloEnv,
    learning_rate=learning_rate,
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
)



#training loop
for episode in tqdm(range(n_episodes)):
    obs, info = env.reset()
    # print(f'obs: {obs}')
    done = False

    while not done:
        action = agent.get_action(obs)
        next_obs, reward, terminated, truncated, info = env.step(action)
        agent.update(obs, action, reward, terminated, next_obs)
        done = terminated or truncated
        obs = next_obs

    agent.decay_epsilon()

