# LO - reinforcing functional thinking

- remember the functional idea of trying to find language features that allow a concept to be modeled functionally in a way that, given the problem context, does not compromise complexity or safety 

# LO - iterators

use of iterators over data structure or iteration of data structure in conjunction with currying and objects (this is a very powerful paradigm - why ?

- iterators remove need for actual data : beneficial because
    - less storage   
    - immutable
- as an object via a closure - benefits
    - can be consumed by client objects multiple times
- can be curried - benefits
    - specifics of data to be iterated over can be parameterised

# Assessment notes

- Draw a diagram of the DMC
- define P and R mathematically
- define S mathematically
- define R mathematically
- should non defined states in S be captured by P and R ? Or should P and R assume S is valid ?
- can you keep rolling after you have passed the winning post, or is victory assumed ?
- do you need a max iteration variable ?
- how do I know why it stopped ?
- how much does your code resemble the pseudo code for the "beautiful in its simplicity" algorithm described in (ref goes here) ?
- what features are there in your implementation that are not in the pseudo code ? Do they obscure the purpose of your code ?
- if you were the an author of (ref goes here) would you modify the pseudo code in anyway ? If so how and why? If not, why not ?
- When you were testing your code to see if it works (assuming you did), how did you "know" if it worked or not ? What benchmark did you use to test it ?
- Did you test your implementation of value iteration on any simple test cases ? Could these be made into unit tests ?

### Implement a mapping $S \rightarrow 2^{S}$ associating a state $s$ with its set of reachable states $S^{'} \subseteq 2^{S}$.

For piglet, given a valid state $(i,j,k)$ then $(i,j,k) \rightarrow \left\{(i,j,k+1),(i+k,j,0),(i,j,0) \right\}$

In [1]:
from pymonad.tools import curry
from collections import defaultdict
from copy import deepcopy


### Model actions

### Model the probability $P(s^{'}|s,a)$

### Model Reward 

### Implement value iteration

In [3]:
@curry(6)
def value_iteration(S,A,P,R,gamma,V) :
    V_dash = deepcopy(V)
    for s,S_dash in S() :
            V_dash[s] = max([(sum([P(s,s_dash,a)*(R(s,s_dash,a) + gamma*V[s_dash][0]) for s_dash in S_dash()]),a) for a in A()])
    return V_dash
    

#### methods to test for convergence 

In [12]:
@curry(3)
def bounding_box_convergence(epsilon,V,V_dash) :
    delta = 0.0
    for s in V :
        delta = max(delta,abs(V_dash[s][0] - V[s][0]))
    return True if delta < epsilon else False
        

### Check the value iteration implementation

See [this article](https://artint.info/2e/html2e/ArtInt2e.Ch9.S5.SS2.html#Ch9.F16) for details of the test problems

#### actions

In [5]:
def lifestyle_actions() :
    for action in {"party","relax"} :
        yield action
    pass

#### state transistions

In [6]:
def lifestyle_transitions() :
    def codomain() :
        for s_dash in {"healthy","sick"} :
            yield s_dash 
    for s,S_dash in {("healthy",codomain),("sick",codomain)} :
        yield s,S_dash
    return 

#### state transition probabilities

In [7]:
def lifestyle_transition_probabilities(s,s_dash,action) :
    match (s,s_dash,action) :
        case ("healthy","healthy","party") : return 0.7
        case ("healthy","sick","party") : return 0.3
        case ("healthy","healthy","relax") : return 0.95
        case ("healthy","sick","relax") : return 0.05
        case ("sick","healthy","party") : return 0.1
        case ("sick","sick","party") : return 0.9
        case ("sick","healthy","relax") : return 0.5
        case ("sick","sick","relax") : return 0.5

#### state transition rewards

In [8]:
def lifestyle_transition_rewards(s,s_dash,action) :
    match (s,action) :
        case ("healthy","relax") : return 7.0
        case ("healthy","party") : return 10.0
        case ("sick","relax") : return 0.0
        case ("sick","party") : return 2.0
        

#### solution

In [9]:
lifestyle_value_iteration = value_iteration(lifestyle_transitions,
                                            lifestyle_actions,
                                            lifestyle_transition_probabilities,
                                            lifestyle_transition_rewards,
                                            0.8)
lifestyle_converged = bounding_box_convergence(0.001)

V = defaultdict(lambda : (0.0,None))
V_dash = lifestyle_value_iteration(V)
while not lifestyle_converged(V,V_dash) :
    V = V_dash
    V_dash = lifestyle_value_iteration(V)
V_dash

defaultdict(<function __main__.<lambda>()>,
            {'healthy': (35.711062336395926, 'party'),
             'sick': (23.80630043163403, 'relax')})

### model ${\rm piglet}_{1}$

#### actions

In [10]:
piglet_actions = lambda : iter({"roll","stick"}) 

#### transitions

In [11]:
piglet_1_transitions = lambda : iter({((0,0),lambda : iter({(0,1),(1,0),(0,0)}))})

#### state transition probabilities

In [12]:
def piglet_1_probabilities(s,s_dash,action) :
    match (s,s_dash,action) :
        case ((0,0),(0,1),"roll") : return 0.5 
        case ((0,0),(1,0),"roll") : return 0.25
        case ((0,0),(0,0),"roll") : return 0.25
        case ((0,0),(0,1),"stick") : return 0.5
        case ((0,0),(0,0),"stick") : return 0.5
    return 0.0
    

#### state transition rewards

In [14]:
piglet_1_rewards = lambda s,s_dash,action : 1.0 if s == (0,0) and s_dash == (0,1) and action == "roll" else 0.0

In [15]:
piglet_1_value_iteration = value_iteration(piglet_1_transitions,
                                           piglet_actions,
                                           piglet_1_probabilities,
                                           piglet_1_rewards,
                                           1.0)
piglet_1_converged = bounding_box_convergence(0.001)
V = defaultdict(lambda : (0.0,None))


V_dash = piglet_1_value_iteration(V)
while not piglet_1_converged(V_dash,V) :
    V = V_dash
    V_dash = piglet_1_value_iteration(V)
    
V_dash

defaultdict(<function __main__.<lambda>()>, {(0, 0): (0.66650390625, 'roll')})

## Remodel the piglet problem - see notebook for details

In [3]:
piglet_actions = lambda : lambda :  iter(("roll","stick"))

In [4]:
@curry(1)
def piglet_states(N) :
    def impl() :
        for i in range(N) :
            for j in range(N) :
                for k in range(N - i + 1) :
                    yield i,j,k
    return impl

In [5]:
@curry(4)
def piglet_reward(N,s,s_dash,a) :
    i,_,k = s_dash
    return 1.0 if i + k >= N else 0.0

In [70]:
for l in set(range(1,1)) :
    print(l)

In [88]:
# maybe able to tidy this up later
@curry(4)
def piglet_transitions(N,d,a,s) :
    def impl() :
        i,j,k = s
        if a == "roll" :
            for l in range(1,d) :
                if i + k + l <= N :
                    yield i,j,k+l
            if i + k < N :
                yield j,i,0
        if a == "stick" and i + k < N :
            yield (j,i+k,0)
    return impl

In [7]:
# assumes a legitimate transition s -> s_dash
piglet_probability = curry(5,lambda N,d,s,s_dash,a : 1.0 if a == "stick" else 1.0/d)

### new version of value iteration that employs the transition function T

In [52]:
@curry(7)
def value_iteration(S,T,A,P,R,gamma,V) :
    V_dash = deepcopy(V)
    for s in S() :
        V_dash[s] = max([(sum([P(s,s_dash,a)*(R(s,s_dash,a) + gamma*V[s_dash][0]) for s_dash in T(a,s)()]),a) for a in A()])
    return V_dash

### Try solving ${\rm piglet}_{2}^{1}$

In [10]:
@curry(2)
def piglet_adaptation(S,V) :
    for s in S() :
        _,_,k = s
        if k == 0 :
            v,a = V[s]
            V[s] = 1.0-v,a        
    return V
    

In [38]:
def delta(V,V_dash) :
    delta = 0.0
    for s in V :
        delta = max(delta,abs(V_dash[s][0] - V[s][0]))
    return delta

In [93]:

A = piglet_actions()
P = piglet_probability(10,3)
T = piglet_transitions(10,3)
S = piglet_states(10)
R = piglet_reward(10)
gamma = 1.0



VI  = value_iteration(S,T,A,P,R,gamma)
adapter = piglet_adaptation(S)
inside_bounding_box = bounding_box_convergence(0.001)

# initialise
V = defaultdict(lambda : (0.0,None))
V_dash = VI(V)
V_dash = adapter(V_dash)

its = 0
while not inside_bounding_box(V_dash,V) :
    V = V_dash
    V_dash = VI(V)
    V_dash = adapter(V_dash)
    its += 1
    # print(delta(V,V_dash))
    if(its > 200) : break
    
#V_dash = VI(V)


#while not piglet_1_converged(V_dash,V) :
#    V = V_dash
#    V_dash = piglet_1_value_iteration(V)

V_dash = VI(V_dash)
print(its)
V_dash 

31


defaultdict(<function __main__.<lambda>()>,
            {(0, 0, 0): (0.5165944283620394, 'roll'),
             (0, 0, 1): (0.5258848083018531, 'roll'),
             (0, 0, 2): (0.5409760848926743, 'roll'),
             (0, 0, 3): (0.5539523344848591, 'roll'),
             (0, 0, 4): (0.5866044834949059, 'stick'),
             (0, 0, 5): (0.5930063921789286, 'stick'),
             (0, 0, 6): (0.6563314037277326, 'stick'),
             (0, 0, 7): (0.6057716426611192, 'stick'),
             (0, 0, 8): (0.7442472433832268, 'stick'),
             (0, 0, 9): (0.4945295628857177, 'roll'),
             (0, 0, 10): (0, 'stick'),
             (0, 1, 0): (0.4975046684205596, 'roll'),
             (0, 1, 1): (0.5068984157778048, 'roll'),
             (0, 1, 2): (0.521726798399961, 'roll'),
             (0, 1, 3): (0.5352738972109503, 'roll'),
             (0, 1, 4): (0.5665101641654551, 'stick'),
             (0, 1, 5): (0.5761465915475966, 'stick'),
             (0, 1, 6): (0.6347427980007572, 's

In [151]:
print("************")
print("************")
V = V_dash
print(V)
print("************")
V_dash = VI(V)
print(V_dash)
print("************")
V_dash = piglet_adaptation(S,V_dash)
print(V_dash)
# v,a = V_dash[(0,0,0)]
# V_dash[(0,0,0)] = (1.0-v,a)

#print(V)
#print("************")
#print(V_dash)




************
************
defaultdict(<function <lambda> at 0x7892e4002200>, {(0, 0, 0): (0.3330078125, 'roll'), (0, 0, 1): (0, 'stick')})
************
defaultdict(<function <lambda> at 0x7892e4002200>, {(0, 0, 0): (0.66650390625, 'roll'), (0, 0, 1): (0, 'stick')})
************
(0, 0, 0)
(0, 0, 1)
defaultdict(<function <lambda> at 0x7892e4002200>, {(0, 0, 0): (0.33349609375, 'roll'), (0, 0, 1): (0, 'stick')})


In [58]:

A = piglet_actions()
P = piglet_probability(2,2)
T = piglet_transitions(2,2)
S = piglet_states(2)
R = piglet_reward(2)
gamma = 1.0



VI  = value_iteration(S,T,A,P,R,gamma)


piglet_1_converged = bounding_box_convergence(0.001)

# initialise
V = defaultdict(lambda : (0.0,None))
V_dash = V
#V_dash = VI(V)


#while not piglet_1_converged(V_dash,V) :
#    V = V_dash
#    V_dash = piglet_1_value_iteration(V)
    
#V_dash

In [188]:
print("************")
print("************")
V = V_dash
print(V)
print("************")
V_dash = VI(V)
print(V_dash)
print("************")
V_dash = piglet_adaptation(S,V_dash)
print(V_dash)
# v,a = V_dash[(0,0,0)]
# V_dash[(0,0,0)] = (1.0-v,a)

#print(V)
#print("************")
#print(V_dash)

************
************
defaultdict(<function <lambda> at 0x7892e419a200>, {(0, 0, 0): (0.4285714285215363, 'roll'), (0, 0, 1): (0.7142857143189758, 'roll'), (0, 0, 2): (0, 'stick'), (0, 1, 0): (0.6000000265485141, 'roll'), (0, 1, 1): (0.5999999848718289, 'roll'), (0, 1, 2): (0, 'stick'), (1, 0, 0): (0.19999997995910235, 'roll'), (1, 0, 1): (0, 'stick'), (1, 1, 0): (0.3333333333430346, 'roll'), (1, 1, 1): (0, 'stick')})
************
defaultdict(<function <lambda> at 0x7892e419a200>, {(0, 0, 0): (0.571428571420256, 'roll'), (0, 0, 1): (0.7142857142607681, 'roll'), (0, 0, 2): (0, 'stick'), (0, 1, 0): (0.39999998241546564, 'roll'), (0, 1, 1): (0.5999999899795512, 'roll'), (0, 1, 2): (0, 'stick'), (1, 0, 0): (0.8000000132742571, 'roll'), (1, 0, 1): (0, 'stick'), (1, 1, 0): (0.6666666666715173, 'roll'), (1, 1, 1): (0, 'stick')})
************
(0, 0, 0)
(0, 0, 1)
(0, 0, 2)
(0, 1, 0)
(0, 1, 1)
(0, 1, 2)
(1, 0, 0)
(1, 0, 1)
(1, 1, 0)
(1, 1, 1)
defaultdict(<function <lambda> at 0x7892e419a200>

### tests

In [31]:
A = piglet_actions()

for a in A :
    print(a)

roll
stick


In [90]:
P = piglet_probability(1,2)
T = piglet_transitions(1,2)
S = piglet_states(1)
R = piglet_reward(1)
for s in S() :
    print("-----")
    print(s) 
    print("-----")
    for a in piglet_actions()() :
        print("[" + a + "]")
        for s_dash in T(a,s)() :
            print("    " + str(s_dash) + " -> " + str(P(s,s_dash,a)) + " -> " + str(R(s,s_dash,a)))
    print("-----")

-----
(0, 0, 0)
-----
[roll]
    (0, 0, 1) -> 0.5 -> 1.0
    (0, 0, 0) -> 0.5 -> 0.0
[stick]
    (0, 0, 0) -> 1.0 -> 0.0
-----
-----
(0, 0, 1)
-----
[roll]
[stick]
-----


In [49]:
T = piglet_transitions(1,2)
S = piglet_states(1)
R = piglet_reward(1)
for s in S() :
    print(s) 
    print("-----")
    for a in piglet_actions()() :
        print("[" + a + "]")
        for s_dash in T(a,s)() :
            print("    " + str(s_dash) + " -> " + str(R(s,s_dash,a)))


(0, 0, 0)
-----
[roll]


TypeError: unsupported operand type(s) for +: 'range' and 'int'

In [285]:
T = piglet_transitions(2,2)
S = piglet_states(2)
for a in piglet_actions() :
    for s in S() :
        for t in T(a,s)() :
            print(t)

(0, 0, 1)
(0, 0, 0)
(0, 0, 2)
(0, 0, 0)
(0, 0, 3)
(0, 0, 0)
(0, 1, 1)
(1, 0, 0)
(0, 1, 2)
(1, 0, 0)
(0, 1, 3)
(1, 0, 0)
(1, 0, 1)
(0, 1, 0)
(1, 0, 2)
(0, 1, 0)
(1, 1, 1)
(1, 1, 0)
(1, 1, 2)
(1, 1, 0)
(0, 0, 0)
(0, 1, 0)
(0, 2, 0)
(1, 0, 0)
(1, 1, 0)
(1, 2, 0)
(0, 1, 0)
(0, 2, 0)
(1, 1, 0)
(1, 2, 0)


In [23]:
S = piglet_states(2)
for s in S() :
    print(s)

(0, 0, 0)
(0, 0, 1)
(0, 0, 2)
(0, 1, 0)
(0, 1, 1)
(0, 1, 2)
(1, 0, 0)
(1, 0, 1)
(1, 1, 0)
(1, 1, 1)


In [24]:
S = piglet_states(2)
R = piglet_reward(2)
for s in S() : 
    for s_dash in S() : 
        print(str(s) + " : " + str(s_dash) + " -> " + str(R(s,s_dash,"roll")))

(0, 0, 0) : (0, 0, 0) -> 0.0
(0, 0, 0) : (0, 0, 1) -> 0.0
(0, 0, 0) : (0, 0, 2) -> 1.0
(0, 0, 0) : (0, 1, 0) -> 0.0
(0, 0, 0) : (0, 1, 1) -> 0.0
(0, 0, 0) : (0, 1, 2) -> 1.0
(0, 0, 0) : (1, 0, 0) -> 0.0
(0, 0, 0) : (1, 0, 1) -> 1.0
(0, 0, 0) : (1, 1, 0) -> 0.0
(0, 0, 0) : (1, 1, 1) -> 1.0
(0, 0, 1) : (0, 0, 0) -> 0.0
(0, 0, 1) : (0, 0, 1) -> 0.0
(0, 0, 1) : (0, 0, 2) -> 1.0
(0, 0, 1) : (0, 1, 0) -> 0.0
(0, 0, 1) : (0, 1, 1) -> 0.0
(0, 0, 1) : (0, 1, 2) -> 1.0
(0, 0, 1) : (1, 0, 0) -> 0.0
(0, 0, 1) : (1, 0, 1) -> 1.0
(0, 0, 1) : (1, 1, 0) -> 0.0
(0, 0, 1) : (1, 1, 1) -> 1.0
(0, 0, 2) : (0, 0, 0) -> 0.0
(0, 0, 2) : (0, 0, 1) -> 0.0
(0, 0, 2) : (0, 0, 2) -> 1.0
(0, 0, 2) : (0, 1, 0) -> 0.0
(0, 0, 2) : (0, 1, 1) -> 0.0
(0, 0, 2) : (0, 1, 2) -> 1.0
(0, 0, 2) : (1, 0, 0) -> 0.0
(0, 0, 2) : (1, 0, 1) -> 1.0
(0, 0, 2) : (1, 1, 0) -> 0.0
(0, 0, 2) : (1, 1, 1) -> 1.0
(0, 1, 0) : (0, 0, 0) -> 0.0
(0, 1, 0) : (0, 0, 1) -> 0.0
(0, 1, 0) : (0, 0, 2) -> 1.0
(0, 1, 0) : (0, 1, 0) -> 0.0
(0, 1, 0) : (0

In [25]:
for s in piglet_states(2)() :
    print(s)

(0, 0, 0)
(0, 0, 1)
(0, 0, 2)
(0, 1, 0)
(0, 1, 1)
(0, 1, 2)
(1, 0, 0)
(1, 0, 1)
(1, 1, 0)
(1, 1, 1)
