In [1]:
import matplotlib.pyplot as plt
import numpy as np
import itertools
import random 
import pickle
from matplotlib.pyplot import figure
import time
from tqdm import tqdm

from stoc_match import StochasticMatching, Aggregation

In [2]:
verteces = ['A', 'B', 'C', 'D']



graph = {
    'A': [('B', 200), ('C', 10), ('D', 50)],
    'B': [('A', 200), ('D', 20)],
    'C': [('A', 10), ('D', 1)],
    'D': [('A', 50), ('B', 20), ('C', 1)]
}

arrival_rates = {
    'A': 0.9,
    'B': 0.2,
    'C': 0.5,
    'D': 0.6
}


eta = 0.5
learning_rate = 0.00014
discount = 0.8
H = 55
k = 10

#experts = ['match_the_longest', 'probability_match_05']#, 'probability_match_02', 'probability_match_08']
experts = ['match_the_longest', 'edge_priority_match_reward', 'random_match']

rewards = {exp : [] for exp in experts}
models = {exp : None for exp in experts}

queue_max = 5

queues ={v: [] for v in verteces}

experts_queues = {exp : {
    'A': [],
    'B': [],
    'C': [],
    'D': []
} for exp in experts}

The diamond graph is stable for every greedy policy provided that the following statability condition is satisfied:

$\lambda_A < \lambda_B + \lambda_C + \lambda_D$ <br>
$\lambda_D < \lambda_C + \lambda_A + \lambda_B$ <br>
$\lambda_C + \lambda_B < \lambda_A + \lambda_D$

# Aggregation

In [3]:
with open('Q/Q_start.pkl', "rb") as fp:
    Q_start = pickle.load(fp)

In [4]:
weights_dict = {}
model_agg = Aggregation(graph, arrival_rates, experts, eta, queue_max)
model = StochasticMatching(graph, arrival_rates, queue_max)

state_space = list(itertools.product(np.array(range(queue_max + 1)), repeat=len(graph.keys())))
num_states = len(state_space)
n=0
state_space_ind = {}
for s in state_space:
    state_space_ind[str(n)] = s
    n += 1 
    
state_space_dict = {s : 0 for s in state_space}
len_updates = 50
res = int(len_updates / 10)
num_repeat_est = 10

queues_lenght = {v: [] for v in verteces}

K = len(experts)
weights_dict['-1'] = {state: np.array([1/K] * K) for state in state_space_ind.keys()}

repeat_weights = 10
learning_rate = 0.1# try even bigger
# learning_rate = 0.00014

In [5]:
%%time 
random.seed(6)
np.random.seed(6)
for n in range(repeat_weights):
    weights_dict = {}
    model_agg = Aggregation(graph, arrival_rates, experts, eta, queue_max, Q_start)
    model = StochasticMatching(graph, arrival_rates, queue_max)
    weights_dict['-1'] =  np.ones([num_states, K]) / K 
    time_expw = []
    time_tot = 0
    for i in tqdm(range(len_updates + 1)):
        if (i <= 200) and (i%10 == 0):
            start_time = time.time()
            weights = model_agg.aggregation_update_exp_NN(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
            time_expw.append(time_tot)
            weights_dict[str(i)] = weights.copy()
            random_state = state_space[random.randint(0, len(state_space)-1)]

        elif (i%res == 0) and (i>200):
            start_time = time.time()
            weights = model_agg.aggregation_update_exp_NN(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
            time_expw.append(time_tot)
            weights_dict[str(i)] = weights.copy()
            random_state = state_space[random.randint(0, len(state_space)-1)]

        else:
            start_time = time.time()
            weights = model_agg.aggregation_update_exp_NN(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
    with open('weights/weights_expw_NN7' + str(n) + '.pkl', 'wb') as output:
        pickle.dump(weights_dict, output)       

100%|███████████████████████████████████████████| 51/51 [41:25<00:00, 48.74s/it]
100%|██████████████████████████████████████| 51/51 [18:22:56<00:00, 1297.57s/it]
100%|█████████████████████████████████████████| 51/51 [1:14:13<00:00, 87.32s/it]
100%|███████████████████████████████████████████| 51/51 [44:10<00:00, 51.98s/it]
100%|███████████████████████████████████████████| 51/51 [43:50<00:00, 51.57s/it]
100%|███████████████████████████████████████████| 51/51 [43:22<00:00, 51.03s/it]
100%|███████████████████████████████████████████| 51/51 [42:32<00:00, 50.05s/it]
100%|███████████████████████████████████████████| 51/51 [42:34<00:00, 50.10s/it]
100%|███████████████████████████████████████████| 51/51 [42:17<00:00, 49.75s/it]
100%|███████████████████████████████████████████| 51/51 [42:23<00:00, 49.87s/it]

CPU times: user 8h 30min 55s, sys: 14h 44min 26s, total: 23h 15min 22s
Wall time: 1d 1h 19min 46s





In [6]:
model_agg.A[1000:] 

array([[ 3.58822545e-01,  6.87531701e-01, -4.55767203e-05],
       [ 3.45669218e-01,  6.91564509e-01, -6.09670874e-05],
       [ 4.33044783e-01,  7.02217116e-01, -8.62097184e-06],
       [ 4.27180729e-01,  7.06207364e-01, -9.84281359e-06],
       [ 4.12351918e-01,  7.07811427e-01, -1.37570786e-05],
       [ 4.07490884e-01,  7.27320944e-01, -1.53439196e-05],
       [ 3.89716043e-01,  7.37377449e-01, -2.28744038e-05],
       [ 3.85433234e-01,  7.50343628e-01, -2.51756055e-05],
       [ 2.79553880e-01,  4.42507465e-01, -2.62985999e-04],
       [ 2.49582742e-01,  4.14714251e-01, -4.93760591e-04],
       [ 2.09393163e-01,  4.00575926e-01, -1.10269609e-03],
       [ 1.97143446e-01,  4.20789341e-01, -1.38895690e-03],
       [ 2.04537499e-01,  4.46073222e-01, -1.20021098e-03],
       [ 2.14291356e-01,  4.68469493e-01, -9.89176690e-04],
       [ 3.25240505e-01,  4.96272398e-01, -9.75006644e-05],
       [ 2.99174212e-01,  4.91701416e-01, -1.70447149e-04],
       [ 2.62674958e-01,  4.68690456e-01