In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# initializing transition matrix
p0 = np.diag([1,1,1,1])
for i in range(4):
    p0[np.mod(i+1,4),i] = 1
p0 = 0.5*p0
p1 = np.transpose(p0)
R = np.transpose([[-1,0,0,1],[-1,0,0,1]]) # reward [R0,R1]
n,m = 100,20  # total arm, active arm
X = np.random.randint(0,4,100) # initializing states
w = np.zeros(4) # initializing whittle index
Q =np.array([R for r in R], dtype=float)  # initializing Q(x,i,u) = R(i,u)

In [3]:
class pm:
    P,R,w,X,n,m,Q = np.array([p0,p1]),R,w,X,n,m,Q
    
class Whittle():
    def __init__(self):
        self.P, self.R, self.w,self.X,self.n,self.m,self.Q = pm.P,pm.R,pm.w,pm.X,pm.n,pm.m,pm.Q
        self.t, self.e, self.dim= [1,1], 0.1,self.P[0].shape[0]
        self.v = np.ones((self.n,2)) # local clock, v(i,u)
        self.A = np.random.randint(0,1,self.n) # list of action for arms
        self.dQ = self.Q*0  # delta for Q 
        self.X_ = self.X # previous state, bookeeping
        self.choose_matrix = np.tril(np.ones((self.dim,self.dim)))-0.5*np.identity(self.dim)
    def action(self):
        self.A = np.random.randint(0,1,self.n) # list of action for arms
        self.toss = np.random.rand()
        self.index = np.arange(self.n)
        if self.toss < self.e:
            np.random.shuffle(self.index)
            self.A[self.index[:self.m]] = 1
        else:
            self.W = [(self.w[x],i) for x,i in zip(self.X,range(self.n))] # [(w0, 0),(w1,1)..]
            self.W = np.array(self.W,dtype=[('whittle',float),('index',int)])
            self.W = np.sort(self.W)[-self.m:]
            self.index = self.W['index']   #sorted whittle index
            self.A[self.index] = 1
        
    
    def evolve(self):
        self.X_ = self.X  # storing previous state
        self.X = np.array([np.random.choice([0,1,2,3],p=self.P[a][x]) for a,x in zip(self.A,self.X)])
        
    def update_w(self):
        self.dw = np.diag(self.Q[:,:,1]) - np.diag(self.Q[:,:,0])  # Q^x(x,1)-Q^x(x,0)
        self.w += 0.001*self.dw/self.t[0]  # w --> w + y(t)[Q^x(x,1)-Q^x(x,0)]
        self.t[0] += 1    
    
    def update_Q(self):
        self.f = np.mean(self.Q,axis=(1,2))
        for x_,x,a in zip(self.X_,self.X,self.A): 
            self.dQ[x_,x,a]+=(1-a)*self.w[x_]+self.R[x,a]+np.max(self.Q[x_,x])-self.f[x_]-self.Q[x_,x,a] 
        self.Q += 0.001*self.dQ
        self.t[1] += 1
        

In [10]:
W = Whittle()
W.w

array([-0.34552499, -0.0018606 , -0.07842993,  0.03974153])

In [11]:
for i in range(200):
    W.action()
    W.evolve()
    W.update_Q()
    W.update_w()
    print(W.w)

[-0.37318515 -0.00055574 -0.07724178  0.03089109]
[-3.87029174e-01  6.80304427e-05 -7.66617994e-02  2.64838522e-02]
[-0.39627591  0.000458   -0.0762855   0.02357069]
[-0.4032308   0.00072536 -0.07601244  0.02142169]
[-0.40881442  0.00091555 -0.07580295  0.01973783]
[-0.41348592  0.00105206 -0.07563802  0.01837695]
[-0.4175076   0.00114926 -0.07550668  0.01725735]
[-0.42104378  0.00121552 -0.07540168  0.01632842]
[-0.42420408  0.00125617 -0.07531777  0.01555192]
[-0.42706476  0.00127592 -0.0752513   0.0148983 ]
[-0.42968155  0.00127831 -0.0751995   0.01435249]
[-0.43209596  0.0012662  -0.07516022  0.01390162]
[-0.43433986  0.00124166 -0.07513212  0.01353666]
[-0.43643811  0.00120636 -0.07511394  0.01324629]
[-0.43841041  0.00116163 -0.07510465  0.01302161]
[-0.44027268  0.00110856 -0.07510335  0.01285478]
[-0.44203788  0.00104805 -0.0751092   0.01273782]
[-0.44371665  0.00098088 -0.07512142  0.01266656]
[-0.44531785  0.00090772 -0.07513937  0.01263577]
[-0.44684888  0.00082917 -0.075162

[-0.47577784 -0.0071186  -0.07820938  0.0163496 ]
[-0.47609157 -0.00712939 -0.07821047  0.0164021 ]
[-0.47640747 -0.00714067 -0.07821117  0.01645536]
[-0.47672507 -0.00715234 -0.07821144  0.01650909]
[-0.47704397 -0.0071644  -0.07821127  0.01656279]
[-0.47736381 -0.0071768  -0.07821064  0.01661577]
[-0.47768433 -0.00718951 -0.07820951  0.01666749]
[-0.47800528 -0.00720246 -0.07820784  0.01671748]
[-0.47832643 -0.00721556 -0.07820559  0.01676525]
[-0.47864764 -0.00722868 -0.07820273  0.01681047]
[-0.47896865 -0.0072417  -0.07819926  0.01685264]
[-0.47928764 -0.00725457 -0.07819518  0.01689177]
[-0.4796023  -0.00726724 -0.07819051  0.01692797]
[-0.47991233 -0.00727961 -0.07818526  0.01696093]
[-0.48021751 -0.00729161 -0.07817945  0.01699028]
[-0.48051756 -0.00730319 -0.07817309  0.01701568]
[-0.48081213 -0.00731429 -0.0781662   0.01703701]
[-0.48110092 -0.00732491 -0.07815881  0.01705391]
[-0.48138371 -0.00733499 -0.07815098  0.01706609]
[-0.48166029 -0.00734451 -0.07814274  0.01707338]


In [12]:
W.X_,W.X,W.A

(array([1, 1, 1, 1, 2, 1, 3, 2, 2, 1, 2, 1, 3, 2, 1, 3, 0, 1, 1, 3, 1, 1,
        3, 1, 2, 3, 1, 2, 0, 2, 1, 2, 3, 1, 2, 1, 0, 0, 1, 3, 1, 1, 1, 2,
        0, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 0, 3, 3, 3, 0, 3, 0,
        0, 0, 3, 3, 0, 3, 0, 3, 0, 3, 3, 3, 0, 3, 0, 0, 3, 3, 0, 3, 0, 0,
        3, 0, 0, 3, 3, 0, 0, 0, 3, 0, 3, 0]),
 array([0, 1, 1, 1, 2, 0, 2, 1, 2, 0, 2, 1, 3, 1, 0, 2, 0, 0, 1, 2, 0, 0,
        3, 0, 1, 3, 0, 1, 3, 2, 1, 2, 2, 1, 2, 0, 3, 0, 1, 2, 0, 0, 1, 1,
        3, 3, 0, 3, 3, 2, 3, 2, 3, 2, 2, 3, 3, 2, 1, 0, 0, 3, 0, 0, 0, 0,
        0, 0, 0, 3, 3, 3, 3, 0, 3, 0, 3, 0, 3, 3, 3, 3, 3, 0, 3, 0, 3, 3,
        3, 0, 0, 0, 0, 3, 3, 3, 0, 0, 3, 0]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0,
        0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0,
        1, 0, 0, 1, 

In [16]:
W.Q,W.w

(array([[[-8.68713049e+01, -1.21314252e+02],
         [ 0.00000000e+00,  3.30705939e+01],
         [ 0.00000000e+00,  0.00000000e+00],
         [ 1.82070781e+02,  1.00000000e+00]],
 
        [[-5.37815329e+00, -1.00000000e+00],
         [ 6.70199998e+00,  7.29440772e+00],
         [ 0.00000000e+00,  9.79055382e-01],
         [ 1.00000000e+00,  1.00000000e+00]],
 
        [[-1.00000000e+00, -1.00000000e+00],
         [ 1.12092928e-01,  0.00000000e+00],
         [-7.88433459e+01, -7.68385328e+01],
         [ 1.00000000e+00,  1.64735923e+02]],
 
        [[-1.00000000e+00, -1.13844211e+01],
         [ 0.00000000e+00,  0.00000000e+00],
         [-6.81344931e+00,  0.00000000e+00],
         [ 4.19843355e+01,  3.26782209e+01]]]),
 array([-0.48456672, -0.00739023, -0.07801537,  0.01675545]))

In [8]:

l = [(np.random.rand(), np.random.rand()) for i in range(10)]
L = np.array(l, dtype=[('a',float),('b',float)])
np.sort(L, order=('b','a'))

array([(0.75680739, 0.09481884), (0.09338645, 0.13626824),
       (0.37327465, 0.19624388), (0.47076533, 0.27571411),
       (0.33163505, 0.35486515), (0.83120935, 0.72459937),
       (0.75145607, 0.78948953), (0.55759471, 0.85157688),
       (0.60237792, 0.86379351), (0.98701374, 0.94648353)],
      dtype=[('a', '<f8'), ('b', '<f8')])

In [17]:
L

NameError: name 'L' is not defined

In [25]:
l =W.w
L = np.array(list(zip(l,l)), dtype=[('a',float),('b',float)])
M= list(zip(L['a'],L['b'],l))
M

[(-0.4845667242018684, -0.4845667242018684, -0.4845667242018684),
 (-0.007390231897007059, -0.007390231897007059, -0.007390231897007059),
 (-0.07801536609768217, -0.07801536609768217, -0.07801536609768217),
 (0.01675545234490244, 0.01675545234490244, 0.01675545234490244)]

In [26]:
L

array([(-0.48456672, -0.48456672), (-0.00739023, -0.00739023),
       (-0.07801537, -0.07801537), ( 0.01675545,  0.01675545)],
      dtype=[('a', '<f8'), ('b', '<f8')])

In [105]:
np.eye(4)[0,2]

0.0

In [106]:
x_,x,a =0,0,0
(1-a)*W.w[x_]+W.R[a,x_]+np.max(W.Q[x_,x])-W.f[x_]-W.Q[x_,x,a]

-1.0000000000000002

In [107]:
W.w[x_]

0.0

In [108]:
W.R[a,x_]

-1

In [109]:
W.Q[x_,x,a]

-1.24

In [110]:
np.array(W.R)[0,0]

-1

In [111]:
[r for r in R]

[array([-1, -1]), array([0, 0]), array([0, 0]), array([1, 1])]

In [112]:
Q = np.zeros((4,4,2))

In [113]:
np.array([R for i in range(4)])

array([[[-1, -1],
        [ 0,  0],
        [ 0,  0],
        [ 1,  1]],

       [[-1, -1],
        [ 0,  0],
        [ 0,  0],
        [ 1,  1]],

       [[-1, -1],
        [ 0,  0],
        [ 0,  0],
        [ 1,  1]],

       [[-1, -1],
        [ 0,  0],
        [ 0,  0],
        [ 1,  1]]])

In [207]:
Q +=Q

In [208]:
Q

array([[[-2, -2],
        [ 0,  0],
        [ 0,  0],
        [ 2,  2]],

       [[-2, -2],
        [ 0,  0],
        [ 0,  0],
        [ 2,  2]],

       [[-2, -2],
        [ 0,  0],
        [ 0,  0],
        [ 2,  2]],

       [[-2, -2],
        [ 0,  0],
        [ 0,  0],
        [ 2,  2]]])

In [9]:
T = np.random.randint(0,4,(4,3,2))
T

array([[[0, 3],
        [0, 0],
        [0, 0]],

       [[3, 1],
        [2, 0],
        [0, 1]],

       [[1, 3],
        [0, 3],
        [2, 0]],

       [[0, 3],
        [0, 3],
        [1, 1]]])

In [16]:
np.max([T[:,1,0] ,T[:,1,1]],axis=0)

array([0, 2, 3, 3])

In [12]:
T[:,1,0],T[:,1,1]

(array([0, 2, 0, 0]), array([0, 0, 3, 3]))

In [20]:
np.max(T[:,1],axis=1)

array([0, 2, 3, 3])

In [25]:
np.zeros(T.shape,dtype=float)

array([[[0., 0.],
        [0., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 0.],
        [0., 0.]]])