In [1]:
import numpy as np
import collections

# Definition of the states, accions y parameters
alpha = 0.3 # 30%

down = -5
up = 5
# State space
X = [i for i in range (down, up+1)]

# Control space 
U = [1, 0]

# Latest time
T = 5

In [2]:
def AVAR(q, w, alpha):
    # we recive w = w_{t+1}

    # this is something like {w(1): 20%, w(): 60%}
    eval = [(w[max(down, min(k, up))], q[k])  for k in q.keys()]
    s = sorted(eval, reverse=True)
    res = 0
    i = 0
    a = alpha
    while(alpha > 0):
        if alpha >= s[i][1]:
            res += s[i][0]*s[i][1]
            alpha -= s[i][1]
        else:
            res += s[i][0]*alpha
            alpha = 0
        i += 1
    return res/a

In [3]:
def precal_Q(X):
    Q0 = dict()
    Q1 = dict()
    for x in X:
        Q0[x] = dict()
        Q0[x][0] = {x-1: 0.2, x-1: 0.2, x: 0.2, x+1: 0.2, x+2: 0.2}
        Q0[x][1] = {x+1: 0.4, x+2: 0.2, x+3: 0.4}
        
        Q1[x] = dict()
        Q1[x][0] = {x-1: 0.6, x: 0.2, x+1: 0.2}
        Q1[x][1] = {x-1: 0.2, x: 0.4, x+1: 0.4}
    return Q0,Q1

Q_0, Q_1 = precal_Q(X)

def Q0(x, u):
    return Q_0[x][u]

def Q1(x, u):
    return Q_1[x][u]

In [4]:
def cost0(x, a, t):
    # x = state
    # a = action of this node
    if a == 0:
        return 0
    if a == 1:
        return 1

def cost1(x, a, t):
    # x = state
    # a = action of this node
    return np.exp(-x) - t 

In [5]:
class Node:
    def __init__(self, t, id, parent = None, Q = None, cost = None):
        self.t = t # Time
        self.id = id
        self.parent = parent
        self.Q = Q
        self.cost = cost
        self.w = dict() # cost
        self.childs = []
        self.policy = dict() # One for each of the posible states 
        self.terminal = True # If it is a terminal node
        

    def new_child(self, node):
        self.childs.append(node)# Add the new child
        node.parent = self
        self.terminal = False 

    
    def get_w(self):
        if self.w.__len__() == 0:
            if self.terminal:
                for x in X: 
                    self.w[x] = self.cost(x, 1, self.t) # Cost of terminal nodes in this case
                    # This w is w_{t, s}
            else:
                self.calc_policy()
        
        return self.w


    def calc_policy(self):
        self.policy = {x: min(U, key=lambda u: self.cost(x, u, self.t) + sum([AVAR(self.Q(x, u), child.get_w(), alpha) for child in self.childs])) for x in X}
        self.w = {x: (self.cost(x, self.policy[x], self.t) + sum([AVAR(self.Q(x, self.policy[x]), child.get_w(), alpha) for child in self.childs])) for x in X}
        #print(f'id: {self.id}, u: {self.policy}')

    def print_tree(self, level = 0):
        print(" " * 4 * level + f'{self.id}, x_{self.t, self.id % 2}')
        for child in self.childs:
            child.print_tree(level = level+1)
            

In [6]:
# Create tree

# ad_list is symilar to: [(0, 1), (0, 2), (1, 3), ...]
# With the relation that if we have (0,1), then 0 is parent of 1
# To use this you have to have a node list and a list of which type they are
def create_tree(nodes, ad_list):
    root = nodes[0]
    for ad in ad_list:
        nodes[ad[0]].new_child(nodes[ad[1]])
    return root


#adyacent list
ad_list = []

#The root
root = Node(0, id=0)
root.Q = Q0
root.cost = cost0

# List of nodes
nodes = [root]

# We create the adyacent list for this case
for i in range(0, T):
    new1 = Node(i+1, id=2*i+1, Q=Q1, cost=cost1)
    new2 = Node(i+1, id=2*i+2, Q=Q0, cost=cost0)
    nodes.append(new1)
    nodes.append(new2)
    ad_list.append((2*i, 2*i+1))
    ad_list.append((2*i, 2*i+2))


# Add the last node
new = Node(T+1, id = 2*T+1, Q=Q1, cost=cost1)
nodes.append(new)


l = len(ad_list)
ad_list.append((l, l+1))

print(ad_list)


[(0, 1), (0, 2), (2, 3), (2, 4), (4, 5), (4, 6), (6, 7), (6, 8), (8, 9), (8, 10), (10, 11)]


In [7]:
create_tree(nodes, ad_list)

<__main__.Node at 0x22e7fde6850>

In [8]:
root.print_tree()

0, x_(0, 0)
    1, x_(1, 1)
    2, x_(1, 0)
        3, x_(2, 1)
        4, x_(2, 0)
            5, x_(3, 1)
            6, x_(3, 0)
                7, x_(4, 1)
                8, x_(4, 0)
                    9, x_(5, 1)
                    10, x_(5, 0)
                        11, x_(6, 1)


In [9]:
print(root.get_w())

id: 10, u: {-5: 1, -4: 1, -3: 1, -2: 1, -1: 1, 0: 1, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
id: 8, u: {-5: 1, -4: 1, -3: 1, -2: 1, -1: 1, 0: 1, 1: 1, 2: 0, 3: 0, 4: 0, 5: 0}
id: 6, u: {-5: 1, -4: 1, -3: 1, -2: 1, -1: 1, 0: 1, 1: 1, 2: 1, 3: 0, 4: 0, 5: 0}
id: 4, u: {-5: 1, -4: 1, -3: 1, -2: 1, -1: 1, 0: 1, 1: 1, 2: 1, 3: 0, 4: 0, 5: 0}
id: 2, u: {-5: 1, -4: 1, -3: 1, -2: 1, -1: 1, 0: 1, 1: 1, 2: 1, 3: 1, 4: 0, 5: 0}
id: 0, u: {-5: 1, -4: 1, -3: 1, -2: 1, -1: 1, 0: 1, 1: 1, 2: 1, 3: 1, 4: 0, 5: 0}
{-5: 71.15890432489304, -4: 16.35004743880595, -3: -4.099082626342417, -2: -11.86515505328169, -1: -15.110996110499265, 0: -16.619279367467968, 1: -17.508292713468002, 2: -18.239287892958995, 3: -18.840461346023297, 4: -19.43967904398043, 5: -20.037685494650486}


In [10]:
print(root.policy)

{-5: 1, -4: 1, -3: 1, -2: 1, -1: 1, 0: 1, 1: 1, 2: 1, 3: 1, 4: 0, 5: 0}
