imports and functions

In [1]:
import time
from rich.live import Live
from rich.table import Table
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, MofNCompleteColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn
from rich.console import Group
from rich.panel import Panel
from gridworld1d import GridWorld1D
from q_agent import QLearningAgent
print ("hello")
def initDisplayTable():
    table = Table()
    table.add_column("State")
    table.add_column("Left")
    table.add_column("Right")
    table.add_column("Decision")
    return table
#end def initDisplayTable
def updateDisplayTableFromQTable(display_table, qtable):
    for state, actions in qtable.items():
        state_str = "State " + str(state)
        left_str = "Left " + str(list(actions.values())[0])
        right_str = "Right " + str(list(actions.values())[1])
        
        leftValue = float(list(actions.values())[0])
        rightValue = float(list(actions.values())[1])
        decision = "stay" if leftValue == rightValue else ("<<Left<<" if leftValue > rightValue else ">>Right>>")
        stateValueStr = str(state)
        leftValueStr = str(round(leftValue, 4))
        rightValueStr = str(round(rightValue, 4)) 
        display_table.columns[0]._cells[state] = stateValueStr
        display_table.columns[1]._cells[state] = leftValueStr
        display_table.columns[2]._cells[state] = rightValueStr
        display_table.columns[3]._cells[state] = decision
    # end for state, actions loop
    return display_table
#end def updateDisplayTableFromQTable

def isTrainingComplete(qtable):
    for state, actions in qtable.items():
        if state == goalState:
            continue
        leftValue = float(list(actions.values())[0])
        rightValue = float(list(actions.values())[1])
        if rightValue <= leftValue:
            return False
    return True
#end def isTrainingComplete

def isTrainingCompleteByStepCount(stepCounterTable):
    for row in stepCounterTable.rows:
        if row.cells[1] == str(grid1DSize-1):
            return True
    return False
#end def isTrainingCompleteByStepCount

def initStepCounterTable():
    stepCounterTable = Table()
    stepCounterTable.add_column("Episode", style="bold cyan")
    stepCounterTable.add_column("Steps Taken", style="bold cyan")
    return stepCounterTable
#end def initStepCounterTable

# Function to update step counter
def updateStepCounter(stepTable, episode, steps):
    # Check if row exists for this episode
    existing_row = None
    for i, row in enumerate(stepTable.rows):
        if stepTable.columns[0]._cells[i] == str(episode):
            existing_row = i
            break
    
    if existing_row is not None:
        # Update existing row
        stepTable.columns[1]._cells[existing_row] = str(steps)
    else:
        # Add new row if episode doesn't exist
        stepTable.add_row(str(episode), str(steps))
#end def updateStepCounter

# Configuration Variables
num_episodes = 1000 # number of training episodes
grid1DSize = 30 # size of the 1D grid
startState = 0 # starting state
goalState = (grid1DSize - 1)   # goal state 
sleep_time = 0 # time to sleep between episodes

print ("init done a")

hello
init done a


config variables

In [2]:
# Init env. Init agent.
step_count = 0
env = GridWorld1D(size=grid1DSize, start_state=startState, goal_state=goalState)
agent = QLearningAgent(actions=["left", "right"], learning_rate=0.1, discount_factor=0.99,
                       epsilon=1.0, epsilon_decay=0.95, epsilon_min=0.01)
# Initialize lists to store episode and step data
episode_data = []
step_data = []
epsilon_data = []

In [3]:
# Initialize the display tables and progress bar
table = initDisplayTable()
stepCounterTable = initStepCounterTable()
progress = Progress(
    SpinnerColumn(),  # Shows a spinning animation
    TextColumn("Training progress:"),  # Task description
    BarColumn(complete_style="blue", finished_style="green"),  # Progress bar
    MofNCompleteColumn(),  # Shows "M of N complete"
    TimeElapsedColumn(),  # Time elapsed
    TimeRemainingColumn(),  # Estimated time remaining
    TaskProgressColumn(),  # Task-specific progress
    transient=True
)
task = progress.add_task("Training Q-learning agent", total=num_episodes)

# Add state progress bar
stateProgressBar = Progress(
    SpinnerColumn(),
    TextColumn("Current state:"),
    BarColumn(complete_style="yellow", finished_style="green"),
    MofNCompleteColumn(),
    transient=True
)
stateTask = stateProgressBar.add_task("State tracking", total=grid1DSize-1)

# Initialize step counter variable
current_steps = 0

# Create initial display group
display_group = Group(progress, stepCounterTable, stateProgressBar, table)


In [4]:
with Live(display_group, refresh_per_second=50) as live:
    for i in range(grid1DSize):
        table.add_row(str(i), "0.0", "0.0", "stay")
        live.update(table)
    # end for i loop to init the display table

    # Training loop
    start_time = time.time()
    for episode in range(num_episodes):
        state = env.reset()        # reset environment to starting state
        done = False
        step_count = 0
        
        while not done:
            action = agent.choose_action(state)             # choose action (epsilon-greedy)
            next_state, reward, done = env.step(action)     # take action, observe reward and next state
            agent.update_q_value(state, action, reward, next_state, done)  # update Q-table
            state = next_state        # move to the next state
            step_count += 1
            
            stateProgressBar.update(stateTask, completed=state) # Update state progress bar
            #updateStepCounter(stepCounterTable, episode, step_count) # todo: add this back in
            live.update(display_group)
        # end while loop
         
        # Store episode and step data
        episode_data.append(episode)
        step_data.append(step_count)
        epsilon_data.append(agent.epsilon)
        
        # Decay exploration rate at end of episode
        agent.decay_epsilon()
        
        # Update Displays. (progress bar and the Q-table)
        progress.update(task, description=f"Episode {episode+1} of {num_episodes}", advance=1)
        qtable = agent.getQTable()
        table = updateDisplayTableFromQTable(table, qtable)
        #updateStepCounter(stepCounterTable, episode, step_count) # todo: add this back in
        
        # Update the display group with new plot
        display_group = Group(progress, stepCounterTable, stateProgressBar, table)
        live.update(display_group)
        
        if isTrainingComplete(qtable):
            #print ("Training complete at episode", episode+1)
            xx=1
        # end if isTrainingComplete
        time.sleep(sleep_time)
    # end for episode loop
    end_time = time.time()

#end with Live loop


Output()