In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm as tq
from PIL import Image
import PIL.ImageDraw as ImageDraw
import imageio
import os
import json
import pickle

In [4]:
class MCLearning:
    def __init__(self, eps=0.1, use_pretrained=True):
        self.x = np.round(np.linspace(-0.3, 0.3, num=30), decimals=3)
        self.x_dot = np.round(np.linspace(-1.5, 1.5, num=15), decimals=3)
        self.th = np.round(np.linspace(-0.21, 0.21, num=21), decimals=3)
        self.th_dot = np.round(np.linspace(-2.5, 2.5, num=25), decimals=3)
        self.env = gym.make('CartPole-v1')
        self.actions = np.array([0, 1])
        self.run_hist = []
        self.eps = eps
        self.name = "Monte_carlo"
        if use_pretrained:
            self.policy = np.load('./MCLearning/custom_reward_1/policy.npy')
            self.st_act_val = np.load(
                './MCLearning/custom_reward_1/st_act_val.npy')
            with open('./MCLearning/custom_reward_1/st_dict_idx.pkl', 'rb') as f:
                self.st_dict_idx = pickle.load(f)
            with open('./MCLearning/custom_reward_1/st_idx.pkl', 'rb') as f:
                self.st_idx_dict = pickle.load(f)
            with open('./MCLearning/custom_reward_1/run_hist.pkl', 'rb') as f:
                self.run_hist = pickle.load(f)
            self.state_count = self.policy.shape[0]
            print("Resuming training from", len(self.run_hist), "episode!")

    def generate_states(self):
        self.states = []
        for x in self.x:
            for x_dot in self.x_dot:
                for th in self.th:
                    for th_dot in self.th_dot:
                        self.states.append((x, x_dot, th, th_dot))
        '''
        The structure of self.state_dict_idx: key is the state tuple, value is its index
        The structure of self.state_idx_dict
        '''
        self.st_dict_idx = {state: i for i, state in enumerate(self.states)}
        self.st_idx_dict = {i: state for i, state in enumerate(self.states)}
        '''
        The variable self.policy states the best action for each state, taking the state's index
        '''
        self.state_count = len(self.states)
        self.policy = np.random.randint(0, 2, (self.state_count))

        '''
        The structure of self.st_act_val is as follows:
        1. The length of this array depicts the config of each state.
        2. The two vectors in each state correspond to the two actions - 0, 1
        3. The first value for each action vector is the action value for that state
        4. The second value for each action vector is the freq with which that action is taken
        '''
        self.st_act_val = np.zeros((self.state_count, 2, 2))
        del self.states

    def get_state(self, desc):
        def nearest(v, x):
            if x <= v[0]:
                return v[0]
            elif x >= v[-1]:
                return v[-1]
            n = len(v)
            lo, hi, mid = 0, n-1, 0
            while (lo < hi):
                mid = lo+int((hi-lo)/2)
                if v[mid] == x:
                    return v[mid]
                if x < v[mid]:
                    if (mid > 0 and x > v[mid - 1]):
                        return v[mid] if v[mid]-x >= v[mid-1]-x else v[mid-1]
                    hi = mid
                else:
                    if (mid < n - 1 and x < v[mid + 1]):
                        return v[mid] if v[mid]-x >= v[mid+1]-x else v[mid+1]
                    lo = mid + 1
            return v[mid]
        x = nearest(self.x, desc[0])
        x_dot = nearest(self.x_dot, desc[1])
        th = nearest(self.th, desc[2])
        th_dot = nearest(self.th_dot, desc[3])
        return self.st_dict_idx[(x, x_dot, th, th_dot)]

    # Using e-greedy to evaluate policy
    def get_action(self, state):
        rn = np.random.uniform(0, 1)
        act_best = self.policy[self.get_state(state)]
        if rn > self.eps:
            return act_best
        else:
            return np.random.choice([0, 1])

    def save_agent(self):
        np.save('./MCLearning/custom_reward_1/policy.npy', self.policy)
        np.save('./MCLearning/custom_reward_1/st_act_val.npy', self.st_act_val)
        with open('./MCLearning/custom_reward_1/st_dict_idx.pkl', 'wb') as f:
            pickle.dump(self.st_dict_idx, f)
        with open('./MCLearning/custom_reward_1/st_idx.pkl', 'wb') as f:
            pickle.dump(self.st_idx_dict, f)
        with open('./MCLearning/custom_reward_1/run_hist.pkl', 'wb') as f:
            pickle.dump(self.run_hist, f)
        plt.plot(self.run_hist)
        plt.savefig('./MCLearning/custom_reward_1/run_hist.png')

    # Using every time visit
    def train(self, episodes=20):
        self.state_freq = np.zeros(self.state_count)
        for episode in tq(range(episodes)):
            history = []
            obs, _ = self.env.reset()
            while (1):
                action = self.get_action(obs)
                state = self.get_state(obs)
                obs, reward, term, trunc, _ = self.env.step(action)
                reward += (1-obs[0]**2-obs[2]**2)
                history.append((state, action, reward))
                if term or trunc:
                    break

            self.run_hist.append(len(history))

            cum_reward = 0
            for step in reversed(history):
                st, act, rew = step
                cum_reward += rew
                self.st_act_val[st][act][1] += 1
                self.st_act_val[st][act][0] += (
                    cum_reward - self.st_act_val[st][act][0])/self.st_act_val[st][act][1]

            ### Policy Improvement ###
            for st in range(self.state_count):
                tmp = np.array([self.st_act_val[st][0][0],
                               self.st_act_val[st][1][0]])
                self.policy[st] = np.argmax(tmp)
            if episode % 100 == 99:
                self.save_agent()
        self.save_agent()

In [5]:
agent = MCLearning(eps=-1, use_pretrained=True)
agent.generate_states()
agent.train(episodes=20000)

Resuming training from 2 episode!


  0%|          | 47/20000 [00:41<4:52:32,  1.14it/s]


KeyboardInterrupt: 