In [48]:
import numpy as np
from copy import deepcopy
from tqdm import tqdm

In [67]:
##Environment
class GridWorld:
    def __init__(self, height = 50, width = 50, n_c = 1, n_o = 1):
        self.wd = width
        self.ht = height
        self.n_c = n_c            # no. of classes
        self.n_o = n_o            # no. of objects  (n_c <= n_o)
        self.reward = [-1, 1]     # range of rewards
        self.start_coord = (0, 0)
        self.goal_coord = (height-1, width-1)
        self.x_g = height-1        # goal x-coord
        self.x_g = width-1         # goal y-coord
        self.r_g = 1               # goal reward
        self.D = 9 + n_o
        self.d = n_c + 1
        '''
        UP: 0
        RIGHT: 1
        DOWN: 2
        LEFT: 3
        '''
        self.actions = [0, 1, 2, 3]
        #self.x_s, self.y_s = self.start_coord

        self.object_coord = []
        self.object_class = []
        self.class_reward = {}
        self.objects_so_far = [0] * self.n_o

        self.generate_objects()    # Done once
        
        self.assign_classes()
        self.generate_class_reward()

    def generate_objects(self):
        n_o = self.n_o
        while n_o:
            object_x = np.random.randint(self.wd)
            object_y = np.random.randint(self.ht)
            if((object_x,object_y) not in self.object_coord):
                self.object_coord.append((object_x,object_y))
                n_o -= 1
  
    def assign_classes (self):
        self.object_class = np.random.randint(0, self.n_c, size = self.n_o)
 
    def generate_class_reward(self):
        for i in range(0, self.n_c):
            self.class_reward[i] = np.random.uniform(self.reward[0], self.reward[1])

    def reset(self):
        self.cur_pos_x, self.cur_pos_y = self.start_coord
        self.object_so_far = [0] * self.n_o
        self.generate_class_reward()
        return self.start_coord

    def render(self):
        for i in range(self.wd):
            for j in range(self.ht):
                if (i, j) in self.object_coord:
                    print(self.object_class[self.object_coord.index((i, j))], end = "")
                else:
                    print("*",end = "")
            print("")
            
    def show_class_rewards(self):
        print(self.class_reward)
    
    def restart(self):
        self.cur_pos_x, self.cur_pos_y = self.start_coord
        self.objects_so_far = [0] * self.n_o
        return self.start_coord
        
    def step(self, action):
        flag = 1
        reward = 0
        phi_at_step = [0] * (self.n_c + 1)
        done = False
        if action == 0 and self.cur_pos_y + 1 < self.ht:
            self.cur_pos_y += 1
        elif action == 1 and self.cur_pos_x + 1 < self.wd:
            self.cur_pos_x += 1
        elif action == 2 and self.cur_pos_y - 1 >=0:
            self.cur_pos_y -= 1
        elif action == 3 and self.cur_pos_x - 1 >=0:
            self.cur_pos_x -= 1
        
        pos = (self.cur_pos_x, self.cur_pos_y)
        if pos == self.goal_coord:
            phi_at_step[-1] = self.r_g
            done = True
            reward += 1
        elif pos in self.object_coord:
            object_index = self.object_coord.index(pos)
            if self.objects_so_far[object_index] == 0:
                phi_at_step[self.object_class[object_index]] = 1
                self.objects_so_far[object_index] = 1
                reward += self.class_reward[self.object_class[object_index]]

        return pos, np.array(phi_at_step), reward, done
               

In [5]:
env = GridWorld(10, 10, 2, 10)

In [6]:
done = False
state = env.restart()
env.render()
#env.showClassRewards()
while not done:
    a = np.random.choice([0, 1, 2, 3])
    state, phi, r, done = env.step(a)
    if r != 0:
        print("State", state, end = " ")
        print("PHI", phi, end = " ")
        print("Reward", r, end = " ")

**********
1******1**
***0******
**********
*1********
0*********
******00*1
*1********
**********
****1*****
State (1, 0) PHI [0, 1, 0] Reward -0.5561393869552584 State (5, 0) PHI [1, 0, 0] Reward -0.5280895023193286 State (4, 1) PHI [0, 1, 0] Reward -0.5561393869552584 State (7, 1) PHI [0, 1, 0] Reward -0.5561393869552584 State (6, 6) PHI [1, 0, 0] Reward -0.5280895023193286 State (6, 7) PHI [1, 0, 0] Reward -0.5280895023193286 State (1, 7) PHI [0, 1, 0] Reward -0.5561393869552584 State (6, 9) PHI [0, 1, 0] Reward -0.5561393869552584 State (9, 9) PHI [0, 0, 1] Reward 1 

In [83]:
class RadialBasis():

    def __init__(self, x_dim, y_dim, basis_x, basis_y):

        self.centres = []
        self.x_dim, self.y_dim = x_dim, y_dim
        for x in np.linspace(0, x_dim, basis_x):
            for y in np.linspace(0, y_dim, basis_y):
                self.centres.append((x/x_dim, y/y_dim))
      
    def getPositionVector(self, x, y):
        state = []
        x,y = x/self.x_dim, y/self.y_dim
        for cx, cy in self.centres:
            state.append(np.exp(-1 * ((cx - x)**2 + (cy - y)**2)/0.1))
        return np.array(state)

In [87]:
class SFQL:
    def __init__(self, env):
        self.env = env
        self.reward_weight_list = []
        self.Z_list = []
        self.eps_greedy = 0.4
        self.gamma = 0.95
        self.w_alpha = 0.01
        self.z_alpha = 0.01
        self.w_err_th = 0.01
        self.rdb = RadialBasis(env.wd, env.ht, 3, 3)

    def featurize_state(self, state):
        return np.hstack((self.rdb.getPositionVector(state[0], state[1]), np.array(self.env.objects_so_far)))
    
    def find_best_psi(self, w_t, state):
        max_k = 0
        action_val = -999999
        action_val_his = -999999
        for k in range(0,len(self.Z_list)):
            for action in self.env.actions: 
                psi = np.dot(self.featurize_state(state).T, self.Z_list[k][action])
                new_action_val = np.dot(psi.T,w_t)
                action_val = max(action_val,new_action_val)
            max_k = k if (action_val > action_val_his) else max_k
            action_val_his = action_val
        
        return max_k 
                
    def get_action(self, state, c, w_t):
        if np.random.uniform(0,1) < self.eps_greedy:   #In paper, Bernoulli is considered imstead of uniform
            return np.random.choice(self.env.actions)
        else:
            val_his = -99999
            act_choice = 0
            for action in self.env.actions:
                psi = np.dot(self.featurize_state(state).T, self.Z_list[c][action])
                val = np.dot(psi.T,w_t)
                act_choice = action if val > val_his else act_choice
                val_his = max(val_his,val)
            return act_choice


    
    def algorithm(self,num_tasks):
        D = self.env.D
        d = self.env.d

        NUM_EPISODES = 100
        
        #TODO: How to initialize the list? add before train or after train (only for the first task)
        
        Z = [np.random.rand(D,d) for r in range(len(self.env.actions))]
       
        for t in range(0, num_tasks):
            print("Task: ",t)
        
            w_t = np.random.rand(d)
            self.Z_list.append(deepcopy(Z))
            self.reward_weight_list.append(deepcopy(w_t))
            Z = self.Z_list[-1]
            w_t = self.reward_weight_list[-1]
            env.reset()

            mavg = 0
            for ep in range(NUM_EPISODES):
                gamma = self.gamma
                state = self.env.restart()
                done = False
                creward = 0
                while not done:
                    c = self.find_best_psi(w_t,state)
                    w_c = self.reward_weight_list[c]
                    action = self.get_action(state, c, w_t)
                    s_prime, phi_at_step, reward, done = self.env.step(action)
                    if done:
                        gamma = 0
                    else:
                        c_prime = self.find_best_psi(w_t, s_prime)
                        a_prime = self.get_action(s_prime, c_prime, w_t)
                    creward += reward
                    
                    w_t = w_t + self.w_alpha * (reward - np.dot(phi_at_step.T, w_t)) * phi_at_step
                    psi_prime = np.dot(self.featurize_state(s_prime).T, self.Z_list[t][a_prime])
                    psi = np.dot(self.featurize_state(state).T, self.Z_list[t][action])
                    z_t = self.Z_list[t][action]
                    for k in range(0,d):
                        target_k = phi_at_step[k] + gamma*psi_prime[k]
                        z_t[:, k] = z_t[:, k] + self.z_alpha * (target_k - psi[k]) * self.featurize_state(state)

                    if c != t:
                        a_prime = self.get_action(s_prime, c, w_c)

                        psi_prime = np.dot(self.featurize_state(s_prime).T, self.Z_list[c][a_prime])
                        psi = np.dot(self.featurize_state(state).T, self.Z_list[c][action])
                        z_c = self.Z_list[c][action]
                        for k in range(0,d):
                            target_k = phi_at_step[k] + gamma*psi_prime[k]
                            z_c[:, k] = z_c[:, k] + self.z_alpha * (target_k - psi[k]) * self.featurize_state(state)

                    state = s_prime
                mavg += creward
                print(round(mavg/(ep + 1),4), end = " ")

In [88]:
env = GridWorld(10, 10, 2, 10)
a = SFQL(env)
a.algorithm(10)

Task:  0
-0.2523 0.3002 0.116 0.0239 0.2412 0.3524 0.4028 0.4337 0.3575 0.3168 0.3571 0.3401 0.3526 0.3344 0.3149 0.3324 0.3305 0.3288 0.3428 0.3306 0.3554 0.3504 0.3482 0.3631 0.3584 0.3774 0.3854 0.3804 0.3758 0.3867 0.3839 0.3905 0.3878 0.3939 0.3954 0.401 0.4132 0.402 0.4017 0.3913 0.3815 0.3844 0.3871 0.378 0.3841 0.3754 0.3672 0.3731 0.3693 0.3639 0.3706 0.3653 0.3679 0.3646 0.3708 0.3748 0.3716 0.3729 0.3767 0.3788 0.3832 0.3768 0.3765 0.3848 0.3787 0.3783 0.3725 0.3728 0.373 0.3675 0.3674 0.3741 0.3743 0.3773 0.3802 0.3804 0.3832 0.3859 0.3783 0.3831 0.3863 0.3815 0.3847 0.3885 0.3886 0.3886 0.3887 0.387 0.3871 0.3878 0.39 0.3885 0.3907 0.3885 0.3806 0.3786 0.3746 0.3748 0.3729 0.3751 Task:  1
2.4828 2.4828 2.4828 2.7299 2.5816 2.5651 2.5534 2.6063 2.5926 2.6805 2.6625 2.769 2.785 2.8992 2.9044 2.878 2.8257 2.8341 2.8636 2.8693 2.8944 2.897 2.9219 2.9417 2.9431 2.9444 2.9639 2.9644 2.9648 2.9652 2.9656 2.9798 2.9498 2.9651 2.9506 2.9513 2.952 2.926 2.914 2.9026 2.9044 2.9055 2.

KeyboardInterrupt: 