In [1]:
import matplotlib.pyplot as plt
import numpy as np
import itertools
import random 
import pickle
from matplotlib.pyplot import figure
import time 
from tqdm import tqdm

from stoc_match import StochasticMatching, Aggregation

In [2]:
verteces = ['A', 'B', 'C', 'D']



graph = {
    'A': [('B', 200), ('C', 10), ('D', 50)],
    'B': [('A', 200), ('D', 20)],
    'C': [('A', 10), ('D', 1)],
    'D': [('A', 50), ('B', 20), ('C', 1)]
}

arrival_rates = {
    'A': 0.9,
    'B': 0.2,
    'C': 0.5,
    'D': 0.6
}


eta = 0.5
discount = 0.8
H = 55
k = 10

#experts = ['match_the_longest', 'probability_match_05']#, 'probability_match_02', 'probability_match_08']
experts = ['match_the_longest', 'edge_priority_match_reward', 'random_match']

rewards = {exp : [] for exp in experts}
models = {exp : None for exp in experts}

queue_max = 5

queues ={v: [] for v in verteces}

experts_queues = {exp : {
    'A': [],
    'B': [],
    'C': [],
    'D': []
} for exp in experts}

The diamond graph is stable for every greedy policy provided that the following statability condition is satisfied:

$\lambda_A < \lambda_B + \lambda_C + \lambda_D$ <br>
$\lambda_D < \lambda_C + \lambda_A + \lambda_B$ <br>
$\lambda_C + \lambda_B < \lambda_A + \lambda_D$

# Aggregation

In [7]:
with open('Q/Q_start.pkl', "rb") as fp:
    Q_start = pickle.load(fp)

In [8]:
Q_start

array([[118.76412123, 118.76412123, 118.76412123],
       [124.41642537, 124.41642537, 124.41642537],
       [127.49595968, 127.49595968, 127.49595968],
       ...,
       [528.99549619, 528.99549619, 528.99549619],
       [523.41556902, 523.41556902, 523.41556902],
       [514.06185029, 514.06185029, 514.06185029]])

In [9]:
weights_dict = {}
model_agg = Aggregation(graph, arrival_rates, experts, eta, queue_max)
model = StochasticMatching(graph, arrival_rates, queue_max)

state_space = list(itertools.product(np.array(range(queue_max + 1)), repeat=len(graph.keys())))
num_states = len(state_space)
n=0
state_space_ind = {}
for s in state_space:
    state_space_ind[str(n)] = s
    n += 1 
    
state_space_dict = {s : 0 for s in state_space}
len_updates = 3000
res = int(len_updates / 100)
num_repeat_est = 1

queues_lenght = {v: [] for v in verteces}

K = len(experts)
weights_dict['-1'] = {state: np.array([1/K] * K) for state in state_space_ind.keys()}

#learning_rate_start = 0.1
#learning_rate_end = 0
#delta_l_rate = (learning_rate_start - learning_rate_end) / len_updates

repeat_weights = 10
M = 210
lr = 0.005

In [10]:
%%time
random.seed(51)
for n in range(repeat_weights):
    weights_dict = {}
    model_agg = Aggregation(graph, arrival_rates, experts, eta, queue_max, Q_start)
    model = StochasticMatching(graph, arrival_rates, queue_max)
    weights_dict['-1'] =  np.ones([num_states, K]) / K 
    time_expw_variable_eta = []
    time_tot = 0
    for i in tqdm(range(len_updates + 1)):
        learning_rate = lr / np.sqrt(i + 1)
        if (i <= 200) and (i%10 == 0):
            start_time = time.time()
            weights = model_agg.aggregation_update_exp(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
            time_expw_variable_eta.append(time_tot)
            weights_dict[str(i)] = weights.copy()
            #if i % 50 == 0:
            #    print('update --> ', n, i)
        elif (i % res == 0) and (i > 200):
            start_time = time.time()
            weights = model_agg.aggregation_update_exp(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
            time_expw_variable_eta.append(time_tot)
            weights_dict[str(i)] = weights.copy()
            random_state = state_space[random.randint(0, len(state_space)-1)]
            #if i % (res*5) == 0:
            #    print('update --> ', n, i)
        else:
            start_time = time.time()
            weights = model_agg.aggregation_update_exp(discount, learning_rate, model, H, k, num_repeat_est)
            end_time = time.time()
            time_tot += end_time - start_time
        
    with open('weights/weights_expw_variable_eta_' + str(n) + '.pkl', 'wb') as output:
        pickle.dump(weights_dict, output)
    with open('weights/time_expw_variable_eta_' + str(n) + '.pkl', 'wb') as output:
        pickle.dump(time_expw_variable_eta, output)

update -->  0 0
update -->  0 50
update -->  0 100
update -->  0 150
update -->  0 200
update -->  0 300
update -->  0 450
update -->  0 600
update -->  0 750
update -->  0 900
update -->  0 1050
update -->  0 1200
update -->  0 1350
update -->  0 1500
update -->  0 1650
update -->  0 1800
update -->  0 1950
update -->  0 2100
update -->  0 2250
update -->  0 2400
update -->  0 2550
update -->  0 2700
update -->  0 2850
update -->  0 3000
update -->  1 0
update -->  1 50
update -->  1 100
update -->  1 150
update -->  1 200
update -->  1 300
update -->  1 450
update -->  1 600
update -->  1 750
update -->  1 900
update -->  1 1050
update -->  1 1200
update -->  1 1350
update -->  1 1500
update -->  1 1650
update -->  1 1800
update -->  1 1950
update -->  1 2100
update -->  1 2250
update -->  1 2400
update -->  1 2550
update -->  1 2700
update -->  1 2850
update -->  1 3000
update -->  2 0
update -->  2 50
update -->  2 100
update -->  2 150
update -->  2 200
update -->  2 300
update --