# Q-Learning

In this notebook, an agent is trained to solve a grid world task via a Q-Learning algorithm.

The code is adapted from NYU's 'Computational Cognitive Modeling' course to accomplish the homework assignment. 


# Outline
- [ 1 - Import Packages <img align="Right" src="https://raw.githubusercontent.com/jojozyang/CCM-site/master/homeworks/homework-RL/images/gridworld.png" width = 60%>](#1)
- [ 2 - Load the Environment](#2)
- [ 3 - Main Functions](#3)
- [ 4 - Helper Functions](#4)
- [ 5 - Train the Agent](#5)

<a name="1"></a>
## 1 - Import Packages

In [1]:
import numpy as np
import random
import math
import statistics
from copy import deepcopy
from IPython.display import display, Markdown, Latex, HTML
from gridworld import GridWorld, random_policy

<a name="2"></a>
## 2 - Load the Environment

In [2]:
gridworld = [
       [ 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'x', 'g'],
       [ 'o', 'x', 'x', 'o', 'x', 'x', 'o', 'x', 'o'],
       [ 'o', 'x', 'x', 'o', 'x', 'x', 'o', 'x', 'o'],
       [ 'o', 'x', 'x', 'o', 'x', 'x', 'o', 'o', 'o'],
       [ 'o', 'x', 'x', 'o', 'x', 'x', 'x', 'o', 'o'],
       [ 's', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'x']
    ] # the problem described above, 'x' is a wall, 's' is start, 'g' is goal, and 'o' is a normal room

mygrid = GridWorld(gridworld)
mygrid.raw_print()  # print out the grid world
mygrid.index_print() # print out the indicies of each state
mygrid.coord_print() # print out the coordinates of each state (helpful in your code)

# define the rewards as a hash table (dictionary)
rewards={}

# mygrid.transitions contains all the pairwise state-state transitions allowed in the grid
# for each state transition intialize the reward to zero
for start_state in mygrid.transitions:
    for action in mygrid.transitions[start_state].keys():
        next_state = mygrid.transitions[start_state][action]
        rewards[str([start_state, action, next_state])] = 0.0

# now set the reward for moving up into state 8 (the goal state) to +10
rewards[str([17, 'up', 8])] = 10

# now set the penalty for walking off the edge of the grid and returning to state 45 (the start state)
for i in [0,1,2,3,4,5,6,7]:
    rewards[str([i, 'up', 45])] = -1
for i in [0,9,18,27,36,45]:
    rewards[str([i, 'left', 45])] = -1
for i in [45,46,47,48,49,50,51,52,53]:
    rewards[str([i, 'down', 45])] = -1
for i in [8,17,26,35,44,53]:
    rewards[str([i, 'right', 45])] = -1

## Welcome to your new Grid World!

**Raw World Layout**

0,1,2,3,4,5,6,7,8
o,o,o,o,o,o,o,x,g
o,x,x,o,x,x,o,x,o
o,x,x,o,x,x,o,x,o
o,x,x,o,x,x,o,o,o
o,x,x,o,x,x,x,o,o
s,o,o,o,o,o,o,o,x


**Indexes of each grid location as an id number**

0,1,2,3,4,5,6,7,8
0,1,2,3,4,5,6,7,8
9,10,11,12,13,14,15,16,17
18,19,20,21,22,23,24,25,26
27,28,29,30,31,32,33,34,35
36,37,38,39,40,41,42,43,44
45,46,47,48,49,50,51,52,53


**Indexes of each grid location as a tuple**

0,1,2,3,4,5,6,7,8
"(0,0)","(0,1)","(0,2)","(0,3)","(0,4)","(0,5)","(0,6)","(0,7)","(0,8)"
"(1,0)","(1,1)","(1,2)","(1,3)","(1,4)","(1,5)","(1,6)","(1,7)","(1,8)"
"(2,0)","(2,1)","(2,2)","(2,3)","(2,4)","(2,5)","(2,6)","(2,7)","(2,8)"
"(3,0)","(3,1)","(3,2)","(3,3)","(3,4)","(3,5)","(3,6)","(3,7)","(3,8)"
"(4,0)","(4,1)","(4,2)","(4,3)","(4,4)","(4,5)","(4,6)","(4,7)","(4,8)"
"(5,0)","(5,1)","(5,2)","(5,3)","(5,4)","(5,5)","(5,6)","(5,7)","(5,8)"


In [37]:
# set up initial data strucutres 
# q(s,a) - the q-values for each action in each state
def zero_q_values():
    qvals = {"up": 0.0, "right": 0.0, "down": 0.0, "left": 0.0}
    return qvals
q_value_table = [[zero_q_values() for i in range(mygrid.ncols)] for j in range(mygrid.nrows)]

# pi - the policy table
policy_table = [[random_policy() for i in range(mygrid.ncols)] for j in range(mygrid.nrows)]
display(Markdown("**Initial (randomized) policy**"))
mygrid.pretty_print_policy_table(policy_table)

**Initial (randomized) policy**

0,1,2,3,4,5,6,7,8
↓,→,→,→,↑,←,←,▉,↓
→,▉,▉,↓,▉,▉,↑,▉,←
→,▉,▉,↑,▉,▉,↑,▉,←
↓,▉,▉,↑,▉,▉,↑,→,→
→,▉,▉,↑,▉,▉,▉,→,←
↑,←,←,↑,→,←,↓,←,▉


<a name="3"></a>
## 3 - Main Functions

In [26]:
def be_greedy(q_values):
    if len(q_values)==0:
        return {}
    
    keys = list(q_values.keys())
    vals = [q_values[i] for i in keys]    
    maxqs = [i for i,x in enumerate(vals) if x==max(vals)]
    if len(maxqs)>1:
        pos = random.choice(maxqs)
    else:
        pos = maxqs[0]
    policy = deepcopy(q_values)
    for i in policy.keys():
        policy[i]=0.0
    policy[keys[pos]]=1.0
    return policy
    
def epsilon_greedy(actions, epsilon):
    if random.random() < epsilon:
        return random.choice(list(actions.keys()))
    else:
        if actions['up'] == 1.0:
            return 'up'
        elif actions['right'] == 1.0:
            return 'right'
        elif actions['down'] == 1.0:
            return 'down'
        elif actions['left'] == 1.0:
            return 'left'

<a name="4"></a>
## 4 - Helper Functions

<a name="5"></a>
## 5 - Train the Agent

In [43]:
starting_state = 45
goal_state = 8 # terminate the MC roll out when getting to this state
GAMMA = 0.95
epsilon = 0.1
ITERATIONS = 10000 
max_depth = 120
alpha = 1e-3 
start_itr = 130 # improve policy only after certain iterations
random.seed(5000) 

returns = {} ## key is (state, action), values are N, the number of visits for a (s,a) pair
for i in range(ITERATIONS):  
    # "exploring start": Agent starts in a random valid state
    current_state = random.choice(list(mygrid.valid_states.keys())) 
    depth = 0 
    if current_state != goal_state and depth < max_depth:
        sx, sy = mygrid.index_to_coord(current_state)
        action = epsilon_greedy(policy_table[sx][sy],epsilon)
        if action == 'up':
            new_state = mygrid.up(current_state)
        elif action == 'right':
            new_state = mygrid.right(current_state)
        elif action == 'down':
            new_state = mygrid.down(current_state)
        elif action == 'left':
            new_state = mygrid.left(current_state)
        r = rewards[str([current_state,action,new_state])]       
   
        sx, sy = mygrid.index_to_coord(current_state)
        sx_new, sy_new = mygrid.index_to_coord(new_state)
        
        # find max q values for the new state 
        for idx, n_action in enumerate(list(policy_table[sx_new][sy_new].keys())): 
            q_new =  q_value_table[sx_new][sy_new][n_action]
            if idx == 0:
                max_q = q_new
            else:
                if q_new > max_q: 
                    max_q = q_new
                    
        # update q value table 
        q_value_table[sx][sy][action] += alpha * (r + GAMMA * max_q - q_value_table[sx][sy][action]) 
            
        # improve policy only after some iterations:
        if i >= start_itr: 
            for sx in range(len(q_value_table)):
                for sy in range(len(q_value_table[sx])):
                    policy_table[sx][sy] = be_greedy(q_value_table[sx][sy])
                
        current_state = new_state
        depth += 1
    
display(Markdown("**Improved policy**"))
mygrid.pretty_print_policy_table(policy_table)

**Improved policy**

0,1,2,3,4,5,6,7,8
→,→,→,→,→,→,↓,▉,←
↑,▉,▉,↑,▉,▉,↓,▉,↑
↓,▉,▉,↓,▉,▉,↓,▉,↑
↓,▉,▉,↓,▉,▉,→,→,↑
↓,▉,▉,↓,▉,▉,▉,→,↑
→,→,→,→,→,→,→,↑,▉
