In [24]:
import numpy as np
import plotly.graph_objects as go
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from functools import reduce
from math import ceil, floor
from typing import TypeVar, Callable, Dict
import graphviz
import itertools
sns.set_style('darkgrid')

In [25]:
TAB_SIZE = 4
EPS = 1e-9
INF = 1e30
PRECISION = 6

In [26]:
from scipy.stats import binom
from iteround import saferound

def binom_distribution(n, p):
    arr = saferound([binom.pmf(k, n, p) for k in range(n+1)], PRECISION)
    return dict(enumerate(arr))

In [27]:
binom_distribution(10, .2)

{0: 0.107374,
 1: 0.268436,
 2: 0.30199,
 3: 0.201327,
 4: 0.08808,
 5: 0.026424,
 6: 0.005505,
 7: 0.000786,
 8: 7.4e-05,
 9: 4e-06,
 10: 0.0}

In [28]:
def AVAR(q, w, alpha):
    eval = [(w[k], q[k]) for k in q.keys()]
    res = 0
    a = alpha
    for wk, qk in sorted(eval, reverse=True):
        if np.isclose(alpha, 0, atol = EPS):
            break
        if alpha >= qk:
            res += wk*qk
            alpha -= qk
        else:
            res += wk*alpha
            alpha = 0
    return res/a

def add_qw(qw1, qw2):
    # bad error for now
    Q_sum = {}
    w_sum = {}
    for x1, q1 in qw1[0].items():
        for x2, q2 in qw2[0].items():
            qs = q1*q2
            temp = qw1[1][x1] + qw2[1][x2]
            if isinstance(x1, int) or isinstance(x1, np.int32):
                x1 = (x1,)
            if isinstance(x2, int) or isinstance(x2, np.int32):
                x2 = (x2,)
            xs = x1 + x2
            Q_sum[xs] = qs 
            w_sum[xs] = temp
    return (Q_sum, w_sum)

def AVAR_of_sum(list_qw, alpha):
    return AVAR(*reduce(add_qw, list_qw), alpha)

def sum_of_AVAR(list_qw, alpha):
    return sum([AVAR(*qw, alpha) for qw in list_qw])

In [29]:
class Node:
    def __init__(self, id: int, model, Q: Callable = None, cost: Callable = None, parent= None):
        """
            cost = (x, u) -> or x ->
            Q = (x_pa, u_pa) ->
        """
        self.id = id
        self.parent = parent
        self.Q = Q
        self.cost = cost
        self.w = {}
        self.policy = {}
        self.children = []
        self.terminal = True
        self.model = model
        self.name = str(id)
        self.one_step = None
        self.allowed_U = {x : self.model.U for x in self.model.X}

    def add_child(self, node):
        self.children.append(node)
        node.parent = self
        self.terminal = False
        
    def get_w(self):
        if self.w:
            return self.w
        if self.terminal:
            for x in self.model.X:
                self.w[x] = self.cost(x)
            return self.w
        self.calc_policy()
        return self.w

    def calc_policy(self):
        """
            Calculate the optimal policy and value for the maximal subtree rooted here
        """
        if self.terminal:
            raise Error("calc_policy ran on terminal nodes!")

        def net_one_step(x, u):
            os = self.one_step([(child.Q(x, u), child.get_w()) for child in self.children])
            # print(f'one_step ({x}, {u}): {os}')
            res = self.cost(x, u) + os
            return res

        print(f'Calculating policy for {self.name}')
        self.policy = {x: min(self.allowed_U[x], key = lambda u: net_one_step(x, u)) for x in self.model.X}
        print(f'Finished calculating policy for {self.name}')
        self.w = {x: net_one_step(x, self.policy[x]) for x in self.model.X}
    
    def print_tree(self, level = 0):
        print(" " * TAB_SIZE * level + self.name)
        for child in self.children:
            child.print_tree(level+1)

In [30]:
class Model:
    def __init__(self, lo, hi, U, alpha):
        """
            State space X = [lo, hi] of interval size = 1
            Action space U
            VaR calculation alpha
            Assume that 0 is root node
        """
        self.X = range(lo, hi + 1)
        self.lo = lo
        self.hi = hi
        self.U = U
        self.alpha = alpha
        self.nodes = [Node(0, self)]
        self.root = self.nodes[0]
        self.edge_list = []
        self.construct_graph()
        self.construct_risks()
         
    def construct_graph(self):
        raise NotImplementedError("construct_graph has not been properly implemented!")
    
    def construct_risks(self):
        raise NotImplementedError("construct_risks has not been properly implemented!")
        
    
    def bound(self, q):
        q_res = {x : 0 for x in self.X}
        for k, qk in q.items():
            q_res[max(self.lo, min(k, self.hi))] += qk
        return q_res
    
    def draw_edge(self, parent_i, child_i):
        if max(parent_i, child_i) >= len(self.nodes):
            # add nodes appropriately
            self.nodes += [Node(i, self) for i in range(len(self.nodes), max(parent_i, child_i) + 1)]
        self.nodes[parent_i].add_child(self.nodes[child_i])
        
    def visualize_graph(self):
        graph = graphviz.Graph()
        for pa, ch in self.edge_list:
            graph.edge(str(self.nodes[pa].id), str(self.nodes[ch].id))
        
        graph.render('tree', format='png', view=True)
        

In [31]:
MAX_RESOURCES = 50
MAX_CH = 3
permutations = list(itertools.product(range(MAX_RESOURCES + 1), repeat=MAX_CH))
ALL_U = [p for p in permutations if sum(p) <= MAX_RESOURCES]
print(len(ALL_U))
ALLOWED_U = {total: [p for p in ALL_U if sum(p) <= total] for total in range(MAX_RESOURCES + 1)}

23426


In [32]:
class DRModel(Model):
    def __init__ (self, lo, hi, U, alpha):
        super().__init__(lo, hi, U, alpha)
        
    def construct_graph(self):
        """
            customize graph structure here
        """
        self.edge_list = [(0, 1), (0, 2), (0, 3), (1, 4), (1, 5), (2, 6), (2, 7), (2, 8), (3, 9), (3, 10), (3, 11)]
        for pa, ch in self.edge_list:
            self.draw_edge(pa, ch)
        for node in self.nodes:
            if node.id == 0:
                node.name += "R"
                continue
            if node.id >= 1 and node.id <= 3:
                node.name += "D"
                continue
            node.name += "L"

    def construct_risks(self):    
        def q_into_d(x_pa, u_pa, ch_idx):
            alloc = u_pa[ch_idx]
            return self.bound(binom_distribution(alloc, .9))
        
        def q_into_l(x_pa, u_pa, ch_idx, necessary):
            alloc = u_pa[ch_idx]
            n_dead = max(necessary-alloc, 0)
            return self.bound(binom_distribution(n_dead, .2))
        
        def c_term(x):
            return 1000 * x

        def c(x, u, n_ch):
            return sum(u[:n_ch])
        
        for node in self.nodes:
            # assign allowed_U
            node.allowed_U = ALLOWED_U
            # assign costs
            if 'R' in node.name:
                for ch_idx, child in enumerate(node.children):
                    child.Q = lambda x_pa, u_pa, aux_ch_idx = ch_idx : q_into_d(x_pa, u_pa, aux_ch_idx)
            elif 'D' in node.name:
                for ch_idx, child in enumerate(node.children):
                    child.Q = lambda x_pa, u_pa, aux_ch_idx = ch_idx, aux_ch_id = child.id : q_into_l(x_pa, u_pa, aux_ch_idx, necessary= aux_ch_id - 3)
            
            if not node.terminal:
                node.cost = lambda x, u, aux_node = node : c(x, u, len(aux_node.children))
                node.one_step = lambda list_qw : sum_of_AVAR(list_qw, self.alpha)
            else:
                node.cost = c_term

In [33]:
drmodel = DRModel(0, MAX_RESOURCES, ALL_U, 0.3)

In [34]:
drmodel.root.calc_policy()

Calculating policy for 0R
Calculating policy for 1D
Finished calculating policy for 1D
Calculating policy for 2D
Finished calculating policy for 2D
Calculating policy for 3D
Finished calculating policy for 3D
Finished calculating policy for 0R


In [35]:
def simplify(policy, n_ch, total):
    return tuple((*policy[:n_ch], total - sum(policy) + sum(policy[n_ch:])))

In [36]:
simplified_policies = {}
for i in range(4):
    simplified_policies[i] = {k : simplify(drmodel.nodes[i].policy[k], len(drmodel.nodes[i].children), MAX_RESOURCES) for k in drmodel.nodes[i].policy}
    print(f'The consolidated policy for node {i} is:')
    print(simplified_policies[i])

The consolidated policy for node 0 is:
{0: (0, 0, 0, 50), 1: (1, 0, 0, 49), 2: (2, 0, 0, 48), 3: (3, 0, 0, 47), 4: (4, 0, 0, 46), 5: (4, 1, 0, 45), 6: (4, 2, 0, 44), 7: (2, 5, 0, 43), 8: (4, 4, 0, 42), 9: (4, 5, 0, 41), 10: (4, 6, 0, 40), 11: (4, 7, 0, 39), 12: (4, 8, 0, 38), 13: (4, 9, 0, 37), 14: (4, 10, 0, 36), 15: (4, 11, 0, 35), 16: (4, 12, 0, 34), 17: (4, 13, 0, 33), 18: (4, 14, 0, 32), 19: (4, 15, 0, 31), 20: (4, 15, 1, 30), 21: (4, 15, 2, 29), 22: (4, 10, 8, 28), 23: (4, 11, 8, 27), 24: (4, 11, 9, 26), 25: (4, 14, 7, 25), 26: (4, 14, 8, 24), 27: (4, 15, 8, 23), 28: (4, 15, 9, 22), 29: (4, 9, 16, 21), 30: (4, 10, 16, 20), 31: (4, 10, 17, 19), 32: (4, 11, 17, 18), 33: (4, 14, 15, 17), 34: (4, 14, 16, 16), 35: (4, 14, 17, 15), 36: (4, 15, 17, 14), 37: (4, 15, 18, 13), 38: (4, 14, 20, 12), 39: (4, 14, 21, 11), 40: (4, 15, 21, 10), 41: (4, 14, 23, 9), 42: (4, 14, 24, 8), 43: (4, 14, 25, 7), 44: (4, 15, 25, 6), 45: (4, 15, 26, 5), 46: (5, 15, 26, 4), 47: (5, 16, 26, 3), 48: (5, 16, 2

In [37]:
drmodel.visualize_graph()

In [50]:
simplified_policies[0]

{0: (0, 0, 0, 50),
 1: (1, 0, 0, 49),
 2: (2, 0, 0, 48),
 3: (3, 0, 0, 47),
 4: (4, 0, 0, 46),
 5: (4, 1, 0, 45),
 6: (4, 2, 0, 44),
 7: (2, 5, 0, 43),
 8: (4, 4, 0, 42),
 9: (4, 5, 0, 41),
 10: (4, 6, 0, 40),
 11: (4, 7, 0, 39),
 12: (4, 8, 0, 38),
 13: (4, 9, 0, 37),
 14: (4, 10, 0, 36),
 15: (4, 11, 0, 35),
 16: (4, 12, 0, 34),
 17: (4, 13, 0, 33),
 18: (4, 14, 0, 32),
 19: (4, 15, 0, 31),
 20: (4, 15, 1, 30),
 21: (4, 15, 2, 29),
 22: (4, 10, 8, 28),
 23: (4, 11, 8, 27),
 24: (4, 11, 9, 26),
 25: (4, 14, 7, 25),
 26: (4, 14, 8, 24),
 27: (4, 15, 8, 23),
 28: (4, 15, 9, 22),
 29: (4, 9, 16, 21),
 30: (4, 10, 16, 20),
 31: (4, 10, 17, 19),
 32: (4, 11, 17, 18),
 33: (4, 14, 15, 17),
 34: (4, 14, 16, 16),
 35: (4, 14, 17, 15),
 36: (4, 15, 17, 14),
 37: (4, 15, 18, 13),
 38: (4, 14, 20, 12),
 39: (4, 14, 21, 11),
 40: (4, 15, 21, 10),
 41: (4, 14, 23, 9),
 42: (4, 14, 24, 8),
 43: (4, 14, 25, 7),
 44: (4, 15, 25, 6),
 45: (4, 15, 26, 5),
 46: (5, 15, 26, 4),
 47: (5, 16, 26, 3),
 48: 

In [54]:
simplified_policies[3]

{0: (0, 0, 0, 50),
 1: (1, 0, 0, 49),
 2: (2, 0, 0, 48),
 3: (3, 0, 0, 47),
 4: (3, 0, 1, 46),
 5: (5, 0, 0, 45),
 6: (6, 0, 0, 44),
 7: (6, 0, 1, 43),
 8: (1, 7, 0, 42),
 9: (2, 7, 0, 41),
 10: (3, 7, 0, 40),
 11: (3, 0, 8, 39),
 12: (5, 7, 0, 38),
 13: (6, 7, 0, 37),
 14: (6, 0, 8, 36),
 15: (0, 7, 8, 35),
 16: (1, 7, 8, 34),
 17: (2, 7, 8, 33),
 18: (3, 7, 8, 32),
 19: (4, 7, 8, 31),
 20: (5, 7, 8, 30),
 21: (6, 7, 8, 29),
 22: (6, 7, 8, 29),
 23: (6, 7, 8, 29),
 24: (6, 7, 8, 29),
 25: (6, 7, 8, 29),
 26: (6, 7, 8, 29),
 27: (6, 7, 8, 29),
 28: (6, 7, 8, 29),
 29: (6, 7, 8, 29),
 30: (6, 7, 8, 29),
 31: (6, 7, 8, 29),
 32: (6, 7, 8, 29),
 33: (6, 7, 8, 29),
 34: (6, 7, 8, 29),
 35: (6, 7, 8, 29),
 36: (6, 7, 8, 29),
 37: (6, 7, 8, 29),
 38: (6, 7, 8, 29),
 39: (6, 7, 8, 29),
 40: (6, 7, 8, 29),
 41: (6, 7, 8, 29),
 42: (6, 7, 8, 29),
 43: (6, 7, 8, 29),
 44: (6, 7, 8, 29),
 45: (6, 7, 8, 29),
 46: (6, 7, 8, 29),
 47: (6, 7, 8, 29),
 48: (6, 7, 8, 29),
 49: (6, 7, 8, 29),
 50: (6, 7

In [40]:
import pickle
import time
import os
current_time = time.strftime("%m%d-%H%M%S")
directory = f"pickles/{current_time}"
if not os.path.exists(directory):
    os.makedirs(directory)
for node in drmodel.nodes:
    if node.policy:
        policy_pickle = f"{directory}/policy_{node.id}.pickle"
        with open(policy_pickle, "wb") as f:
            pickle.dump(node.policy, f)
    
    if node.w:
        w_pickle = f"{directory}/w_{node.id}.pickle"
        with open(w_pickle, "wb") as f:
            pickle.dump(node.w, f)

In [51]:
import pickle

policy_pickle = "pickles/0725-093751/policy_0.pickle"

with open(policy_pickle, "rb") as f:
    policy_0 = pickle.load(f)

policy_0

{0: (0, 0, 0),
 1: (1, 0, 0),
 2: (2, 0, 0),
 3: (3, 0, 0),
 4: (4, 0, 0),
 5: (4, 1, 0),
 6: (4, 2, 0),
 7: (2, 5, 0),
 8: (4, 4, 0),
 9: (4, 5, 0),
 10: (4, 6, 0),
 11: (4, 7, 0),
 12: (4, 8, 0),
 13: (4, 9, 0),
 14: (4, 10, 0),
 15: (4, 11, 0),
 16: (4, 12, 0),
 17: (4, 13, 0),
 18: (4, 14, 0),
 19: (4, 15, 0),
 20: (4, 15, 1),
 21: (4, 15, 2),
 22: (4, 10, 8),
 23: (4, 11, 8),
 24: (4, 11, 9),
 25: (4, 14, 7),
 26: (4, 14, 8),
 27: (4, 15, 8),
 28: (4, 15, 9),
 29: (4, 9, 16),
 30: (4, 10, 16),
 31: (4, 10, 17),
 32: (4, 11, 17),
 33: (4, 14, 15),
 34: (4, 14, 16),
 35: (4, 14, 17),
 36: (4, 15, 17),
 37: (4, 15, 18),
 38: (4, 14, 20),
 39: (4, 14, 21),
 40: (4, 15, 21),
 41: (4, 14, 23),
 42: (4, 14, 24),
 43: (4, 14, 25),
 44: (4, 15, 25),
 45: (4, 15, 26),
 46: (5, 15, 26),
 47: (5, 16, 26),
 48: (5, 16, 27),
 49: (5, 17, 27),
 50: (5, 17, 28)}