# Solving the HJB
The HJB equation, is used in dynamic programming to solve optimisation problem. Optimisation problems occur in all walks of life and some can even be solved. And some of those that can be solved are best solved with dynamic programming, a recursive evaluation of the best descission.

This notebook aims to explain by way of demonstration how to solve the HJB and fidn the optimal set of descissions, but also provide an easy to use base to formuate and solve your optimisation problems.

In [1]:
import numpy as np
import time
import matplotlib.pyplot as plt

$$
v(s) = \min_x(cost(x,s) + v(new state(x,s)))
$$

In [9]:
class DynamicProgram(object):
    """
    Generate a dynamic program to find a set of optimal descissions using the HJB.
    
    define the program by:
    
    Setting intial states via:            set_inital_state(list or int)
    Setting the number of steps via:      set_step_number(int) 
    
    Add a set of descissions:             add_decisions_set(set)
    Add a cost function:                  add_cost_function(function in terms of state )
    
    Add a state change equation:          add_state_eq(function(state))
    Add an expression for the last value: add_final_value_expression(function(state,settings))
    Add limits on the states:              add_state_limits(lower=list or int,upper = list or int)

    
    
    """
    def __init__(self):
        self.settings = {
                        'Lower state limits' : [],
                        'Upper state limits' : [],    
                        'x_set' : set(),
                        'cache' : {},}
        
    def add_state_eq(self,function):
        self.settings['State eq.'] = function
        
    def add_cost_function(self,function):
        self.settings['Cost eq.'] = function
    
    def add_final_value_expression(self, function,):
        self.settings['Final value'] = function
        
    def set_step_number(self,step_number):
        self.settings['T'] = step_number
        
    def set_inital_state(self,intial_values):
        if type(intial_values) is list:
            self.settings['Initial state'] = intial_values
            self.settings['Initial state'].insert(0,0)
        elif type(intial_values) is int:
            self.settings['Initial state'] = [intial_values]
            self.settings['Initial state'].insert(0,0)
           
        print self.settings['Initial state']
        self.settings['Initial state'] = tuple(self.settings['Initial state'])

        
    def add_state_limits(self,lower=[],upper=[]):
        """Add the limits on the state, leave empty if none"""
        if type(lower) is list:
            self.settings['Lower state limits'].extend(lower)
            self.settings['Upper state limits'].extend(upper)
        elif type(lower) is int:
            self.settings['Lower state limits'] = [lower]
            self.settings['Upper state limits'] = [upper]
        

        
    def solve(self):
        return self.hjb(self.settings['Initial state'])
    
    def retrieve_decisions(self):
        sched = np.ones(self.settings['T'])*np.nan
        cost_calc= 0
        states = []

        s = self.settings['Initial state']
        t = 0

        while t < self.settings['T']:
            print s
            sched[t] = self.settings['cache'][s][1]

            cost_calc += self.settings['Cost eq.'](sched[t],s)
            states.append(s[1:])

            s = self.settings['State eq.'](sched[t],s)
            t += 1

        states.append(s[1:])

        return cost_calc, sched, states
    
    def return_settings(self):
        return self.settings
    
    def return_cache(self):
        return self.settings['cache']
        
    def add_decisions_set(self,set_of_decisions):
        if set(set_of_decisions) != set_of_decisions:
            raise TypeError('Expected a set unique values, use set() to declare a set')
        self.settings['x_set'] = set(set_of_decisions)
       
    def hjb(self,s):
        if self.settings['cache'].has_key(s):
            return self.settings['cache'][s][0]

        # check state bounds
        for c,i in enumerate(s[1:]):
            if i < self.settings['Lower state limits'][c] or i > self.settings['Upper state limits'][c]:
                return float('inf')
            
        #Check if reached time step limit:
        if s[0] == self.settings['T']:            
            m = self.settings['Final value'](s,self.settings)
            self.settings['cache'][s] = [m, np.nan]
            return m

        # Else enter recursion
        else:
            p=[]
            for x in self.settings['x_set']:
                p.append(self.settings['Cost eq.'](x,s)+self.hjb(self.settings['State eq.'](x,s)))

            m = min(p)

            for x in self.settings['x_set']:
                if m == p[x]:
                    pp = x

            self.settings['cache'][s] = [m, pp]

            return m


We can solve a very simple pump optimsiation where the state of water in a tank is given by h and described by:
$$
s_{new} = \begin{cases} (t+1,h-1) & \text{if } x = 0  \\ (t+1,h+1) & \text{if } x = 1 \\ (t+1,h+1.5) & \text{if } x = 2\end{cases}
$$
The operating cost are described by:
$$
cost = tarrif(t)\times x
$$
where $x$ is the descission variable.
The final value is given by:
$$
V_T = \begin{cases} 0 & \text{if: } h_T \geq h_0 \\ Inf &\text{otherwise} \end{cases}
$$

    
    

In [10]:
def simple_cost(x,s):
    tariff = [19,  8,  20,  3, 12, 14,  0,  4,  3, 13, 11, 13, 13, 11, 16, 14, 16,
       19,  1,  8,  0,  4,   0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
       12,  3, 18, 15,  3, 10, 12,  6,  3,  5, 11,  0, 11,  8, 10, 11,  5,
       15,  8,  2,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        9, 10, 13,  7,  7,  1, 12,  2,  2,  1,  5,  8,  4,  0, 11,  2,  5,
       16,  8,  1, 17, 16,  3,  0,  4, 16,  0,  7]
    return tariff[s[0]]*x

def val_T(s,settings):
    if s[1] < settings['Initial state'][1]:
        return float('inf')
    else:
        return 0
    
def simple_state(x,s):
    #print s
    if x == 0:
        return (s[0]+1,s[1]-1)
    elif x == 1:
        return (s[0]+1,s[1]+1)
    elif x == 2:
        return (s[0]+1,s[1]+1.5)

In [19]:
pumping = DynamicProgram()
pumping.set_step_number(3)
pumping.add_decisions_set({0,1,2})
pumping.add_cost_function(simple_cost)
pumping.add_state_eq(simple_state)
pumping.add_final_value_expression(val_T)
pumping.add_state_limits(lower=0,upper = 200)
pumping.set_inital_state(100)
print pumping.return_settings()
print pumping.solve()
print pumping.retrieve_decisions()

[0, 100]
{'Lower state limits': [0], 'Cost eq.': <function simple_cost at 0x0000000006C1D898>, 'x_set': set([0, 1, 2]), 'Initial state': (0, 100), 'Upper state limits': [200], 'State eq.': <function simple_state at 0x0000000006C1D828>, 'cache': {}, 'T': 3, 'Final value': <function val_T at 0x0000000006A88A58>}
27
(0, 100)
(1, 101)
(2, 102)
(27.0, array([ 1.,  1.,  0.]), [(100,), (101,), (102,), (101,)])


In [12]:
pumping.set_inital_state(199)
print pumping.return_settings()
print pumping.solve()
print pumping.retrieve_decisions()

[0, 199]
{'Lower state limits': [0], 'Cost eq.': <function simple_cost at 0x0000000006C1D898>, 'x_set': set([0, 1, 2]), 'Initial state': (0, 199), 'Upper state limits': [200], 'State eq.': <function simple_state at 0x0000000006C1D828>, 'cache': {(2, 103.0): [0, 0], (0, 100): [27, 1], (1, 99): [28, 1], (3, 99): [inf, nan], (2, 100.5): [20, 1], (3, 99.5): [inf, nan], (3, 97): [inf, nan], (2, 102.5): [0, 0], (3, 103.5): [0, nan], (3, 104.0): [0, nan], (3, 101.5): [0, nan], (2, 102): [0, 0], (1, 101): [8, 1], (3, 102.0): [0, nan], (2, 98): [inf, 2], (3, 103): [0, nan], (1, 101.5): [8, 1], (3, 104.5): [0, nan], (2, 100): [20, 1], (3, 101): [0, nan]}, 'T': 3, 'Final value': <function val_T at 0x0000000006A88A58>}
28
(0, 199)
(1, 198)
(2, 199)
(28.0, array([ 0.,  1.,  1.]), [(199,), (198,), (199,), (200,)])


We can add a second tank and now pump to either of them:

$$
s_{new} = \begin{cases} (t+1,h_1-1,h_2-1) & \text{if } x = 0  \\ (t+1,h_1+1,h_2) & \text{if } x = 1 \\ (t+1,h_1,h_2+1.5) & \text{if } x = 2\end{cases}
$$

In [16]:
def simple_state2(x,s):
    if x == 0:
        return (s[0]+1,s[1]-1,s[2]-1)
    elif x == 1:
        return (s[0]+1,s[1]+1,s[2])
    elif x == 2:
        return (s[0]+1,s[1]  ,s[2]+2)
    
def val_T2(s,settings):
    if s[1] < settings['Initial state'][1] or s[2] < settings['Initial state'][2]:
        return float('inf') 
    else:
        return 0
    

print simple_state2(0,(0,2,2))
print simple_state2(1,(0,2,2))
print simple_state2(2,(0,2,2))

(1, 1, 1)
(1, 3, 2)
(1, 2, 4)


In [17]:
pumping2 = DynamicProgram()
pumping2.add_decisions_set({0,1,2})
pumping2.add_cost_function(simple_cost)
pumping2.add_state_eq(simple_state2)
pumping2.add_final_value_expression(val_T2)
pumping2.add_state_limits(lower=[0,0],upper = [200,200])
pumping2.set_inital_state([100,100])
pumping2.set_step_number(3)
print pumping2.settings
print pumping2.solve()  
print pumping2.return_cache()
print pumping2.retrieve_decisions()


[0, 100, 100]
{'Lower state limits': [0, 0], 'Cost eq.': <function simple_cost at 0x0000000006C1D898>, 'x_set': set([0, 1, 2]), 'Initial state': (0, 100, 100), 'Upper state limits': [200, 200], 'State eq.': <function simple_state2 at 0x0000000006C1DB38>, 'cache': {}, 'T': 3, 'Final value': <function val_T2 at 0x0000000006C1D978>}
35
{(2, 98, 98): [inf, 2], (0, 100, 100): [35, 1], (3, 97, 97): [inf, nan], (2, 101, 102): [0, 0], (3, 99, 98): [inf, nan], (2, 100, 99): [40, 2], (3, 99, 103): [inf, nan], (2, 102, 100): [20, 1], (3, 102, 102): [0, nan], (3, 101, 99): [inf, nan], (3, 100, 106): [0, nan], (2, 100, 104): [20, 1], (2, 99, 101): [20, 1], (3, 103, 100): [0, nan], (3, 100, 101): [0, nan], (1, 100, 102): [8, 1], (1, 101, 100): [16, 2], (3, 98, 100): [inf, nan], (1, 99, 99): [36, 2], (3, 101, 104): [0, nan]}
(0, 100, 100)
(1, 101, 100)
(2, 101, 102)
(35.0, array([ 1.,  2.,  0.]), [(100, 100), (101, 100), (101, 102), (100, 101)])


In [47]:
i1 = [99,]
i2 = [22, 33]
i3 = 99
t = 0
s = 
print s

TypeError: can only concatenate list (not "int") to list

In [24]:
print pumping2.return_settings()

{'Lower state limits': [[0, 0]], 'Cost eq.': <function simple_cost at 0x0000000006BE5CF8>, 'x_set': set([0, 1, 2]), 'Initial state': (0, [600, 190]), 'Upper state limits': [[200, 150]], 'State eq.': <function simple_state2 at 0x0000000006BE5E48>, 'cache': {}, 'T': 3, 'Final value': <function val_T2 at 0x0000000006BE5EB8>}


(0, 100, 80)
Time:  1299.57239572
35


In [56]:
print settings['Initial state']

(0, 100)


(7, 1, 2, 3)


In [54]:
time.clock()
settings = {
    'T' : 3,
    'H_init' : 100,    
    'H_max' : 200,
    'H_min' : 0,
    'x_set' : [0,1,2],
    'cache' : {},
    'State eq.' : stoch_simple_state,
    'Cost eq.' :err_corr_wind_power_cost, 
}
#settings['Initial state'] = (0,settings['H_init'])
# 'Cost eq.' : stoch_wind_power_cost,
settings['Initial state'] = (0,settings['H_init'],0)
settings['P'] = [[0.4,0.5,0.1],[0.2,0.6,0.2],[0.1,0.5,0.4]]
V = stoch_hjb(settings['Initial state'])
print "Time: ", time.clock() 
print V

Time:  7357.13816587
345.35


The cost of operating a pump with a given wind turbine power input is given by:
$$
cost(x,t,h) := \begin{cases} T(t) \times (x \times P_p - W(t)) & \text{if} +ve \\ 
E_{xp} \times (x \times P_p - W(t)) & \text{if} -ve\end{cases}
$$
where $x$ is the descision variable, $W(t)$ is the wind turbine output in time step $t$. $P_p$ is the pump power, $E_{xp}$ is the export price.

In [4]:
def wind_power_cost(x,s):  
    """Very simple cost function for a pump with wind turbine power"""
    Tariff = [5,5,5,5,5,8,8,8,8,8,12,12,12,12,12,50,50,50,50,20,20,6,5,5]
    Wind = [46,  1,  3, 36, 30, 19,  9, 26, 35,  5, 49,  3,  6, 36, 43, 36, 14,
       34,  2,  0,  0, 30, 13, 36]
    
    Export_price = 5.5
        
    power_con = x*60-Wind[s[0]]
    if power_con >= 0:
        return power_con*Tariff[s[0]]
    else:
        return power_con*Export_price

    

The cost of operating a pump with a given wind turbine power input given by a certain state is given by:
$$
cost(x,t,h,j) := \begin{cases} T(t) \times (x \times P_p - W(t,j)) & \text{if} +ve \\ 
E_{xp} \times (x \times P_p - W(t,j)) & \text{if} -ve\end{cases}
$$
where $W(t,j)$ is the wind power output at time $t$ with an error state $j$.

In [26]:
# Convention s = t,h,j
def err_corr_wind_power_cost(x,s):
    Tariff = [5,5,5,5,5,8,8,8,8,8,12,12,12,12,12,50,50,50,50,20,20,6,5,5]
    Wind = [46,  1,  3, 36, 30, 19,  9, 26, 35,  5, 49,  3,  6, 36, 43, 36, 14,
       34,  2,  0,  0, 30, 13, 36]
    
    diff = np.array([-1,0,1])*3
    
    Export_price = 5.5
    
    wind_out = Wind[s[0]]+diff[s[2]]
    if wind_out <= 0:
        wind_out = 0
    
    power_con = x*60-wind_out
    if power_con >= 0:
        return power_con*Tariff[s[0]]
    else:
        return power_con*Export_price

In [27]:
def stoch_wind_power_cost(x,s):
    P = [[0.4,0.5,0.1],[0.2,0.6,0.2],[0.1,0.5,0.4]]
    
    Tariff = [5,5,5,5,5,8,8,8,8,8,12,12,12,12,12,50,50,50,50,20,20,6,5,5]
    Wind = [46,  1,  3, 36, 30, 19,  9, 26, 35,  5, 49,  3,  6, 36, 43, 36, 14,
       34,  2,  0,  0, 30, 13, 36]
    
    Export_price = 5.5
    
    wind_out = sum(Wind[s[0]]*i for i in settings['P'][s[2]])
    
    power_con = x*60-wind_out
    if power_con >= 0:
        return power_con*Tariff[s[0]]
    else:
        return power_con*Export_price

$$
s_{new} = \begin{cases} (t+1,h-1,i) & \text{if } x = 0  \\ (t+1,h+1,i) & \text{if } x = 1 \\ (t+1,h+1.5,i) & \text{if } x = 2\end{cases}
$$

In [30]:
def stoch_simple_state(x,s):
    if x == 0:
        return (s[0]+1,s[1]-1,s[2])
    elif x == 1:
        return (s[0]+1,s[1]+1,s[2])
    elif x == 2:
        return (s[0]+1,s[1]+1.5,s[2])
    
assert(stoch_simple_state(0,(5,100,2)) == (6,99,2))
assert(stoch_simple_state(1,(6,100,2)) == (7,101,2))
assert(stoch_simple_state(2,(9,100,6)) == (10,101.5,6))

In [32]:
assert(hjb((settings['T'],settings['H_init']-1)) == 10000)
assert(hjb((settings['T'],settings['H_init'])) == 0)
assert(hjb((settings['T']-1,settings['H_min']-1)) == 10000)
assert(hjb((settings['T']-1,settings['H_max']+1)) == 10000)

state is given by $s = (t,h,i)$

$$
v(s) = \min_x(cost(x,s) + \sum_j p_{ij} v(new state(x,s)))
$$

In [49]:
def stoch_hjb(s):
    global settings
    if settings['cache'].has_key(s):
        return settings['cache'][s][0]
    

  
    
    if s[0] == settings['T'] and s[1] < settings['H_init']:
        return 10000
    
    elif s[0] == settings['T'] and s[1] >= settings['H_init']:
        return 0
    
    elif s[1] < settings['H_min'] or s[1] > settings['H_max']:
        return 10000
    
    else:
        p=[]
        for x in settings['x_set']:
            future = sum(stoch_hjb(settings['State eq.'](x,(s[0],s[1],i)))
                          *settings['P'][s[2]][i] for i in [0,1,2])
            
            p.append(settings['Cost eq.'](x,s) + future)
                    
        
        m = min(p)
        
        for x in settings['x_set']:
            if m == p[x]:
                pp = x
                
        settings['cache'][s] = [m, pp]
        
        return m

In [34]:
def make_schedule(settings):
    sched = np.ones(settings['T'])*np.nan
    cost_calc= 0
    elev = np.ones(settings['T']+1)*np.nan

    s = settings['Initial state']
    t = 0
    
    while t < settings['T']:
        sched[t] = settings['cache'][s][1]
        print sched[t]

        cost_calc += settings['Cost eq.'](sched[t],s)
        elev[t] = s[1]

        s = settings['State eq.'](sched[t],s)
        t += 1

    elev[settings['T']] = s[1]
    
    return cost_calc, sched, elev

In [112]:
def make_schedule2(settings):
    
    sched_stack =  []
    cost_summary = []
    string_stack = []
    elev = np.ones(settings['T']+1)*np.nan
    for ij in [0,1,2]:
        cost_calc = 0
        string_stack.insert(i,[])
        s = settings['Initial state']
        t = 0
        #string_stack[ij].insert(0,[])
        #string_stack[ij].insert(0,'{0:2} {1} {2}'.format(t,settings['cache'][s][1], s[1:]))
        string_stack[ij].insert(0,'{0}'.format(s))
        while t < settings['T']:
            
            state = tuple(sk if sk in s[:-1] else ij for sk in s  )
            print s, s[:-1], state
            x = settings['cache'][state][1]
            print x
            cost_calc += settings['Cost eq.'](x,s)
            elev[t] = s[1]

            s = settings['State eq.'](x,s)
            t += 1
            string_stack[ij].insert(t,'{0} {1} {2}'.format(x,s[1:],ij) )


    elev[settings['T']] = s[1]
    
    return string_stack, cost_calc

In [113]:
make_schedule2(settings)

(0, 100, 0) (0, 100) (0, 100, 0)
0
(1, 99, 0) (1, 99) (1, 99, 0)
1
(2, 100, 0) (2, 100) (2, 100, 0)
1
(0, 100, 0) (0, 100) (0, 100, 0)
0
(1, 99, 0) (1, 99) (1, 99, 1)
1
(2, 100, 0) (2, 100) (2, 100, 1)
1
(0, 100, 0) (0, 100) (0, 100, 0)
0
(1, 99, 0) (1, 99) (1, 99, 2)
1
(2, 100, 0) (2, 100) (2, 100, 2)
1


([['(0, 100, 0)', '0 (99, 0) 0', '1 (100, 0) 0', '1 (101, 0) 0'],
  ['(0, 100, 0)', '0 (99, 0) 1', '1 (100, 0) 1', '1 (101, 0) 1'],
  ['(0, 100, 0)', '0 (99, 0) 2', '1 (100, 0) 2', '1 (101, 0) 2']],
 363.5)

In [74]:
settings['cache']

{(0, 100, 80): [35, 1],
 (1, 99, 79): [36, 2],
 (1, 100, 81.5): [8, 1],
 (1, 101, 80): [16, 2],
 (2, 98, 78): [10000, 0],
 (2, 99, 80.5): [20, 1],
 (2, 100, 79): [40, 2],
 (2, 100, 83.0): [20, 1],
 (2, 101, 81.5): [0, 0],
 (2, 102, 80): [20, 1]}

In [75]:
A = []
for i in [0,1,2]:
    A.insert(i,[])
    print A
    for t in range(5):
        A[i].insert(t,i*t**2)

        
A

[[]]
[[0, 0, 0, 0, 0], []]
[[0, 0, 0, 0, 0], [0, 1, 4, 9, 16], []]


[[0, 0, 0, 0, 0], [0, 1, 4, 9, 16], [0, 2, 8, 18, 32]]

In [66]:
cost_calc, sched, elev = make_schedule(settings)
print sched
print elev
print cost_calc

1.0
1.0
1.0
1.0
1.0
1.0
0.0
1.0
1.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
1.0
1.0
[ 1.  1.  1.  1.  1.  1.  0.  1.  1.  0.  1.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  1.  1.  1.]
[ 100.  101.  102.  103.  104.  105.  106.  105.  106.  107.  106.  107.
  106.  105.  104.  103.  102.  101.  100.   99.   98.   97.   98.   99.
  100.]
1353.0


In [55]:
settings['cache'][settings['Initial state']]

[28, 0]

In [20]:
cost_calc, sched, elev = make_schedule(settings)
print sched
print elev
print cost_calc

[ 1.  1.  0.]
[ 100.  101.  102.  101.]
27.0


In [39]:
dic = {'blub': simple_cost
      }
dic['blub']

<function __main__.simple_cost>

In [40]:
dic['blub'](1,(0,1))

19

In [11]:
len([5,5,5,5,5,8,8,8,8,8,12,12,12,12,12,50,50,50,50,20,20,6,5,5])

24

In [12]:
np.random.randint(50,size=24)

array([46,  1,  3, 36, 30, 19,  9, 26, 35,  5, 49,  3,  6, 36, 43, 36, 14,
       34,  2,  0,  0, 30, 13, 36])

In [47]:
x = [100,7,90,787]
x_lim_min = [0,10,0,0]
x_lim_max = [100,10,100,1000]
for c,i in enumerate(x):
    if i < x_lim_min[c] or i > x_lim_max[c]:
        print 1000
    

1000
