In [52]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%env CUDA_VISIBLE_DEVICES=0

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: CUDA_VISIBLE_DEVICES=0


In [53]:
import os, sys
from IPython.display import clear_output
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from pandas import DataFrame
import torch, torch.nn as nn
import numpy as np
import random
import networkx as nx
from utils_ac import random_walk, ReplayBuffer, PathsBuffer, get_states_emb, convert_to_walk, \
    is_valid_path, covering_walk, graph_isomorphism_algorithm_covers, relabel_graph
from problem_ac import GraphProblem, generate_erdos_renyi_problems, generate_regular_problems, convert_graph
from generators_ac import connect_graph, generate_anonymous_walks
from network_ac import ActorCriticAct
import time
from collections import defaultdict as ddict

In [54]:
sys.path.insert(0, '..')
moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values

In [55]:
#params
NUM_PROBLEMS = 15
NUM_EPISODES = 50
NUM_VERTICES = 15
DEGREE = 6
THRESHOLD = 0.75
PATHS_BUFFER_CAPACITY = 100
REPLAY_BUFFER_CAPACITY = 100

In [56]:
problem_maker = generate_regular_problems(num_vertices=NUM_VERTICES, degree=DEGREE)

In [57]:
agent = ActorCriticAct(hid_size=256, gcn_size=256, vertex_emb_size=64)

In [58]:
optimizer = torch.optim.Adam(agent.parameters(), lr=1e-4)

In [59]:
#initialize buffers
path_buffer = PathsBuffer(capacity=PATHS_BUFFER_CAPACITY, threshold=THRESHOLD)
train_buffer = ReplayBuffer(capacity=REPLAY_BUFFER_CAPACITY)

In [60]:
actor_losses = []
critic_losses = []

In [61]:
problems = [next(problem_maker) for i in range(NUM_PROBLEMS)]

In [72]:
i = 0

for k in trange(len(problems)):
    
    problem = problems[k]
    
    edges = problem.get_edges()

    for vertex in problem.get_actions():

        path_buffer.flush()
        
        PATH_LENGTH = 2*problem.num_edges + 1        
        
        i += 1 
        
        for episode in range(NUM_EPISODES):
            
            problem.path = [vertex]
            
            source = problem.get_state()[0]
            
            states = []
            actions = []
            
            with torch.no_grad():
                graph_emb = agent.embed_graph(problem.edges)
            
            random_walk = [source]
            checked = ddict(list)
            stack = [source]
            visited = {source}
            ranks = {0: source} # to attempt to get maximal cover (possible to do without rank, but then no guarantees on maximality)
            revranks = {source: 0}
            
            states.append(random_walk[:])

            while len(stack) > 0:
                last = stack[-1]
                lastrank = revranks[last]
                maxrank = max(ranks.keys()) + 1
                with torch.no_grad():
                    probs, _ = agent.get_dist([random_walk[:]], graph_emb, edges)
                probs = probs[0].data.numpy()
                Nlast = [x for _,x in sorted(zip(probs, edges[random_walk[:][-1]]), reverse=True)]
                #print("Is valid", all(i in edges[random_walk[:][-1]] for i in Nlast))
                # going in depth
                for neighbor in Nlast:
                    if neighbor not in visited: # found new node, then add it to the walk
                        actions.append(neighbor)
                        random_walk.append(neighbor)
                        states.append(random_walk[:])
                        stack.append(neighbor)
                        checked[last].append(neighbor)
                        visited.add(neighbor)
                        ranks[maxrank] = neighbor
                        revranks[neighbor] = maxrank
                        break
                else: # we didn't find any new neighbor and rollback
                    stack.pop()
                    if len(stack) > 0:
                        random_walk.append(stack[-1])
                        checked[last].append(stack[-1])

                # interconnecting nodes that are already in walk
                for r in range(maxrank-1, lastrank+1, -1):
                    node = ranks[r]
                    if node not in checked[last] and node in Nlast:
                        checked[last].append(node)
                        random_walk.extend([node, last])
                        
                        
            print(is_valid_path(random_walk, problem))
            
            if len(path_buffer) >= 10:
                graph_emb = agent.embed_graph(problem.edges)
                reward = path_buffer.rank_path(random_walk[:])
                rewards = torch.FloatTensor([reward]*(len(states)-1))
                probs, values = agent.get_dist(states[:-1], graph_emb, edges)
                for i, dist in enumerate(probs):
                    valids = edges[states[i][-1]]
                    a = valids.index(actions[i])
                    m = torch.distributions.Categorical(dist)
                    log_prob = m.log_prob(torch.tensor(a))
                    log_probs.append(log_prob)

                advantage = rewards - values

                actor_loss  = -(log_probs * advantage.detach()).mean()
                critic_loss = advantage.pow(2).mean()
                                                          
                actor_losses.append(actor_loss.item())
                critic_losses.append(critic_loss.item())

                loss = actor_loss + critic_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            path_buffer.push(random_walk)
            
        if i % 25 == 0:   
            print(path_buffer.buffer)
        if i % 5 == 0:
            clear_output(True)
            plt.figure(figsize=[12, 6])
            plt.subplot(1,2,1)
            plt.title('Actor loss'); plt.grid()
            plt.scatter(np.arange(len(actor_losses)), actor_losses, alpha=0.1)
            plt.plot(moving_average(actor_losses, span=100, min_periods=100))

            plt.subplot(1,2,2)
            plt.title('Critic loss'); plt.grid()
            plt.scatter(np.arange(len(critic_losses)), critic_losses, alpha=0.1)
            plt.plot(moving_average(critic_losses, span=10, min_periods=10))
            plt.show()

  0%|          | 0/15 [00:00<?, ?it/s]

Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
False
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
False
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid True
Is valid




NameError: name 'log_probs' is not defined

In [67]:
problem.edges

defaultdict(set,
            {5: {4, 6, 7, 8, 9, 12},
             9: {1, 3, 5, 8, 12, 13},
             6: {0, 1, 5, 8, 10, 12},
             8: {3, 5, 6, 7, 9, 11},
             12: {2, 5, 6, 9, 11, 14},
             4: {1, 2, 5, 10, 13, 14},
             7: {0, 2, 3, 5, 8, 10},
             3: {1, 7, 8, 9, 11, 14},
             1: {0, 3, 4, 6, 9, 14},
             13: {0, 2, 4, 9, 11, 14},
             14: {1, 3, 4, 10, 12, 13},
             0: {1, 6, 7, 10, 11, 13},
             11: {0, 2, 3, 8, 12, 13},
             10: {0, 2, 4, 6, 7, 14},
             2: {4, 7, 10, 11, 12, 13}})

In [None]:
[0, 10, 1, 13, 9, 3, 8, 14, 5, 7, 12, 2, 12, 11, 4]

In [None]:
states[-1]

In [None]:
def test_agent(agent, problem, vertex):
    
    with torch.no_grad():
        graph_emb = agent.embed_graph(problem.edges)
        problem.path = [vertex]
        source = problem.get_state()[0]
        random_walk = [source]
        checked = ddict(list)
        stack = [source]
        visited = {source}
        ranks = {0: source} # to attempt to get maximal cover (possible to do without rank, but then no guarantees on maximality)
        revranks = {source: 0}
        
        while len(stack) > 0:
            last = stack[-1]
            lastrank = revranks[last]
            maxrank = max(ranks.keys()) + 1
            with torch.no_grad():
                probs, v = agent(get_states_emb([random_walk], graph_emb))    
            valids = problem.get_valid_actions(random_walk[-1])
            probs = probs*torch.FloatTensor(valids)
            probs_sum = torch.sum(probs)
            if probs_sum > 0:
                probs /= probs_sum
            else:
                print("All valid moves were masked, do workaround.")
                probs = probs + torch.FloatTensor(valids)
                probs /= torch.sum(probs)
                
            sorted, indices = torch.sort(probs, 1, descending=True)
                
            Nlast = list(indices[0].numpy())

            # going in depth
            for neighbor in Nlast:
                if neighbor not in visited: # found new node, then add it to the walk
                    actions.append(neighbor)
                    random_walk.append(neighbor)
                    states.append(random_walk[:])
                    stack.append(neighbor)
                    checked[last].append(neighbor)
                    visited.add(neighbor)
                    ranks[maxrank] = neighbor
                    revranks[neighbor] = maxrank
                    break
            else: # we didn't find any new neighbor and rollback
                stack.pop()
                if len(stack) > 0:
                    random_walk.append(stack[-1])
                    checked[last].append(stack[-1])

                # interconnecting nodes that are already in walk
            for r in range(maxrank-1, lastrank+1, -1):
                node = ranks[r]
                if node not in checked[last] and node in Nlast:
                    random_walk.extend([node, last])
                    checked[last].append(node)
        return random_walk

In [None]:
G = nx.random_regular_graph(6, 15)
G2 = relabel_graph(G)
print(15, graph_isomorphism_algorithm_covers(G, G2, agent, test_agent, 1))

In [None]:
x = [0, 9 , 6, 7] 

In [None]:
x.index(6)