### Import modules

In [None]:
import networkx as nx
import numpy as np
np.random.seed(1)
# from known_rewards_helper_functions import get_Q_table

from matplotlib import pyplot as plt

%load_ext autoreload
%autoreload 2

In [None]:
def return_graph(graph_type='fully_connected', n_nodes=6, n_children=None):
    """
    Returns specified graph type.
    
    param graph_type: string. fully_connected, line, circle, star, or tree
    param n_nodes: Number of nodes in graph.
    param n_children: Number of children per node in the tree graph (only applicable for tree graph)
    """
    G = nx.Graph()

    if graph_type=='fully_connected':
        for i in range(n_nodes):
            for j in range(n_nodes):
                G.add_edge(i,j)
    elif graph_type=='line' or graph_type=='circle':
        for i in range(n_nodes):
            G.add_edge(i,i)
            if i<n_nodes-1:
                G.add_edge(i,i+1)
        if graph_type=='circle':
            G.add_edge(0,n_nodes-1)
    elif graph_type=='star':
        G.add_edge(0,0)
        for i in range(1,n_nodes):
            G.add_edge(i,i)
            G.add_edge(0,i)
    elif graph_type=='tree':
        assert n_children is not None
        G.add_edge(0,0)
        children = {0:0}
        for i in range(1,n_nodes):
            G.add_edge(i,i)
            available_nodes = np.sort(list(G.nodes))
            for node in available_nodes:
                if children[node] < n_children:
                    G.add_edge(node,i)
                    children[node] += 1
                    children[i] = 0
                    break
    else:
        raise ValueError("Invalid graph type. Must be fully_connected, line, circle, star, or tree.")
    return G
        

In [None]:

def get_Q_table(G, means, T=100):
    """
    param G: networkx graph (undirected); nodes in G = 0,1,2,3,...
    param means: vector of mean rewards; corresponding, by indexing, to graphs in G
    param T: number of times the agent plays the graph bandit "game"
    
    returns: Q-table, k, all_calls (the total number of q_value calculations)
    
    The value function is sum of all rewards over the T time steps (starting at initial node)
    """
    # Find best node and initialize Q-table
    
#     means = np.reshape(means, (-1,))
#     n_nodes = means.shape[0]
    
    n_nodes = len(means)
    best_node = np.argmax(means)
    mu_b = means[best_node]
    Q = np.ones((n_nodes,n_nodes))*(-np.inf)
    Q[best_node,best_node] = T*mu_b
    
    k=0
    next_round = {best_node}
    n_calls = 0
    while next_round:
        if k > T:
            break
            
        curr_round = next_round.copy()
        next_round = set()
        for curr_node in curr_round:
            for node in G.neighbors(curr_node):
                n_calls += 1
                q_value =  np.max(Q[curr_node])- mu_b + means[node]
                if q_value > np.max(Q[node]):
                    next_round.add(node)
                Q[node, curr_node] = q_value
        k+=1
    
    return Q, k, n_calls

def graph_get_policy(G_in,means):
    G = G_in.copy()
    
    sorted_mean = np.sort(list(set(means)))
    assert(len(sorted_mean)>=2)
    Delta = sorted_mean[-1]-sorted_mean[-2]
    D = nx.diameter(G)
    mu_star = np.max(means)
    Q,n_iter,n_calls=get_Q_table(G, means, T=D*mu_star/Delta)
    
    policy = {}
    for s in G:
        neighbors = [_ for _ in G.neighbors(s)]
        neighbor_cost = [Q[s,nb] for nb in neighbors]
        policy[s] = neighbors[np.argmax(neighbor_cost)]
    return policy,n_calls,n_iter

In [None]:
def DP_get_policy(G_cyc,means):
    G = G_cyc.copy()
    G.remove_edges_from(nx.selfloop_edges(G_cyc))

    # Assume no self-loops in this setting.
    # cost at s: g(s)= mu*-mu(s)
    # C(s,T) = min_{length<=T simple path(no loops) starting at s ending s*} \sum_t g(s_t) =  \min{C(s,T-1), \min_{w in N_s} g(w)+C(w,T-1)}
    # C(s*,0) = 0, C(v,0) = +inf if v!=s*.

    mu_star = np.max(means)
    s_star = np.argmax(means)

    g = mu_star - means

    n_nodes = G.number_of_nodes()

    C0 = np.ones(n_nodes)*np.inf

    C0[s_star] = 0

    Cs = [C0]

    # Value iteration for acyclic all-to-all weighted shortest path.
    n_calls = 0
    while True:
        C = np.zeros(n_nodes)

        for s in G:
            n_calls+=len(G[s])
            C[s] = np.min([Cs[-1][w]+g[w] for w in G.neighbors(s)]+[Cs[-1][s]])

        if np.all(C==Cs[-1]):
            break
        Cs.append(C)

    C = Cs[-1]

    # print(len(Cs))

    policy = {}

    # Compute the optimal policy from the cost table.
    for s in G:
        if not s==s_star:
            neighbors = [_ for _ in G.neighbors(s)]
            neighbor_total_cost = [g[nb]+C[nb] for nb in neighbors]
            policy[s] = neighbors[np.argmin(neighbor_total_cost)]
        else:
            policy[s] = s
    return policy,n_calls,len(Cs)

In [None]:
def dict_to_array(dic):
    a = np.zeros(len(dic))
    for key, val in dic.items():
        a[key]=val
    return a

In [None]:
def plot(VE,G_compute,DP_compute,legend=False,ylabel=False,savefig=None):
    plt.figure(dpi=100)
    plt.plot(VE,G_compute,color='red',label = 'Q-Graph',linewidth=5)

    plt.plot(VE,DP_compute,color = 'green',label = 'DP',linestyle='dotted',linewidth = 5)

    # plt.plot(VE,VE,color = 'grey',linestyle = 'dotted',label = 'Benchmark')


    plt.xlabel(r'$|S|\cdot|E|$',fontsize=25)
    plt.tick_params(labelsize=25)
    plt.style.use('seaborn-dark-palette')
    
    if legend:
        plt.legend(loc='upper center',ncol=2, bbox_to_anchor=(0.5, 1.3),fontsize = '25')
        
    
    if ylabel:
        plt.ylabel('Compute',fontsize=25)
        
    if not savefig is None:
        plt.savefig('Figures/{}'.format(savefig),bbox_inches = 'tight')
    plt.show()

In [None]:
def run_experiment(Gs):
    
    VE = []
    G_compute = []
    DP_compute = []
    

    for G in Gs:
        n = G.number_of_nodes()
        np.random.seed(773)
        means = np.random.rand(n)
        means[0]=0.99
        means[-1]=1
        
        # G = nx.erdos_renyi_graph(n,0.5)

        pol1,nG,G_iter = graph_get_policy(G,means)

        G_compute.append(nG)

        pol2,nDP,DP_iter = DP_get_policy(G,means)

        assert(np.all(dict_to_array(pol1)-dict_to_array(pol2)==0))
        DP_compute.append(nDP)

        VE.append(G.number_of_edges()*G.number_of_nodes())
        
    G_compute = np.array(G_compute)
    DP_compute = np.array(DP_compute)
        
    return VE,G_compute,DP_compute

In [None]:
graphs = ['line','fully_connected','star','tree','circle']

num_nodes = range(2,20,2)
n_tree_children = 4

Gs = {g:[return_graph(g,n,n_tree_children) for n in num_nodes] for g in graphs}


# Preparing the grid graphs needs some extra work.
grids = []
for j in range(3,13):
    G = nx.grid_2d_graph(int(np.floor(j/2)),int(np.ceil(j/2))) 
    l = {a:b for a,b in zip(G.nodes,range(G.number_of_nodes()))}
    G = nx.relabel_nodes(G,l)
    for n in G:
        G.add_edge(n,n)
    grids.append(G)

Gs['grid'] = grids
graphs.append('grid')
# Finish preparing the grid graphs

SE = {}
G_comp = {}
DP_comp = {}

for g in graphs:
    se,gc, dpc = run_experiment(Gs[g])
    
    SE[g]=se
    G_comp[g]=gc
    DP_comp[g]=dpc

In [None]:
plt.style.use('seaborn-whitegrid')
for g in graphs:
    plot(SE[g],G_comp[g],DP_comp[g],legend = (g=='line'), ylabel=(g in ['line','tree']), savefig = 'compute/{}-compute.png'.format(g))

# The Q-graph algorithm has an advantge when the graph is not densely connected.