In [48]:
class BaseEnv:
    def __init__(self, n_states: int, n_actions: int):
        self.n_states = n_states
        self.n_actions = n_actions

    def step(self, action: int) -> tuple[int, float, bool, bool]:
        raise NotImplementedError

    def reset(self, seed: int = 0) -> int:
        self.steps = 0

class BaseAgent:
    def __init__(self, env: BaseEnv):
        self.env = env

    def act(self, state: int) -> int:
        raise NotImplementedError

    def train(self, steps: int):
        raise NotImplementedError

    def reset(self):
        return

from numpy.random import MT19937, Generator

def random_generator(seed: int | None = None):
    bg = MT19937(seed)
    rg = Generator(bg)
    return rg

In [49]:
import numpy as np

class BranchingEnv(BaseEnv):
    def __init__(self, n_states: int, b=1, mean=0, deviation=1, seed: int | None = None):
        assert n_states >= b + 1, f"The number of different states ({n_states}) must be more than b + 1 ({b + 1})"
        assert b >= 1, f"The branching ({b}) must be 1 or higher"

        actions = [0, 1]
        n_actions = len(actions)
        rg = random_generator(seed)

        super().__init__(
            n_states=n_states,
            n_actions=n_actions)

        self.b = b
        self.steps = 0
        self.all_steps = 0
        self.state: int | None = None
        self.rg = rg
        self.mean = mean
        self.deviation = deviation
        self.transitions = [
            [
                [
                    (
                        # random move to any spot that is not the same state,
                        # and also not the terminal state (self.n_states-1)
                        (s + rg.choice(range(1, self.n_states-1))) % (self.n_states-1),
                        rg.normal(loc=self.mean, scale=self.deviation),
                    )
                    for _ in range(self.b)
                ]
                for _ in range(self.n_actions)
            ]
            for s in range(self.n_states)
        ]

    def reset(self, seed: int | None = None) -> int:
        state = 0
        rg = random_generator(seed)
        self.steps = 0
        self.state = state
        self.rg = rg
        return state

    def step(self, action: int) -> tuple[int, float, bool, bool]:
        steps = self.steps + 1
        state = self.state

        assert state is not None, "The environment was not initialized"
        assert state != (self.n_states - 1), "The environment is in a terminal state"

        b_chosen = self.rg.choice(range(self.b))
        transition = self.transitions[state][action][b_chosen]
        next_state, reward = transition
        terminated = self.rg.random() < 0.1
        next_state = (self.n_states - 1) if terminated else next_state
        truncated = False

        self.steps = steps
        self.state = next_state
        self.all_steps += 1

        return next_state, reward, terminated, truncated

In [50]:
from collections import defaultdict

class BaseAgentParams:
    def __init__(
        self,
        state: int,
        action: int,
        reward: float,
        next_state: int,
        next_action: int,
        terminated: bool,
        truncated: bool,
    ):
        self.state = state
        self.action = action
        self.next_state = next_state
        self.next_action = next_action
        self.reward = reward
        self.terminated = terminated
        self.truncated = truncated

class BaseDynaAgent(BaseAgent):
    def __init__(
        self,
        env: BaseEnv,
        n_plan: int | None = None, # None when the planning is over the trajectory or all known state-action pairs
        plan_all: bool = False, # n_plan must be None, True for planning over the trajectory, False for all known state-action pairs
        alpha: float | None = None, # None for expected updates
        gamma: float = 1, # 1 for undiscounted task
        epsilon: float = 0.1,
        q_learning: bool = True, # True to update based on the maximum value of Q
        max_updates: int | None = None,
        seed: int | None = None,
    ):
        super().__init__(env=env)
        self.n_plan = n_plan
        self.plan_all = plan_all
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.q_learning = q_learning
        self.Q = np.zeros((env.n_states, env.n_actions), dtype=float)
        self.TA: dict[tuple[int, int], int] = defaultdict(int) # amount of transitions from (S, A)
        self.SP: dict[tuple[int, int], dict[int, int]] = defaultdict(lambda: defaultdict(int)) # amount of times in which (S, A) -> S'
        self.T: dict[tuple[int, int, int], int] = defaultdict(int) # amount of times in which (S, A) -> S'
        self.M: dict[tuple[int, int], dict[int, float]] = defaultdict(dict) # probability of (S, A) -> S'
        self.R: dict[tuple[int, int, int], float] = defaultdict(float) # mean reward of (S, A, S') (based on SP)
        self.max_updates = max_updates
        self.seed = seed
        self.rg = random_generator(seed)
        self.updates = 0

    def reset(self):
        self.updates = 0

    def initial_state(self) -> int:
        raise NotImplementedError

    def is_terminal(self, state: int) -> bool:
        raise NotImplementedError

    def act(self, state: int) -> int:
        n_actions = self.env.n_actions
        epsilon = self.epsilon
        qs = self.Q[state]
        probs = np.ones(n_actions, dtype=float) * epsilon / n_actions
        probs[np.argmax(qs)] += 1 - epsilon
        action = self.rg.choice(len(probs), p=probs)
        return action

    def act_plan(self, state: int) -> int | None:
        n_actions = self.env.n_actions
        epsilon = self.epsilon
        allowed_actions = [a for a in range(n_actions) if self.TA[(state, a)]]

        if not allowed_actions:
            return None

        qs = [self.Q[state, a] for a in allowed_actions]
        n_actions = len(qs)
        probs = np.ones(n_actions, dtype=float) * epsilon / n_actions
        probs[np.argmax(qs)] += 1 - epsilon
        action = self.rg.choice(allowed_actions, p=probs)
        return action

    # called after a real action
    def update(self, params: BaseAgentParams) -> None:
        self.update_model(params)
        self.update_value(params)
        self.plan()

    # called after a real action
    def update_model(self, params: BaseAgentParams):
        state = params.state
        action = params.action
        next_state = params.next_state
        reward = params.reward
        state_action = (state, action)
        san = (state, action, next_state)
        self.TA[state_action] += 1
        self.T[san] += 1
        self.R[san] = ((self.T[san] - 1) * self.R[san] + reward) / self.T[san]
        self.SP[state_action][next_state] += 1
        for sp in self.SP[state_action]:
            self.M[state_action][sp] = self.T[(state, action, sp)] / self.TA[state_action]

    def sample_cause(self) -> tuple[int, int]:
        keys = [key for key in self.TA if self.TA[key]]
        idx = self.rg.choice(len(keys))
        state, action = keys[idx]
        return state, action

    def sample_effect(self, state_action: tuple[int, int]) -> tuple[int, float] | None:
        probs_dict = self.M[state_action]
        probs_states = [s for s in probs_dict]
        probs = [probs_dict[s] for s in probs_dict]
        idx = self.rg.choice(len(probs_states), p=probs)
        next_state = probs_states[idx]
        state, action = state_action
        expected_r = self.R[(state, action, next_state)]
        return next_state, expected_r

    def next_value(self, params: BaseAgentParams, q_learning: bool) -> float:
        next_state = params.next_state
        next_action = params.next_action
        terminated = params.terminated
        Q = self.Q

        if terminated:
            return 0

        return max(Q[next_state]) if q_learning else Q[next_state, next_action]

    def single_update_value(self) -> None:
        self.updates += 1

    # called both in real actions and simulated actions
    def update_value(self, params: BaseAgentParams) -> float:
        state = params.state
        action = params.action
        reward = params.reward

        Q = self.Q
        alpha = self.alpha
        gamma = self.gamma

        if alpha is not None:
            next_value = self.next_value(params=params, q_learning=self.q_learning)
            Q[state, action] += alpha * (reward + gamma * next_value - Q[state, action])
            self.single_update_value()
        else:
            # Expected Updates
            probs = self.M[(state, action)]
            Q[state, action] = 0
            for sp in probs:
                reward = self.R[(state, action, sp)]
                next_value = self.next_value(
                    params=BaseAgentParams(
                        state=state,
                        action=action,
                        reward=reward,
                        next_state=sp,
                        next_action=None,
                        terminated=False,
                        truncated=False,
                    ),
                    q_learning=True,
                )
                Q[state, action] += probs[sp] * (reward + gamma * next_value - Q[state, action])
            self.single_update_value()

    def plan(self):
        # True for planning over all known state-action pairs for every step,
        # otherwise it uses the on-police trajectory
        plan_all = self.plan_all
        n_plan = self.n_plan
        idx_plan_all = 0
        ta_keys = [key for key in self.TA if self.TA[key] > 0]
        terminated = False
        count = 0

        if plan_all:
            s, a = ta_keys[idx_plan_all]
        else:
            s = self.initial_state()
            a = self.act_plan(s)

        while True:
            if self.max_updates is not None:
                if self.updates >= self.max_updates:
                    break
            sp, r = self.sample_effect((s, a))
            ap = self.act_plan(sp)
            terminated = self.is_terminal(sp)
            truncated = (ap is None)
            self.update_value(BaseAgentParams(
                state=s,
                action=a,
                reward=r,
                next_state=sp,
                next_action=ap,
                terminated=terminated,
                truncated=False,
            ))

            if n_plan is not None:
                count += 1
                if count >= n_plan:
                    break

            if plan_all:
                idx_plan_all += 1
                if idx_plan_all >= len(ta_keys):
                    if n_plan is not None:
                        idx_plan_all = 0
                    else:
                        break
                s, a = ta_keys[idx_plan_all]
            elif terminated or truncated:
                if n_plan is not None:
                    idx_plan_all = 0
                    s = self.initial_state()
                    a = self.act_plan(s)
                else:
                    break
            else:
                s = sp
                a = ap

In [51]:
class BranchingDynaAgent(BaseDynaAgent):
    def __init__(
        self,
        env: BranchingEnv,
        plan_all: bool = False,
        alpha: float | None = None, # None for expected updates
        gamma: float = 1, # 1 for undiscounted task
        epsilon: float = 0.1,
        q_learning: bool = True, # True to update based on the maximum value of Q
        max_updates: int | None = None,
        seed: int | None = None,
    ):
        super().__init__(
            env=env,
            alpha=alpha,
            gamma=gamma,
            epsilon=epsilon,
            q_learning=q_learning,
            max_updates=max_updates,
            seed=seed)
        self.plan_all = plan_all
        self.values_history: list[float] = []

    def reset(self) -> float:
        super().reset()
        self.values_history = [self.initial_value()]

    def initial_value(self) -> float:
        return np.mean(self.Q[0])

    def initial_state(self) -> int:
        return 0

    def is_terminal(self, state: int) -> bool:
        return state == self.env.n_states - 1

    def single_update_value(self) -> None:
        super().single_update_value()
        self.values_history.append(self.initial_value())

    def train(self) -> list[float]:
        self.reset()

        rewards = 0
        terminated = False
        truncated = False
        state = self.env.reset(self.seed)
        action = self.act(state)

        while self.updates < self.max_updates:
            next_state, r, terminated, truncated = self.env.step(action)
            rewards += r
            next_action = self.act(next_state)

            self.update(BaseAgentParams(
                state=state,
                action=action,
                reward=r,
                next_state=next_state,
                next_action=next_action,
                terminated=terminated,
                truncated=truncated,
            ))

            if terminated or truncated:
                state = self.env.reset()
                action = self.act(state)
            else:
                state = next_state
                action = next_action

        return self.values_history

In [52]:
def test_case(n_states: int, plan_all: bool, b: int, max_updates: int, seed: int | None) -> list[float]:
    env = BranchingEnv(n_states=n_states, b=b, seed=seed)
    agent = BranchingDynaAgent(env=env, plan_all=plan_all, max_updates=max_updates, seed=seed)
    q_values = agent.train()
    return q_values

In [53]:
import typing
import matplotlib.pyplot as plt

def show_branching(title: str, cases: list[tuple[str, typing.Callable[[], list[float]]]]):
    print('-' * 80)
    print('Last value of start state')
    results: list[tuple[str, list[float]]] = []
    for name, fn in cases:
        q_values = fn()
        results.append((name, q_values))
        print(f'{name}: {q_values[-1]}')
    print('-' * 80)

    plt.figure(figsize=(8, 8))

    for name, q_values in results:
        plt.plot(q_values, label=name)

    plt.xlabel('Computation time, in expected updates')
    plt.ylabel('Value of start state under greedy policy')
    plt.title(title)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.show()

In [None]:
seed = 1
max_updates = 20000
n_states = 1000
show_branching(title=f'{n_states} states', cases=[
    ('on-policy, b=1', lambda: test_case(n_states=n_states, plan_all=False, b=1, max_updates=max_updates, seed=seed)),
    ('on-policy, b=3', lambda: test_case(n_states=n_states, plan_all=False, b=3, max_updates=max_updates, seed=seed)),
    ('on-policy, b=10', lambda: test_case(n_states=n_states, plan_all=False, b=10, max_updates=max_updates, seed=seed)),

    ('uniform, b=1', lambda: test_case(n_states=n_states, plan_all=True, b=1, max_updates=max_updates, seed=seed)),
    ('uniform, b=3', lambda: test_case(n_states=n_states, plan_all=True, b=3, max_updates=max_updates, seed=seed)),
    ('uniform, b=10', lambda: test_case(n_states=n_states, plan_all=True, b=10, max_updates=max_updates, seed=seed)),
])

In [None]:
seed = 1
max_updates = 200000
n_states = 10000
show_branching(title=f'{n_states} states', cases=[
    ('on-policy, b=1', lambda: test_case(n_states=n_states, plan_all=False, b=1, max_updates=max_updates, seed=seed)),
    ('on-policy, b=3', lambda: test_case(n_states=n_states, plan_all=False, b=3, max_updates=max_updates, seed=seed)),

    ('uniform, b=1', lambda: test_case(n_states=n_states, plan_all=True, b=1, max_updates=max_updates, seed=seed)),
    ('uniform, b=3', lambda: test_case(n_states=n_states, plan_all=True, b=3, max_updates=max_updates, seed=seed)),
])