In [1]:
import matplotlib.pyplot as plt
import numpy as np
import itertools
import random 
import pickle
from matplotlib.pyplot import figure
import time
from tqdm import tqdm

from stoc_match import StochasticMatching, Aggregation

In [2]:
verteces = ['A', 'B', 'C', 'D']



graph = {
    'A': [('B', 200), ('C', 10), ('D', 50)],
    'B': [('A', 200), ('D', 20)],
    'C': [('A', 10), ('D', 1)],
    'D': [('A', 50), ('B', 20), ('C', 1)]
}

arrival_rates = {
    'A': 0.9,
    'B': 0.2,
    'C': 0.5,
    'D': 0.6
}


eta = 0.5
learning_rate = 0.00014
discount = 0.8
H = 55
k = 10

#experts = ['match_the_longest', 'probability_match_05']#, 'probability_match_02', 'probability_match_08']
experts = ['match_the_longest', 'edge_priority_match_reward', 'random_match']

rewards = {exp : [] for exp in experts}
models = {exp : None for exp in experts}

queue_max = 5

queues ={v: [] for v in verteces}

experts_queues = {exp : {
    'A': [],
    'B': [],
    'C': [],
    'D': []
} for exp in experts}

The diamond graph is stable for every greedy policy provided that the following statability condition is satisfied:

$\lambda_A < \lambda_B + \lambda_C + \lambda_D$ <br>
$\lambda_D < \lambda_C + \lambda_A + \lambda_B$ <br>
$\lambda_C + \lambda_B < \lambda_A + \lambda_D$

# Aggregation

In [3]:
with open('Q/Q_start.pkl', "rb") as fp:
    Q_start = pickle.load(fp)

In [4]:
weights_dict = {}
model_agg = Aggregation(graph, arrival_rates, experts, eta, queue_max)
model = StochasticMatching(graph, arrival_rates, queue_max)

state_space = list(itertools.product(np.array(range(queue_max + 1)), repeat=len(graph.keys())))
num_states = len(state_space)
n=0
state_space_ind = {}
for s in state_space:
    state_space_ind[str(n)] = s
    n += 1 
    
state_space_dict = {s : 0 for s in state_space}
len_updates = 500
res = int(len_updates / 20)
num_repeat_est = 10

queues_lenght = {v: [] for v in verteces}

K = len(experts)
weights_dict['-1'] = {state: np.array([1/K] * K) for state in state_space_ind.keys()}

repeat_weights = 10
learning_rate = 0.1# try even bigger
# learning_rate = 0.00014

In [None]:
## %%time 
random.seed(6)
np.random.seed(6)
for n in range(repeat_weights):
    weights_dict = {}
    model_agg = Aggregation(graph, arrival_rates, experts, eta, queue_max, Q_start)
    model = StochasticMatching(graph, arrival_rates, queue_max)
    weights_dict['-1'] = np.ones([num_states, K]) / K 
    time_expw = []
    time_tot = 0
    for i in tqdm(range(len_updates + 1)):
        if (i <= 200) and (i%10 == 0):
            start_time = time.time()
            weights = model_agg.aggregation_update_exp_NN(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
            time_expw.append(time_tot)
            weights_dict[str(i)] = weights.copy()
            random_state = state_space[random.randint(0, len(state_space)-1)]

        elif (i%res == 0) and (i>200):
            start_time = time.time()
            weights = model_agg.aggregation_update_exp_NN(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
            time_expw.append(time_tot)
            weights_dict[str(i)] = weights.copy()
            random_state = state_space[random.randint(0, len(state_space)-1)]

        else:
            start_time = time.time()
            weights = model_agg.aggregation_update_exp_NN(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
    with open('weights/weights_expw_NN6' + str(n) + '.pkl', 'wb') as output:
        pickle.dump(weights_dict, output)       

In [6]:
model_agg.A[1000:] 

array([[ 1.80748641e-01,  3.35094810e-01,  0.00000000e+00],
       [ 2.33622015e-01,  3.36320639e-01,  0.00000000e+00],
       [-9.18712919e-04,  3.40527806e-01,  1.50952070e-03],
       [ 4.26356941e-02,  3.45543459e-01, -6.95892170e-06],
       [ 9.58945748e-02,  3.66562009e-01, -2.91801916e-10],
       [ 1.51603341e-01,  3.87423277e-01, -5.32907052e-15],
       [ 2.10001588e-01,  4.14266586e-01,  0.00000000e+00],
       [ 2.52725482e-01,  4.09244061e-01,  0.00000000e+00],
       [ 1.12862873e-03,  8.80290418e-02, -7.66083719e-04],
       [ 2.09193327e-02,  4.25394275e-02, -2.80668559e-04],
       [ 5.36583903e-02,  1.46793964e-02, -6.40666241e-04],
       [ 9.20666461e-02,  2.11006451e-03, -1.09547328e-03],
       [ 1.33689875e-01,  6.71761698e-03, -1.30284363e-03],
       [ 1.77313357e-01,  1.88278844e-02, -3.71248475e-04],
       [ 2.31691401e-02,  1.42480473e-01, -1.94926934e-04],
       [ 5.73694271e-02,  1.18746952e-01, -4.60764632e-07],
       [ 7.78737565e-02,  9.35907861e-02