In [1]:
import sys

sys.path.append("/Users/kanishkjain/opt/anaconda3/envs/gym/lib/python3.9/site-packages")

import random
import collections
from pprint import pprint

import gym
import gym_toytext
import numpy as np
import matplotlib.pyplot as plt

In [14]:
class Agent:
    def __init__(
        self, environment="Roulette-v0", gamma=0.1, theta=1e-6, epsilon=1.0, alpha=0.1
    ) -> None:

        self.env = gym.make(environment)
        self.env.reset()

        self.gamma = gamma
        self.theta = theta
        self.epsilon = epsilon

        self.alpha = alpha

        self.A_space = self.env.action_space
        self.S_space = self.env.observation_space
        self.R_range = self.env.reward_range

        self.Num_A = self.A_space.n
        self.Num_S = self.S_space.n
        
        if environment == 'CliffWalking-v0':
            for s in range(self.Num_S):
                for a in range(self.Num_A):
                    P, S_, R_, T = self.env.P[s][a][0]
                    if T:
                        self.env.P[s][a] = [(P, S_, 0, T)]

    def soft_policy(self):

        Pi = np.ones((self.Num_S, self.Num_A)) / self.Num_A
        return Pi

    def greedy_policy(self, Q):
        Pi = np.zeros((self.Num_S, self.Num_A))
        for s in range(self.Num_S):
            a_star = np.argmax([Q[(s, a)] for a in range(self.Num_A)])
            Pi[s, a_star] = 1.0
        return Pi

    def epsilon_greedy_policy(self, Q, s):

        p = random.random()
        if p < self.epsilon:
            return np.random.choice(self.Num_A)
        else:
            # A = np.argmax([Q[(s, a)] for a in range(self.Num_A)])
            prob = np.array([np.exp(x) for x in Q[s]])
            prob = prob/(sum(prob))
            A = np.random.choice(self.Num_A, p=prob)
            return A

        # Pi = np.zeros((self.Num_S, self.Num_A))
        # for s in range(self.Num_S):
        #     a_star = np.argmax([Q[(s, a)] for a in range(self.Num_A)])
        #     for a in range(self.Num_A):
        #         if a == a_star:
        #             Pi[s, a] = 1 - self.epsilon + self.epsilon / self.Num_A
        #         else:
        #             Pi[s, a] = self.epsilon / self.Num_A
        # return Pi

    def on_policy_monte_carlo(self, num_iter=101):

        self.epsilon = 1

        NUM_ITER = num_iter

        Q = collections.defaultdict(float)
        returns = collections.defaultdict(float)

        Pi = self.soft_policy()
        print("Starting Policy:, ", Pi)

        S_A_count = collections.defaultdict(int)

        rewards_per_episode = []
        unique_states = []

        for it in range(NUM_ITER):
            episode = self.generate_episode(Pi)
            if it % 50 == 0:
                print(f"Generating Episode Number: {it}")

            self.epsilon = max(self.epsilon * 0.99, 0.01)

            S_A = set([(S, A) for (S, A, _) in episode])

            for S, A in S_A:
                first_idx = [
                    i for i, (s, a, _) in enumerate(episode) if (s == S and a == A)
                ][0]
                G = sum(
                    [
                        r * (self.gamma ** i)
                        for i, (s, a, r) in enumerate(episode[first_idx:])
                    ]
                )

                returns[(S, A)] += G
                S_A_count[(S, A)] += 1
                Q[(S, A)] = returns[(S, A)] / S_A_count[(S, A)]

            distinct_states = set([s for s, a in S_A])

            for s in distinct_states:
                a_star = np.argmax([Q[(s, a)] for a in range(self.Num_A)])
                for a in range(self.Num_A):
                    if a == a_star:
                        Pi[s][a] = 1 - self.epsilon + self.epsilon / self.Num_A
                    else:
                        Pi[s][a] = self.epsilon / self.Num_A

        return Pi

    def off_policy_monte_carlo(self, num_iter=101):

        self.epsilon = 1

        NUM_ITER = num_iter

        Q = collections.defaultdict(float)
        C = collections.defaultdict(float)

        Pi = self.greedy_policy(Q)
        Mu = self.soft_policy()

        print("Starting Policy:, ", Pi)

        for it in range(NUM_ITER):
            episode = self.generate_episode(Mu)
            if it % 50 == 0:
                print(f"Generating Episode Number: {it}")

            self.epsilon = max(self.epsilon * 0.99, 0.1)

            G = 0
            W = 1

            for (S, A, R) in episode[::-1]:
                G = self.gamma * G + R

                C[(S, A)] += W
                Q[(S, A)] += (W / C[(S, A)]) * (G - Q[(S, A)])

                Pi = self.greedy_policy(Q)

                W = W * 1 / Mu[S, A]

                if W < 1e-6:
                    break

        return Pi

    def q_learning(self, num_iter=301):

        self.epsilon = 1.

        NUM_ITER = num_iter

        Q = collections.defaultdict(lambda: np.zeros(self.Num_A))

        for it in range(NUM_ITER):

            if it % 50 == 0:
                print(f"Generating Episode Number: {it}")

            S = self.env.reset()

            self.epsilon = max(self.epsilon * 0.999, 0.1)

            while True:
                A = self.epsilon_greedy_policy(Q, S)
                S_, R, terminal, _ = self.env.step(A)

                Q[S][A] += self.alpha * (R + self.gamma * max(Q[S_]) - Q[S][A])

                S = S_

                if terminal:
                    break

        return Q

    def sarsa(self, num_iter=301):

        self.epsilon = 1.

        NUM_ITER = num_iter

        Q = collections.defaultdict(lambda: np.zeros(self.Num_A))

        for it in range(NUM_ITER):
            if it % 50 == 0:
                print(f"Generating Episode Number: {it}")

            S = self.env.reset()
            A = self.epsilon_greedy_policy(Q, S)

            self.epsilon = max(self.epsilon * 0.999, 0.1)

            while True:
                S_, R, terminal, _ = self.env.step(A)
                A_ = self.epsilon_greedy_policy(Q, S_)
                
                Q[S][A] += self.alpha * (R + (self.gamma * Q[S_][A_]) - Q[S][A])

                S = S_
                A = A_

                if terminal:
                    break
        return Q

    def generate_episode(self, Pi):

        episode = []

        S = self.env.reset()
        while True:
            A = np.random.choice(np.arange(self.Num_A), p=Pi[S])
            S_, R, terminal, _ = self.env.step(A)
            episode.append((S, A, R))
            S = S_
            if terminal:
                break
        return episode

    def show_policy(self, Q):

        MAX_STEPS = 500

        S = self.env.reset()
        print(f"Starting state: {S}")
        self.env.render()

        step = 0
        while step < MAX_STEPS:
            # A = self.epsilon_greedy_policy(Q, S)
            A = np.argmax(Q[S])
            S_, R, done, _ = self.env.step(A)
            self.env.render()
            if done:
                break
            print(
                f"Current State: {S}, action: {A}, reward: {R}, done: {done}, step: {step}"
            )
            S = S_
            step += 1
        print(
            f"Current State: {S}, action: {A}, reward: {R}, done: {done}, step: {step}"
        )

        self.env.close()
        print("Finished", done)

In [15]:
agent = Agent(environment='CliffWalking-v0')

In [16]:
# on_policy = agent.on_policy_monte_carlo()

In [17]:
# off_policy = agent.off_policy_monte_carlo()

In [18]:
q_policy = agent.q_learning()

Generating Episode Number: 0
Generating Episode Number: 50
Generating Episode Number: 100
Generating Episode Number: 150
Generating Episode Number: 200
Generating Episode Number: 250
Generating Episode Number: 300


In [19]:
agent.show_policy(q_policy)

Starting state: 36
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
x  C  C  C  C  C  C  C  C  C  C  T

o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
x  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 36, action: 0, reward: -1, done: False, step: 0
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  x  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 24, action: 1, reward: -1, done: False, step: 1
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  x  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 25, action: 1, reward: -1, done: False, step: 2
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  x  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 26, action: 1, reward: -1, done: False, step: 3
o  o  o  o  o  o  o  o  

In [20]:
sarsa_policy = agent.sarsa()

Generating Episode Number: 0
Generating Episode Number: 50
Generating Episode Number: 100
Generating Episode Number: 150
Generating Episode Number: 200
Generating Episode Number: 250
Generating Episode Number: 300


In [21]:
agent.show_policy(sarsa_policy)

Starting state: 36
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
x  C  C  C  C  C  C  C  C  C  C  T

o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
x  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 36, action: 0, reward: -1, done: False, step: 0
o  o  o  o  o  o  o  o  o  o  o  o
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 24, action: 0, reward: -1, done: False, step: 1
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 12, action: 0, reward: -1, done: False, step: 2
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 3
x  o  o  o  o  o  o  o  o

x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 53
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 54
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 55
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 56
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, do

o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 111
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 112
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 113
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 114
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 115
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o

o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 170
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 171
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 172
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 173
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 174
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o

o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 229
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 230
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 231
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 232
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 233
x  o  o  o

x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 288
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 289
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 290
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 291
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1

Current State: 0, action: 0, reward: -1, done: False, step: 346
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 347
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 348
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 349
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 350
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C 

o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 405
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 406
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 407
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 408
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 409
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o

o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 464
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 465
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 466
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 467
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o  o  o  o  o  o  o  o  o
o  C  C  C  C  C  C  C  C  C  C  T

Current State: 0, action: 0, reward: -1, done: False, step: 468
x  o  o  o  o  o  o  o  o  o  o  o
o  o  o  o

In [10]:
# dict(sarsa_policy)

{36: array([  -1.19821971, -104.00426301,   -2.42350452,   -4.41292296]),
 24: array([-1.11672507, -3.72861322, -4.90775164, -1.21429508]),
 12: array([-1.1112574 , -1.18439219, -1.25288724, -1.11567311]),
 13: array([-1.11261157, -1.14452355, -4.41630484, -1.11583283]),
 1: array([-1.11240048, -1.11150912, -1.1875972 , -1.11120539]),
 0: array([-1.11128022, -1.11154102, -1.11583771, -1.11132275]),
 2: array([-1.11224148, -1.1123978 , -1.16004533, -1.11219003]),
 3: array([-1.11248408, -1.11149868, -1.13561811, -1.11189805]),
 4: array([-1.1117427 , -1.11159211, -1.11965555, -1.11278631]),
 5: array([-1.11171478, -1.11135097, -1.12563704, -1.11172049]),
 17: array([-1.11177663, -1.1456313 , -2.33980521, -1.14000473]),
 16: array([-1.11156567, -1.12726188, -3.74716837, -1.13777789]),
 28: array([  -1.13653587,   -2.1157496 , -101.82177274,   -3.5293764 ]),
 27: array([  -1.13603221,   -4.09186946, -101.49785294,   -2.48178807]),
 25: array([  -1.20334672,   -2.13770102, -101.46116794,  

In [11]:
# dict(q_policy)

In [12]:
# sarsa_policy[0]

array([-1.11128022, -1.11154102, -1.11583771, -1.11132275])

In [13]:
# prob = np.array([np.exp(x) for x in sarsa_policy[0]])
# prob = prob/(sum(prob))
# print(prob)

[0.25030352 0.25023825 0.24916536 0.25029287]
