In [1]:
import matplotlib.pyplot as plt
from cvxopt import matrix, solvers
import numpy as np
import random
from SoccerWorld import player, game

In [10]:
def foe_action(qtable, player_state, rand_rate, num_actions, pi):

    if np.random.rand() < rand_rate:
        action = np.random.randint(0, num_actions)
    else:
#         print (np.sum(pi))
        action = random.choices([0,1,2,3,4], weights=pi)[0]

    return action


def lp(qtable, state):
    num_actions = 5
    c = matrix([0., 0., 0., 0., 0., -1.])
    
    G = qtable[state, :, :]
    G = np.hstack((G, np.ones((num_actions, 1)) * -1.))
    G = -G
    q = np.ones(num_actions+1) * 1.
    q[-1] = 0
    G = np.vstack((G, q))
    q = np.ones(num_actions+1) * -1.
    q[-1] = 0
    G = np.vstack((G, q))
    q = np.hstack((np.eye(num_actions)*-1., np.zeros((num_actions,1))))
    G = np.vstack((G,q))
    G = matrix(G)
    c = np.zeros((num_actions+1))
    c[-1] = -1
    c = matrix(c)
    b = np.zeros((num_actions))
    b = np.hstack((b, [1,-1]))
    b = np.hstack((b, np.zeros(num_actions)))
    b = matrix(b)
    sol = solvers.lp(c,G,b, solver='glpk')
    probs = np.array(sol['x'][:5]).reshape(1,5)[0]
    val = sol['x'][5]
    return (probs, val)
    
    
    
    
def foe_q_learning():
    
    errors = []
    qs = []
    #hyperparameters
    num_actions = 5
    num_rows = 2
    num_cols = 4
    num_states = num_rows * num_cols * 2 #2 is for possession
#     n = 1000000
    n = 1000000
    gamma = 0.9
    alpha = 0.5
    alpha_dec = 0.99
    rar = 0.5
    radr = 0.99
    rand_rate = rar
    qtableA = np.zeros(shape=(num_states, num_actions, num_actions), dtype='float')
    qtableB = np.zeros(shape=(num_states, num_actions, num_actions), dtype='float')
    verbose = False
    
    
    V1 = np.ones(num_states) * 1.
    V2 = np.ones(num_states) * 1.
    
    pi1 = np.ones(shape=(num_states, num_actions)) * 1./5
    pi2 = np.ones(shape=(num_states, num_actions)) * 1./5
    
    i = 0
    errors.append(0)
    qs.append(0)
    ind = []
    ind.append(0)
    while i < n:
        match = game(verbose=verbose, rows=2, columns=4, goalRstart=0, \
                     goalRend=1, aGoal=0, bGoal=3,aPosition=[0,2],\
                     bPosition=[0,1])
        
        initial_state = match.reset()
        
        state, score, done = match.move(4, 4)
        
        
        rand_rate *= radr
        
        if rand_rate < 0.05:
            rand_rate = 0.05
            


        while not done:

            old_state = state
            pA_act = foe_action(qtableA, state[0], rand_rate, num_actions, pi1[state[0]])
            pB_act = foe_action(qtableB, state[1], rand_rate, num_actions, pi2[state[1]])

            state, score, done = match.move(pA_act, pB_act)

   
            
                
            qtableA[old_state[0], pA_act, pB_act] = (1-alpha) * qtableA[old_state[0], pA_act, pB_act] + alpha * (score[0] + gamma * V1[state[0]])
            probs, val = lp(qtableA, old_state[0])
            pi1[old_state[0], :] = probs
            V1[old_state[0]] = val
            
            qtableB[old_state[1], pB_act, pA_act] = (1-alpha) * qtableB[old_state[1], pB_act, pA_act] + alpha * (score[1] + gamma * V2[state[1]])        
            probs, val = lp(qtableB, old_state[1])
            pi1[old_state[1], :] = probs
            V1[old_state[1]] = val
            
            if old_state[0] == 2 and pA_act == 1 and pB_act == 4:
                    
                error = abs(qtableA[2, 1, 4] - qs[-1])
                qs.append(qtableA[2,1,4])
                print (i, qtableA[2, 1, 4], error)
                errors.append(error)
                ind.append(i)
                if alpha < 0.001:
                    alpha = 0.001
                alpha *= alpha_dec
            i += 1
        
        if i % 10000 == 0:
            print (i)
    return errors, ind
        
        

In [None]:
errr, ind = foe_q_learning()

121 0.0 0.0
202 0.0 0.0
281 0.0 0.0
284 0.0 0.0
289 0.0 0.0
339 0.0 0.0
351 0.0 0.0
360 0.0 0.0
394 0.0 0.0
523 0.0 0.0
541 0.0 0.0
557 0.0 0.0
570 0.0 0.0
605 0.0 0.0
689 0.08955421041312171 0.08955421041312171
692 0.13973386472190089 0.05017965430877917
716 0.0802451222206096 0.059488742501291283
721 0.046424082431547525 0.03382103978906208
723 0.027053314605637705 0.01937076782590982
766 0.015878014756672266 0.01117529984896544
773 0.00938464554429554 0.006493369212376726
775 0.0055851410289982605 0.0037995045152972792
796 0.003346531081130966 0.0022386099478672945
837 0.0020186036478094765 0.0013279274333214894
849 0.001225617267290667 0.0007929863805188095
866 0.0007489616228171202 0.0004766556444735469
945 0.00046059524075640554 0.00028836638206071463
1035 0.06597167249200034 0.06551107725124393
1201 0.04107662567260335 0.024895046819396992
1243 0.09011045022594806 0.049033824553344714
1256 0.05678308338678815 0.03332736683915991
1258 0.035991861735032354 0.020791221651755794
126

85838 25.53743006318352 0.6556778865844528
97811 24.963695624981654 0.5737344382018676
98006 24.43652188443909 0.5271737405425618
109811 25.4555264412901 1.0190045568510087
110000
126939 25.265165142942774 0.19036129834732662
130781 26.20668164166562 0.9415164987228479
130993 25.67943742671513 0.5272442149504926
131021 26.574185390607088 0.8947479638919589
131317 26.03281431913798 0.5413710714691078
131598 32.050767991246516 6.017953672108536
137139 30.400651988913765 1.6501160023327515
137590 28.848487272792266 1.5521647161214993
137810 27.411726268028065 1.4367610047642003
142867 26.280545473192582 1.1311807948354833
144942 25.16386680957908 1.1166786636135022
151116 24.351239776660062 0.8126270329190177
151150 25.48103507374091 1.1297952970808467
152022 26.530393044682743 1.0493579709418341
152709 27.50568270427823 0.9752896595954859
153903 26.543826022171615 0.9618566821066139
154109 27.498909783746836 0.9550837615752208
154293 26.556592195684395 0.9423175880624406
154424 27.491926

494763 28.082479475733265 0.1153016108115068
495156 28.309508884911324 0.22702940917805847
496037 28.195225441933154 0.11428344297817006
498058 29.21027264675922 1.0150472048260646
499597 29.08817824538657 0.12209440137264949
500707 28.966508605913972 0.12166963947259646
502524 28.843246887868936 0.12326171804503616
503517 28.720559509687938 0.12268737818099851
503638 28.600418519969 0.12014098971893716
503940 28.482758146891786 0.11766037307721433
504994 29.433741398505084 0.9509832516132981
506695 29.30972619255528 0.12401520594980298
507902 29.18823237739361 0.12149381516167068
509302 29.069196137964017 0.11903623942959385
511055 28.952555587432883 0.11664055053113387
515828 29.846130553460256 0.8935749660273729
521106 30.023886251851593 0.1777556983913371
521208 30.19811794942191 0.17423169757031687


In [109]:
errr

[0,
 0.51597,
 0.34312005,
 0.4022076647834143,
 0.3451280157324509,
 0.3307296769090974,
 0.13505456944318706,
 0.05880757781835133,
 0.027119931508990613,
 0.013169593272783997,
 0.006701023738605039,
 0.0035574672415634723,
 0.001963152492795106,
 0.001122428390983865,
 0.0006629741614430085,
 0.00040350431613567217,
 0.2624336564079489,
 0.0810926831938521,
 0.05330349911757437,
 0.035817285010297684,
 0.1893941189499051,
 0.07086369280459026,
 0.05042709003794188,
 0.03648533689416622,
 0.14746008715660863,
 0.05561040288574359,
 0.04203186059912967,
 0.03217689965621573,
 0.02492934701478622,
 0.01953267139766368,
 0.015466850552297795,
 0.012369658045438103,
 0.009985594747841775,
 0.10170212374478393,
 0.020826751672544108,
 0.017236693902946465,
 0.014370949291287238,
 0.012065195860264333,
 0.010196017376400524,
 0.00866990789041533,
 10.28007583417251,
 0.9177356406926709,
 0.7934186347405596,
 0.7550939035060136,
 0.5966263684620179,
 0.5230791810035509,
 0.4605146322248599