In [350]:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
import time

In [351]:
# Assumption: 
# 1) When a node is chosen its children are auto added, and if any of the children are the goal it ends
# 2) There are never no edge nodes, always edges that have links left

In [352]:
# Constants
ignore_internal_nodes = True  # Whether to ignore the internal nodes when picking a new node to add to graph
num_nodes = 2000
max_ep_steps = 50
prop_steps = 3
num_eps = 1
goal_node = num_nodes+1  # Actual node goal
# there are three ways im testing, 
# 1)if i put it in the state update func
# 2) if I send it in to the output i.e. after the update 
# 3) If i dont send in the goal and do a cos sim test between the output and the goal
goal_opt = 2
feat_size = 800
fake_goal = np.random.uniform(0, 1, feat_size)  # Fake embedding
hidden_size = feat_size
goal_size = hidden_size

In [353]:
node_dict = {}
for i in range(num_nodes):
    node_dict[i] = np.random.uniform(0, 1, feat_size)
G = nx.cycle_graph(node_dict, nx.DiGraph())
nx.set_node_attributes(G, node_dict, 'v')
# nx.draw(G, with_labels=True)

In [354]:
# Graph
# feat_size = 3
# G = nx.DiGraph()
# A = np.array([3, 1, 2])
# B = np.array([5, 3, 1])
# C = np.array([2, 3, 1])
# D = np.array([5, 7, 2])
# E = np.array([6, 4, 6])
# F = np.array([50,70, 45])
# H = np.array([3, 1, 1])
# G.add_node(0, v=A)
# G.add_node(1, v=B)
# G.add_node(2, v=C)
# G.add_node(3, v=D)
# G.add_node(4, v=E)
# G.add_node(5, v=F)
# G.add_node(6, v=H)
# G.add_edges_from([(3, 4), (1, 0), (2, 1), (4, 3), (2, 4), (4, 0), (2, 3), (3, 0), (4, 1), (1, 4),
#                   (4, 5), (5, 6), (3, 5), (0, 5)])

In [355]:
# nx.draw(G, with_labels=True, font_weight='bold')
# plt.show()

In [356]:
# Out-Node and feat node dicts, for adding children
out_node_dict = {}
feat_node_dict = {}
for node, data in G.nodes(data=True):
    out_node_dict[node] = list(G.successors(node))
    feat_node_dict[node] = data['v']
# print(out_node_dict)
# print(feat_node_dict)

In [357]:
# Define message, aggregate, update and output functions (one each), appied in order to each node
# Message: Takes an array (N_in, feats) (in-nodes) and applies the func
def mF(X):
    return X * 2  # Output: x.shape
# Aggregate: Takes a array of func (N_in x length) (in-node messages) and aggregates which in this case is average 
def aF(X):
    return np.mean(X, axis=0)  # Output: (N_in,)
# Update opt 1: Takes an array (3, feat) which is the curr node feats, input goal and the aggregate
# Update opt 2 and 3: Takes an array (2, feat) which is the curr node feats and the aggregate
def uF(X):
    if goal_opt == 1:
        X = np.concatenate((fake_goal.reshape(1, -1), X), axis=0)
    return np.sum(X, axis=0)
# Output opt 1: Takes a vector (update for node) and applies func  
# Output opt 2: Takes a an vector, update and stacks goal, applies func
# Output opt 3: Takes a vector (update for node) and applies func
def oF(X):
    if goal_opt == 1:
        return np.sum(X + 2.)  # Output: (1,)
    elif goal_opt == 2:
        X = np.stack((fake_goal, X))
        return np.sum(np.mean(X, axis=0) + 2.)  # Output: (1,)
    else:
        return X + 2.  # Output: x.shape

In [358]:
# Given a node, graph and out_node_dict and feat_dict return the graph with the childrne of the node added
# Also return if goal reached
def add_children(node, G, out_nodes, node_feat, goal_node):
    achieved_goal = False
    children = out_nodes[node]
    children_vals = [node_feat[x] for x in children]
    for child, val in zip(children, children_vals):
        if child not in G:
            G.add_node(child, v=val)
            if child == goal_node:
                achieved_goal = True
        if not G.has_edge(node, child):  # the node might exist but not the edge
            G.add_edge(node, child)
    return achieved_goal

In [359]:
# Pick a random node to start with
# start_node = np.random.choice(list(G.node))
start_node = 1
start_node_v = G.nodes(data=True)[start_node]['v']
G_init = nx.DiGraph()
# Init. Graph is init node and its children
G_init.add_node(start_node, v=start_node_v)
got_goal = add_children(start_node, G_init, out_node_dict, feat_node_dict, goal_node)
assert got_goal == False

In [360]:
# nx.draw(G_init, with_labels=True, font_weight='bold')
# plt.show()

In [361]:
# For now this is -1 per timestep +5 on terminal for reaching goal, -5 on terminal for not reaching goal
# In future, i want to give it feedback on how close it got by either:
# 1) For all nodes (or edge nodes) calc the shortest distance to the goal and then pick the shortest and use that
# 2) If (1) takes too long than instead embed all of the found nodes names and pick the one with the closest dist to
# the goal embedding to use
def reward_func(terminal, reach_goal):
    rew = -1
    if terminal:
        rew += 5 if reach_goal else -5
    return rew

In [362]:
def cos_sim(a, b):
    return cosine_similarity(a, b)[:, 0]

In [363]:
def propogate(G_curr, last_prop):
    outputs_dict = {}
    # Propogate: goes through each node and applies the 4 funcs above, at same time to each node
    # so don't update the graph that is being propogated
    # ALso dont need to run output model unless its the las propogate
    new_node_dict = {}
    for node, data in G_curr.nodes(data=True):
        curr_feat = data['v']
        # Get predecessors
        preds_feats = np.array([G_curr.nodes[n]['v'] for n in G_curr.predecessors(node)])
        # Apply message, aggregate, update and output functions
        if preds_feats.size != 0:  # Has preds
            messages = mF(preds_feats)
            aggregate = aF(messages)
            stack = np.stack((curr_feat, aggregate))
            update = uF(stack)
        else:
            update = uF(curr_feat.reshape(1, -1))
        if last_prop:
            output = oF(update)
            outputs_dict[node] = output
        new_node_dict[node] = update
    return outputs_dict, new_node_dict

In [364]:
# Action function (for now just greedy)
def select_node(node_outputs, goal):
    if goal_opt == 1 or goal_opt == 2:  # Then node outputs is a vector
        return np.argmax(node_outputs)
    else:  # o.w. it is of shape (N_nodes, output_len)
        # Cos-sim
        cos_sims = cos_sim(node_outputs, np.stack([goal] * node_outputs.shape[0]))  # vector
#         print('cos sims: {}'.format(cos_sims))
        return np.argmax(cos_sims)

In [365]:
def run_episode(G_init):
    G_curr = G_init.copy()
    total_rew = 0
    step_time = []
    for step in range(max_ep_steps):
        ts = time.time()
#         print('--- Ep Step: {} ----'.format(step+1))
        # Propgate
        for p in range(prop_steps):
#             print('- Prop step: {} -'.format(p+1))
            outputs_dict, new_node_dict = propogate(G_curr, p == prop_steps - 1)
#             print('outputs: {}'.format(outputs_dict))
#             print('new node dict: {}'.format(new_node_dict))
            nx.set_node_attributes(G_curr, new_node_dict, 'v')  # Update curr graph
#             print('new greaph: {}'.format(G_curr.nodes(data=True)))
        # Pick a node to add/go to if already exists and I allow it
        if ignore_internal_nodes:
#             print('ignoring internal')
            # Get a list of the actual node indexes (the abs based on the overall graph)
            edge_nodes = [n for n in G_curr.nodes if len(list(G_curr.successors(n))) == 0]
        
            # For now i can make the goal node not actually in the graph just to check speed
            if len(edge_nodes) == 0:
                break
        
#             print('edge nodes: {}'.format(edge_nodes))
            # Get their outputs
            outputs = [outputs_dict[x] for x in edge_nodes]
            if goal_opt == 3:
                outputs = np.stack(outputs)  # (num edge nodes x output size)
                assert outputs.shape == (len(edge_nodes), goal_size)
            else:
                outputs = np.array(outputs)  # (num edge nodes,)
#             print('selected outputs: {}'.format(outputs))
            node_rel_indx = select_node(outputs, fake_goal)  # Rel to filtered nodes
#             print('node rel index: {}'.format(node_rel_indx))
            # Grab real node
            chosen_node = edge_nodes[node_rel_indx]
        else:
            # Get list of nodes and outputs
            nodes, outputs = [], []
            for n, o in outputs_dict.items():
                nodes.append(n)
                outputs.append(o)
#             print('nodes: {}  outputs: {}'.format(nodes, outputs))
            outputs = np.array(outputs)  # (num nodes,)
            node_rel_indx = select_node(outputs, fake_goal)
#             print('node rel indx: {}'.format(node_rel_indx))
            # Grab real node
            chosen_node = nodes[node_rel_indx]
#         print('chosen node: {}'.format(chosen_node))
        # Add children (returns if any of the children were goal)
        achieved_goal = add_children(chosen_node, G_curr, out_node_dict, feat_node_dict, goal_node)
#         print('new cur graph: {}'.format(G_curr.nodes(data=True)))
#         nx.draw(G_curr, with_labels=True, font_weight='bold')
#         plt.show()
        # Check if terminal (ends if any children are goal)
        done = True if (step == max_ep_steps - 1) or achieved_goal else False
#         print('done: {}  achievde goal: {}'.format(done, achieved_goal))
        # Get reward
        rew = reward_func(done, achieved_goal)
#         print('rew: {}'.format(rew))
        total_rew += rew
        step_time.append(time.time() - ts)
        # Break if done
        if done:
            break
    return total_rew, step + 1, achieved_goal, G_curr, step_time

In [366]:
t0 = time.time()
for ep in range(num_eps):
#     print('EP: {} --------------------'.format(ep+1))
#     print('init graph: {}'.format(G_init.nodes(data=True)))
    ep_rew, ep_step, got_goal, G_ep, step_time = run_episode(G_init)
    print('EPISODE TIME (s): {}'.format(time.time() - t0))
    print('AVG STEP TIME (s): {}'.format(np.array(step_time).mean()))
    print('Ep: {}  Reward: {}  Reached Goal: {}  Num Steps: {}'.format(ep + 1, ep_rew, got_goal, ep_step))
#     nx.draw(G_ep, with_labels=True, font_weight='bold')
#     plt.show()

EPISODE TIME (s): 0.029824018478393555
AVG STEP TIME (s): 0.000591740608215332
Ep: 1  Reward: -55  Reached Goal: False  Num Steps: 50
