In [95]:
import matplotlib.pyplot as plt
from cvxopt import matrix, solvers
import numpy as np
import random
from SoccerWorld import player, game

In [107]:
def foe_action(qtable, player_state, rand_rate, num_actions, pi):

    if np.random.rand() < rand_rate:
        action = np.random.randint(0, num_actions)
    else:
#         print (np.sum(pi))
        action = random.choices([0,1,2,3,4], weights=pi)[0]

    return action


def lp(qtable, state):
    num_actions = 5
    c = matrix([0., 0., 0., 0., 0., -1.])
    
    G = qtable[state, :, :]
    G = np.hstack((G, np.ones((num_actions, 1)) * -1.))
    G = -G
    q = np.ones(num_actions+1) * 1.
    q[-1] = 0
    G = np.vstack((G, q))
    q = np.ones(num_actions+1) * -1.
    q[-1] = 0
    G = np.vstack((G, q))
    q = np.hstack((np.eye(num_actions)*-1., np.zeros((num_actions,1))))
    G = np.vstack((G,q))
    G = matrix(G)
    c = np.zeros((num_actions+1))
    c[-1] = -1
    c = matrix(c)
    b = np.zeros((num_actions))
    b = np.hstack((b, [1,-1]))
    b = np.hstack((b, np.zeros(num_actions)))
    b = matrix(b)
    sol = solvers.lp(c,G,b, solver='glpk')
    probs = np.array(sol['x'][:5]).reshape(1,5)[0]
    val = sol['x'][5]
    return (probs, val)
    
    
    
    
def foe_q_learning():
    
    errors = []
    qs = []
    #hyperparameters
    num_actions = 5
    num_rows = 2
    num_cols = 4
    num_states = num_rows * num_cols * 2 #2 is for possession
#     n = 1000000
    n = 1000000
    gamma = 0.9
    alpha = 0.7
    alpha_dec = 0.95
    rar = 0.5
    radr = 0.99
    rand_rate = rar
    qtableA = np.zeros(shape=(num_states, num_actions, num_actions), dtype='float')
    qtableB = np.zeros(shape=(num_states, num_actions, num_actions), dtype='float')
    verbose = False
    
    
    V1 = np.ones(num_states) * 1.
    V2 = np.ones(num_states) * 1.
    
    pi1 = np.ones(shape=(num_states, num_actions)) * 1./5
    pi2 = np.ones(shape=(num_states, num_actions)) * 1./5
    
    i = 0
    errors.append(0)
    qs.append(0)
    while i < n:
        match = game(verbose=verbose, rows=2, columns=4, goalRstart=0, \
                     goalRend=1, aGoal=0, bGoal=3,aPosition=[0,2],\
                     bPosition=[0,1])
        
        initial_state = match.reset()
        
        state, score, done = match.move(4, 4)
        
        
        rand_rate *= radr
        
        if rand_rate < 0.01:
            rand_rate = 0.1
            


        while not done:

            old_state = state
            pA_act = foe_action(qtableA, state[0], rand_rate, num_actions, pi1[state[0]])
            pB_act = foe_action(qtableB, state[1], rand_rate, num_actions, pi2[state[1]])

            state, score, done = match.move(pA_act, pB_act)

   
            
                
            qtableA[old_state[0], pA_act, pB_act] = (1-alpha) * qtableA[old_state[0], pA_act, pB_act] + alpha * (score[0] + gamma * V1[state[0]])
            probs, val = lp(qtableA, old_state[0])
            pi1[old_state[0], :] = probs
            V1[old_state[0]] = val
            
            qtableB[old_state[1], pB_act, pA_act] = (1-alpha) * qtableB[old_state[1], pB_act, pA_act] + alpha * (score[1] + gamma * V2[state[1]])        
            probs, val = lp(qtableB, old_state[1])
            pi1[old_state[1], :] = probs
            V1[old_state[1]] = val
            
            if old_state[0] == 2 and pA_act == 1 and pB_act == 4:
                    
                error = abs(qtableA[2, 1, 4] - qs[-1])
                qs.append(qtableA[2,1,4])
                print (i, qtableA[2, 1, 4], error)
                errors.append(error)
                if alpha < 0.001:
                    alpha = 0.001
                alpha *= alpha_dec
            i += 1
        
        if i % 10000 == 0:
            print (i)
    return errors
        
        

In [108]:
errr = foe_q_learning()

689 0.51597 0.51597
745 0.17284995000000006 0.34312005
1426 0.5750576147834143 0.4022076647834143
9462 0.22992959905096347 0.3451280157324509
10347 0.5606592759600608 0.3307296769090974
12648 0.6957138454032479 0.13505456944318706
20599 0.7545214232215992 0.05880757781835133
20701 0.7816413547305898 0.027119931508990613
20803 0.7948109480033738 0.013169593272783997
21280 0.8015119717419789 0.006701023738605039
31433 0.8050694389835423 0.0035574672415634723
31696 0.8070325914763374 0.001963152492795106
34657 0.8081550198673213 0.001122428390983865
35270 0.8088179940287643 0.0006629741614430085
37966 0.8092214983449 0.00040350431613567217
37987 0.5467878419369511 0.2624336564079489
40928 0.6278805251308032 0.0810926831938521
43622 0.6811840242483775 0.05330349911757437
44421 0.7170013092586752 0.035817285010297684
45143 0.5276071903087701 0.1893941189499051
51335 0.5984708831133604 0.07086369280459026
51530 0.6488979731513023 0.05042709003794188
53488 0.6853833100454685 0.036485336894166

In [109]:
errr

[0,
 0.51597,
 0.34312005,
 0.4022076647834143,
 0.3451280157324509,
 0.3307296769090974,
 0.13505456944318706,
 0.05880757781835133,
 0.027119931508990613,
 0.013169593272783997,
 0.006701023738605039,
 0.0035574672415634723,
 0.001963152492795106,
 0.001122428390983865,
 0.0006629741614430085,
 0.00040350431613567217,
 0.2624336564079489,
 0.0810926831938521,
 0.05330349911757437,
 0.035817285010297684,
 0.1893941189499051,
 0.07086369280459026,
 0.05042709003794188,
 0.03648533689416622,
 0.14746008715660863,
 0.05561040288574359,
 0.04203186059912967,
 0.03217689965621573,
 0.02492934701478622,
 0.01953267139766368,
 0.015466850552297795,
 0.012369658045438103,
 0.009985594747841775,
 0.10170212374478393,
 0.020826751672544108,
 0.017236693902946465,
 0.014370949291287238,
 0.012065195860264333,
 0.010196017376400524,
 0.00866990789041533,
 10.28007583417251,
 0.9177356406926709,
 0.7934186347405596,
 0.7550939035060136,
 0.5966263684620179,
 0.5230791810035509,
 0.4605146322248599