In [16]:
import numpy as np

# Definition of the states, accions y parameters
alpha = 0.3 # 30%

down = -50
up = 50
# State space
X = np.arange(down, up+1)

# Control space 
U = [1, 0]

# Latest time
T = 100

In [17]:
def AVAR(q, w, alpha):
    # we recive w = w_{t+1}

    # this is something like {w(1): 20%, w(): 60%}
    eval = [(w[max(down, min(k, up))], q[k])  for k in q.keys()]
    s = sorted(eval, reverse=True)
    res = 0
    i = 0
    a = alpha
    while(alpha > 0):
        if alpha >= s[i][1]:
            res += s[i][0]*s[i][1]
            alpha -= s[i][1]
        else:
            res += s[i][0]*alpha
            alpha = 0
        i += 1
    return res/a

In [18]:
def precal_Q(X):
    Q0 = dict()
    Q1 = dict()
    for x in X:
        Q0[x] = dict()
        Q0[x][0] = {x-2: 0.2, x-1: 0.2, x: 0.2, x+1: 0.2, x+2: 0.2}
        Q0[x][1] = {x+1: 0.4, x+2: 0.2, x+3: 0.4}
        
        Q1[x] = dict()
        Q1[x][0] = {x-1: 0.6, x: 0.2, x+1: 0.2}
        Q1[x][1] = {x-1: 0.2, x: 0.4, x+1: 0.4}
    return Q0,Q1

Q_0, Q_1 = precal_Q(X)

def Q0(x, u):
    return Q_0[x][u]

def Q1(x, u):
    return Q_1[x][u]

In [19]:
def cost0(x, u, t):
    # x = state
    # a = action of this node
    if u == 0:
        return 0
    if u == 1:
        return 1

def cost1(x, u, t):
    # x = state
    # a = action of this node
    return np.exp(-x/20) - t 

In [20]:
class Node:
    def __init__(self, t, id, parent = None, Q = None, cost = None):
        self.t = t # Time
        self.id = id
        self.parent = parent
        self.Q = Q
        self.cost = cost
        self.w = dict() # cost
        self.childs = []
        self.policy = dict() # One for each of the posible states 
        self.terminal = True # If it is a terminal node
        

    def new_child(self, node):
        self.childs.append(node)# Add the new child
        node.parent = self
        self.terminal = False

    def reset_w(self):
        self.w={}
    #for value iteration
    def change_w(self,new_w,gamma):
        for key in new_w.keys():
            self.w[key]=new_w[key]*(gamma**self.t) 
    #consider a node to be its own descendant
    def hasDescendant(self,nodes):
        if self in nodes:
            return True
        for child in self.childs:
            if (child.hasDescendant(nodes)):
                return True
        return False


    
    def get_w(self):
        if self.w.__len__() == 0:
            if self.terminal:
                for x in X: 
                    self.w[x] = self.cost(x, 1, self.t) # Cost of terminal nodes in this case
                    # This w is w_{t, s}
            else:
                self.calc_policy()
        
        return self.w

    def calc_policy(self):
        self.policy = {x: min(U, key=lambda u: self.cost(x, u, self.t) + sum([AVAR(self.Q(x, u), child.get_w(), alpha) for child in self.childs])) for x in X}
        self.w = {x: (self.cost(x, self.policy[x], self.t) + sum([AVAR(self.Q(x, self.policy[x]), child.get_w(), alpha) for child in self.childs])) for x in X}
        #print(f'id: {self.id}, u: {self.policy}')

    def print_tree(self, level = 0):
        print(" " * 4 * level + f'{self.id}, x_{self.t, self.id % 2}')
        for child in self.childs:
            child.print_tree(level = level+1)
            

In [21]:
# Create tree

# ad_list is symilar to: [(0, 1), (0, 2), (1, 3), ...]
# With the relation that if we have (0,1), then 0 is parent of 1
# To use this you have to have a node list and a list of which type they are
def create_tree(nodes, ad_list):
    root = nodes[0]
    for ad in ad_list:
        nodes[ad[0]].new_child(nodes[ad[1]])
    return root


#adyacent list
ad_list = []

#The root
root = Node(0, id=0)
root.Q = Q0
root.cost = cost0

# List of nodes
nodes = [root]

# We create the adyacent list for this case
for i in range(0, T):
    new1 = Node(i+1, id=2*i+1, Q=Q1, cost=cost1)
    new2 = Node(i+1, id=2*i+2, Q=Q0, cost=cost0)
    nodes.append(new1)
    nodes.append(new2)
    ad_list.append((2*i, 2*i+1))
    ad_list.append((2*i, 2*i+2))


# Add the last node
new = Node(T+1, id = 2*T+1, Q=Q1, cost=cost1)
nodes.append(new)


l = len(ad_list)
ad_list.append((l, l+1))

print(ad_list)


[(0, 1), (0, 2), (2, 3), (2, 4), (4, 5), (4, 6), (6, 7), (6, 8), (8, 9), (8, 10), (10, 11), (10, 12), (12, 13), (12, 14), (14, 15), (14, 16), (16, 17), (16, 18), (18, 19), (18, 20), (20, 21), (20, 22), (22, 23), (22, 24), (24, 25), (24, 26), (26, 27), (26, 28), (28, 29), (28, 30), (30, 31), (30, 32), (32, 33), (32, 34), (34, 35), (34, 36), (36, 37), (36, 38), (38, 39), (38, 40), (40, 41), (40, 42), (42, 43), (42, 44), (44, 45), (44, 46), (46, 47), (46, 48), (48, 49), (48, 50), (50, 51), (50, 52), (52, 53), (52, 54), (54, 55), (54, 56), (56, 57), (56, 58), (58, 59), (58, 60), (60, 61), (60, 62), (62, 63), (62, 64), (64, 65), (64, 66), (66, 67), (66, 68), (68, 69), (68, 70), (70, 71), (70, 72), (72, 73), (72, 74), (74, 75), (74, 76), (76, 77), (76, 78), (78, 79), (78, 80), (80, 81), (80, 82), (82, 83), (82, 84), (84, 85), (84, 86), (86, 87), (86, 88), (88, 89), (88, 90), (90, 91), (90, 92), (92, 93), (92, 94), (94, 95), (94, 96), (96, 97), (96, 98), (98, 99), (98, 100), (100, 101), (100,

In [22]:
create_tree(nodes, ad_list)

<__main__.Node at 0x193ce6c6230>

In [23]:
root.print_tree() 

0, x_(0, 0)
    1, x_(1, 1)
    2, x_(1, 0)
        3, x_(2, 1)
        4, x_(2, 0)
            5, x_(3, 1)
            6, x_(3, 0)
                7, x_(4, 1)
                8, x_(4, 0)
                    9, x_(5, 1)
                    10, x_(5, 0)
                        11, x_(6, 1)
                        12, x_(6, 0)
                            13, x_(7, 1)
                            14, x_(7, 0)
                                15, x_(8, 1)
                                16, x_(8, 0)
                                    17, x_(9, 1)
                                    18, x_(9, 0)
                                        19, x_(10, 1)
                                        20, x_(10, 0)
                                            21, x_(11, 1)
                                            22, x_(11, 0)
                                                23, x_(12, 1)
                                                24, x_(12, 0)
                                                    25,

In [24]:
print(root.get_w())

{-50: -4825.7958283765465, -49: -4837.561363574834, -48: -4848.79998849056, -47: -4859.490499201325, -46: -4869.659627552347, -45: -4879.377404818307, -44: -4888.625407912985, -43: -4897.42238057452, -42: -4905.831109024217, -41: -4913.837720316184, -40: -4921.453844567646, -39: -4928.737404304046, -38: -4935.675632968161, -37: -4942.275480227383, -36: -4948.592113285536, -35: -4954.610786941753, -34: -4960.33592642001, -33: -4965.821808352411, -32: -4971.048950440037, -31: -4976.021161799831, -30: -4980.793475761833, -29: -4985.339211589216, -28: -4989.665113971418, -27: -4993.822831897348, -26: -4997.781889103117, -25: -5001.553419371624, -24: -5005.170318070568, -23: -5008.615413776039, -22: -5011.905965266083, -21: -5015.057105166878, -20: -5018.062117347892, -19: -5020.936580245518, -18: -5023.687419329344, -17: -5026.314470829074, -16: -5028.829592573996, -15: -5031.236551454995, -14: -5033.538434199812, -13: -5035.743681861242, -12: -5037.855149917003, -11: -5039.877058570571, -

In [25]:
for node in nodes:
    if not node.terminal:
        print(f'id: {node.id}, u: {node.policy}')

id: 0, u: {-50: 1, -49: 1, -48: 1, -47: 1, -46: 1, -45: 1, -44: 1, -43: 1, -42: 1, -41: 1, -40: 1, -39: 1, -38: 1, -37: 1, -36: 1, -35: 1, -34: 1, -33: 1, -32: 1, -31: 1, -30: 1, -29: 1, -28: 1, -27: 1, -26: 1, -25: 1, -24: 1, -23: 1, -22: 1, -21: 1, -20: 1, -19: 1, -18: 1, -17: 1, -16: 1, -15: 1, -14: 1, -13: 1, -12: 1, -11: 1, -10: 1, -9: 1, -8: 1, -7: 1, -6: 1, -5: 1, -4: 1, -3: 1, -2: 1, -1: 1, 0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1, 12: 1, 13: 1, 14: 1, 15: 1, 16: 1, 17: 1, 18: 1, 19: 1, 20: 1, 21: 1, 22: 1, 23: 1, 24: 1, 25: 1, 26: 1, 27: 1, 28: 1, 29: 1, 30: 1, 31: 1, 32: 1, 33: 1, 34: 1, 35: 1, 36: 1, 37: 1, 38: 1, 39: 1, 40: 1, 41: 1, 42: 1, 43: 1, 44: 1, 45: 1, 46: 1, 47: 1, 48: 1, 49: 1, 50: 0}
id: 2, u: {-50: 1, -49: 1, -48: 1, -47: 1, -46: 1, -45: 1, -44: 1, -43: 1, -42: 1, -41: 1, -40: 1, -39: 1, -38: 1, -37: 1, -36: 1, -35: 1, -34: 1, -33: 1, -32: 1, -31: 1, -30: 1, -29: 1, -28: 1, -27: 1, -26: 1, -25: 1, -24: 1, -23: 1, -22: 1, -21: 1, 

In [26]:
import plotly.graph_objects as go

# Crear la figura
fig = go.Figure()

for node in nodes:
    x = list(node.policy.keys())
    y = list(node.policy.values())
    fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name=f'Time {node.t}'))

fig.update_layout(
    title='Policies by Node',
    xaxis_title='Policy',
    yaxis_title='Value',
    legend_title='Node ID',
)

fig.show()


In [27]:
change = {}
for node in nodes:
    if node.terminal:
           print(node.id)
           continue
    prev = 1
    for x in X:
      #print(x, n.policy[x])
      if node.policy[x] != prev:
          #print(node.t, x)
          change[node.t] = x
      prev = node.policy[x]

1
3
5
7
9
11
13
15
17
19
21
23
25
27
29
31
33
35
37
39
41
43
45
47
49
51
53
55
57
59
61
63
65
67
69
71
73
75
77
79
81
83
85
87
89
91
93
95
97
99
101
103
105
107
109
111
113
115
117
119
121
123
125
127
129
131
133
135
137
139
141
143
145
147
149
151
153
155
157
159
161
163
165
167
169
171
173
175
177
179
181
183
185
187
189
191
193
195
197
199
201


In [28]:
change

{0: 50,
 1: 50,
 2: 50,
 3: 50,
 4: 50,
 5: 50,
 6: 50,
 7: 50,
 8: 50,
 9: 50,
 10: 50,
 11: 50,
 12: 50,
 13: 50,
 14: 50,
 15: 50,
 16: 50,
 17: 50,
 18: 50,
 19: 50,
 20: 50,
 21: 50,
 22: 50,
 23: 50,
 24: 50,
 25: 50,
 26: 50,
 27: 50,
 28: 50,
 29: 50,
 30: 50,
 31: 50,
 32: 50,
 33: 50,
 34: 50,
 35: 50,
 36: 50,
 37: 50,
 38: 50,
 39: 50,
 40: 50,
 41: 50,
 42: 50,
 43: 50,
 44: 50,
 45: 50,
 46: 50,
 47: 50,
 48: 50,
 49: 50,
 50: 50,
 51: 50,
 52: 50,
 53: 50,
 54: 50,
 55: 50,
 56: 50,
 57: 50,
 58: 50,
 59: 50,
 60: 50,
 61: 50,
 62: 50,
 63: 50,
 64: 50,
 65: 50,
 66: 50,
 67: 50,
 68: 50,
 69: 50,
 70: 50,
 71: 50,
 72: 50,
 73: 50,
 74: 50,
 75: 50,
 76: 48,
 77: 47,
 78: 45,
 79: 43,
 80: 41,
 81: 39,
 82: 37,
 83: 34,
 84: 32,
 85: 30,
 86: 28,
 87: 25,
 88: 23,
 89: 20,
 90: 18,
 91: 15,
 92: 12,
 93: 8,
 94: 5,
 95: 1,
 96: -4,
 97: -9,
 98: -16,
 99: -25,
 100: -39}

In [29]:
import plotly.graph_objects as go

# Crear la figura
fig = go.Figure()

x = list(change.keys())
y = list(change.values())
fig.add_trace(go.Scatter(x=x, y=y, mode='markers'))

fig.update_layout(
    title='Policies by Node',
    xaxis_title='Time',
    yaxis_title='Value of change'
)

fig.show()


In [30]:



#where non_begs is all the nodes which arent in the equivalence class of root.
#begs is the set of all those equivalent to root.
#finally, the cost function itself inputted to the nodes should be based on around the samme gamma as well.


def valueIteration(root, non_begs, begs, gamma,count):
    need_to_reset = [x for x in non_begs if x.hasDescendant(nodes)]
    for i in range(0,count):
        old_w = root.get_w()
        #print(old_w)
        for node in need_to_reset:
            node.reset_w()
        for node in begs:
            node.change_w(old_w,gamma)
    return root.get_w()

non_begs= [nodes[0],nodes[1]]
begs = [nodes[2]]

valueIteration(root, non_begs,begs,.9,1000)

{-50: 80.53080649912395,
 -49: 76.6027330887784,
 -48: 72.86617412015198,
 -47: 69.31178266158267,
 -46: 65.93066689640882,
 -45: 62.71436784450034,
 -44: 59.654838161184685,
 -43: 56.74442196000754,
 -42: 53.97583560826654,
 -41: 51.34214944662197,
 -40: 48.836770386323685,
 -39: 46.45342533970488,
 -38: 44.18614544158401,
 -37: 42.029251021091305,
 -36: 39.977337285198175,
 -35: 38.02526067688049,
 -34: 36.168125872392544,
 -33: 34.40127338356966,
 -32: 32.72026773241616,
 -31: 31.120886166472683,
 -30: 29.59910788459402,
 -29: 28.15110374380578,
 -28: 26.77322641884567,
 -27: 25.46200098683188,
 -26: 24.214115910236263,
 -25: 23.026414391971578,
 -24: 21.89588607692781,
 -23: 20.819659074708934,
 -22: 19.794992278625003,
 -21: 18.81926795617982,
 -20: 17.88998458635642,
 -19: 17.004749918933975,
 -18: 16.16127423086336,
 -17: 15.357363754374854,
 -16: 14.590914250980429,
 -15: 13.85990470485306,
 -14: 13.16239110820287,
 -13: 12.49650031020997,
 -12: 11.860423899799402,
 -11: 11.252