In [13]:
from __future__ import annotations
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
import graphviz
import numpy as np
import random
from pprint import pprint
from typing import (Callable, Dict, Iterable, Generic, Sequence, Tuple,
                    Mapping, TypeVar, Set)

from distribution import (Categorical, Distribution, FiniteDistribution,
                             SampledDistribution, Gaussian, Choose)

import itertools

## State

State Space: 100 locations on a 10x10 game board with these exceptions:

Transition Probabilities: Equal probability of moving {1,2,3,4,5,6} steps ahead, with exceptions for certain states with a ladder or snake. 

In [14]:
S = TypeVar('S')
X = TypeVar('X')

class State(ABC, Generic[S]):
    state: S

    def on_non_terminal(
        self,
        f: Callable[[NonTerminal[S]], X],
        default: X
    ) -> X:
        if isinstance(self, NonTerminal):
            return f(self)
        else:
            return default


@dataclass(frozen=True)
class Terminal(State[S]):
    state: S


@dataclass(frozen=True)
class NonTerminal(State[S]):
    state: S

In [15]:
class MarkovProcess(ABC, Generic[S]):
    '''A Markov process with states of type S.
    '''
    @abstractmethod
    def transition(self, state: NonTerminal[S]) -> Distribution[State[S]]:
        '''Given a state of the process, returns a distribution of
        the next states.  Returning None means we are in a terminal state.
        '''

    def simulate(self, start_state_distribution: Distribution[NonTerminal[S]]) -> Iterable[State[S]]:
        '''Run a simulation trace of this Markov process, generating the
        states visited during the trace.

        This yields the start state first, then continues yielding
        subsequent states forever or until we hit a terminal state.
        '''

        state: State[S] = start_state_distribution.sample()
        yield state

        while isinstance(state, NonTerminal):
            state = self.transition(state).sample()
            yield state

    def traces(self,start_state_distribution: Distribution[NonTerminal[S]]) -> Iterable[Iterable[State[S]]]:
        '''Yield simulation traces (the output of `simulate'), sampling a
        start state from the given distribution each time.

        '''
        while True:
            yield self.simulate(start_state_distribution)


Transition = Mapping[NonTerminal[S], FiniteDistribution[State[S]]]

In [16]:
@dataclass(frozen=True)
class Die(Distribution[int]):
    sides: int
    def sample(self) -> int:
        return random.randint(1, self.sides)
    
    def expectation(self) -> int:
        return (int + 1) / 2

@dataclass
class StateSaL:
    position: int
    
    def current_state(self) -> int:
        return position
    
@dataclass
class SnakesandLadders(MarkovProcess[StateSaL]):
    # Create a distribution (a 6-sided die)
    six_sided = Die(6)
    gauss = Gaussian(0,1)
    snakeorladder = Choose([-10,-5,-3,2,5,7])
    
    # Method to get die roll
    def roll_dice():
        return six_sided.sample()
    
    def transition(Self, state: NonTerminal[StateSaL]) -> State[StateSaL]:
        die_roll = self.roll_dice()
        SorL = gauss.sample() > 1
        SorLdistance = snakeorladder.sample()
        
        return State(StateSaL(state.state.position + die_roll + (SorL * SorLdistance)))
    
    
def simulation(process, start_state):
    state = start_state
    while True:
        yield state
        state = process.next_state(state)

## Test

In [17]:
s = SnakesandLadders()

In [22]:
from distribution import Constant
import numpy as np
def price_traces(time_steps: int, num_traces: int) -> np.ndarray:
    mp = SnakesandLadders()
    start_state = StateSaL(0)
    return np.vstack([np.fromiter(
        (s.state.position for s in itertools.islice(mp.simulate(start_state), time_steps + 1)), 
        float) for _ in range(num_traces)])

In [23]:
price_traces(10,10)

AttributeError: 'StateSaL' object has no attribute 'sample'

## Finite Markov Process

In [None]:
class FiniteMarkovProcess(MarkovProcess[S]):
    '''A Markov Process with a finite state space.

    Having a finite state space lets us use tabular methods to work
    with the process (ie dynamic programming).

    '''

    non_terminal_states: Sequence[NonTerminal[S]]
    transition_map: Transition[S]

    def __init__(self, transition_map: Mapping[S, FiniteDistribution[S]]):
        non_terminals: Set[S] = set(transition_map.keys())
        self.transition_map = {
            NonTerminal(s): Categorical(
                {(NonTerminal(s1) if s1 in non_terminals else Terminal(s1)): p
                 for s1, p in v.table().items()}
            ) for s, v in transition_map.items()
        }
        self.non_terminal_states = list(self.transition_map.keys())

    def __repr__(self) -> str:
        display = ""

        for s, d in self.transition_map.items():
            display += f"From State {s.state}:\n"
            for s1, p in d:
                opt = "Terminal " if isinstance(s1, Terminal) else ""
                display += f"  To {opt}State {s1.state} with Probability {p:.3f}\n"

        return display

    def get_transition_matrix(self) -> np.ndarray:
        sz = len(self.non_terminal_states)
        mat = np.zeros((sz, sz))

        for i, s1 in enumerate(self.non_terminal_states):
            for j, s2 in enumerate(self.non_terminal_states):
                mat[i, j] = self.transition(s1).probability(s2)

        return mat

    def transition(self, state: NonTerminal[S])\
            -> FiniteDistribution[State[S]]:
        return self.transition_map[state]

    def get_stationary_distribution(self) -> FiniteDistribution[S]:
        eig_vals, eig_vecs = np.linalg.eig(self.get_transition_matrix().T)
        index_of_first_unit_eig_val = np.where(
            np.abs(eig_vals - 1) < 1e-8)[0][0]
        eig_vec_of_unit_eig_val = np.real(
            eig_vecs[:, index_of_first_unit_eig_val])
        return Categorical({
            self.non_terminal_states[i].state: ev
            for i, ev in enumerate(eig_vec_of_unit_eig_val /
                                   sum(eig_vec_of_unit_eig_val))
        })

    def display_stationary_distribution(self):
        pprint({
            s: round(p, 3)
            for s, p in self.get_stationary_distribution()
        })

    def generate_image(self) -> graphviz.Digraph:
        d = graphviz.Digraph()

        for s in self.transition_map.keys():
            d.node(str(s))

        for s, v in self.transition_map.items():
            for s1, p in v:
                d.edge(str(s), str(s1), label=str(p))

        return d


# Reward processes
@dataclass(frozen=True)
class TransitionStep(Generic[S]):
    state: NonTerminal[S]
    next_state: State[S]
    reward: float

    def add_return(self, γ: float, return_: float) -> ReturnStep[S]:
        '''Given a γ and the return from 'next_state', this annotates the
        transition with a return for 'state'.

        '''
        return ReturnStep(
            self.state,
            self.next_state,
            self.reward,
            return_=self.reward + γ * return_
        )


@dataclass(frozen=True)
class ReturnStep(TransitionStep[S]):
    return_: float