In [58]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%env CUDA_VISIBLE_DEVICES=0

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: CUDA_VISIBLE_DEVICES=0


In [59]:
import os, sys
from IPython.display import clear_output
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from pandas import DataFrame
import torch, torch.nn as nn
import numpy as np
import random
import networkx as nx
from utils_ac import random_walk, ReplayBuffer, PathsBuffer, get_states_emb, convert_to_walk, \
    is_valid_path, covering_walk, graph_isomorphism_algorithm_covers, relabel_graph, is_valid_path_new
from problem_ac import GraphProblem, generate_erdos_renyi_problems, generate_regular_problems, convert_graph
from generators_ac import connect_graph, generate_anonymous_walks
from network_ac import ActorCriticAct
import time
from collections import defaultdict as ddict

In [60]:
sys.path.insert(0, '..')
moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values

In [61]:
#params
NUM_PROBLEMS = 15
NUM_EPISODES = 50
NUM_VERTICES = 15
DEGREE = 6
THRESHOLD = 0.75
PATHS_BUFFER_CAPACITY = 100
REPLAY_BUFFER_CAPACITY = 100

In [62]:
problem_maker = generate_regular_problems(num_vertices=NUM_VERTICES, degree=DEGREE)

In [63]:
agent = ActorCriticAct(hid_size=256, gcn_size=256, vertex_emb_size=64)

In [64]:
optimizer = torch.optim.Adam(agent.parameters(), lr=1e-4)

In [65]:
#initialize buffers
path_buffer = PathsBuffer(capacity=PATHS_BUFFER_CAPACITY, threshold=THRESHOLD)
train_buffer = ReplayBuffer(capacity=REPLAY_BUFFER_CAPACITY)

In [66]:
actor_losses = []
critic_losses = []

In [67]:
problems = [next(problem_maker) for i in range(NUM_PROBLEMS)]

In [69]:
i = 0

for k in trange(len(problems)):
    
    problem = problems[k]
    
    edges = problem.get_edges()
    print(edges)

    for vertex in problem.get_actions():

        path_buffer.flush()
        
        PATH_LENGTH = 2*problem.num_edges + 1        
        
        i += 1 
        
        for episode in range(NUM_EPISODES):
            
            problem.path = [vertex]
            
            source = problem.get_state()[0]
            
            states = []
            actions = []
            
            with torch.no_grad():
                graph_emb = agent.embed_graph(problem.edges)
            
            random_walk = [source]
            checked = ddict(list)
            stack = [source]
            visited = {source}
            ranks = {0: source} # to attempt to get maximal cover (possible to do without rank, but then no guarantees on maximality)
            revranks = {source: 0}
            
            states.append(random_walk[:])

            while len(stack) > 0:
                last = stack[-1]
                lastrank = revranks[last]
                maxrank = max(ranks.keys()) + 1
                with torch.no_grad():
                    probs, _ = agent.get_dist([random_walk[:]], graph_emb, edges)
                    print(probs)
                probs = probs[0].data.numpy()
                Nlast = [x for _,x in sorted(zip(probs, edges[random_walk[:][-1]]), reverse=True)]
                #print("Is valid", all(i in edges[random_walk[:][-1]] for i in Nlast))
                # going in depth
                for neighbor in Nlast:
                    if neighbor not in visited:
                        print("neigh ok")# found new node, then add it to the walk
                        actions.append(neighbor)
                        random_walk.append(neighbor)
                        print("rw", random_walk)
                        print("aw", convert_to_walk(random_walk))
                        print("Nlast", Nlast)
                        print("Stack", stack)
                        print("Checked", checked[last])
                        print("Visited", visited)
                        if not is_valid_path_new(random_walk, edges):
                            print("path not valid")
                        states.append(random_walk[:])
                        stack.append(neighbor)
                        checked[last].append(neighbor)
                        visited.add(neighbor)
                        ranks[maxrank] = neighbor
                        revranks[neighbor] = maxrank
                        break
                else: # we didn't find any new neighbor and rollback
                    stack.pop()
                    if len(stack) > 0:
                        print("pop ok")
                        random_walk.append(stack[-1])
                        checked[last].append(stack[-1])
                        print("rw", random_walk)
                        print("aw", convert_to_walk(random_walk))
                        print("Nlast", Nlast)
                        print("Stack", stack)
                        print("Checked", checked[last])
                        if not is_valid_path_new(random_walk, edges):
                            print("path not valid")
                            break
                # interconnecting nodes that are already in walk
                for r in range(maxrank-1, lastrank+1, -1):
                    node = ranks[r]
                    if node not in checked[last] and node in Nlast:
                        print("last ok")
                        checked[last].append(node)
                        random_walk.extend([node, last])
                        print("rw", random_walk)
                        print("aw", convert_to_walk(random_walk))
                        print("node", node)
                        print("last", last)
                        print("Nlast", Nlast)
                        print("Stack", stack)
                        print("Checked", checked[last])
                        if not is_valid_path_new(random_walk, edges):
                            print("path not valid")
                            break
            break
        break
    
"""           
            print(is_valid_path_new(random_walk, edges))
            
            if len(path_buffer) >= 10:
                graph_emb = agent.embed_graph(problem.edges)
                reward = path_buffer.rank_path(random_walk[:])
                rewards = torch.FloatTensor([reward]*(len(states)-1))
                probs, values = agent.get_dist(states[:-1], graph_emb, edges)
                for i, dist in enumerate(probs):
                    valids = edges[states[i][-1]]
                    a = valids.index(actions[i])
                    m = torch.distributions.Categorical(dist)
                    log_prob = m.log_prob(torch.tensor(a))
                    log_probs.append(log_prob)

                advantage = rewards - values

                actor_loss  = -(log_probs * advantage.detach()).mean()
                critic_loss = advantage.pow(2).mean()
                                                          
                actor_losses.append(actor_loss.item())
                critic_losses.append(critic_loss.item())

                loss = actor_loss + critic_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            path_buffer.push(random_walk)
            
        if i % 25 == 0:   
            print(path_buffer.buffer)
        if i % 5 == 0:
            clear_output(True)
            plt.figure(figsize=[12, 6])
            plt.subplot(1,2,1)
            plt.title('Actor loss'); plt.grid()
            plt.scatter(np.arange(len(actor_losses)), actor_losses, alpha=0.1)
            plt.plot(moving_average(actor_losses, span=100, min_periods=100))

            plt.subplot(1,2,2)
            plt.title('Critic loss'); plt.grid()
            plt.scatter(np.arange(len(critic_losses)), critic_losses, alpha=0.1)
            plt.plot(moving_average(critic_losses, span=10, min_periods=10))
            plt.show()
"""

  7%|▋         | 1/15 [00:00<00:01,  8.24it/s]

{6: [3, 5, 8, 9, 12, 14], 9: [2, 3, 4, 6, 12, 13], 5: [0, 1, 6, 8, 10, 13], 14: [0, 1, 2, 6, 11, 13], 3: [2, 6, 7, 9, 10, 13], 12: [0, 1, 2, 4, 6, 9], 8: [1, 5, 6, 7, 10, 11], 4: [0, 2, 9, 11, 12, 13], 2: [3, 4, 7, 9, 12, 14], 13: [3, 4, 5, 9, 10, 14], 1: [5, 8, 10, 11, 12, 14], 0: [4, 5, 7, 11, 12, 14], 10: [1, 3, 5, 7, 8, 13], 7: [0, 2, 3, 8, 10, 11], 11: [0, 1, 4, 7, 8, 14]}
[tensor([0.1657, 0.1660, 0.1679, 0.1707, 0.1648, 0.1650])]
0
neigh ok
rw [0, 11]
aw [0, 1]
Nlast [11, 7, 5, 4, 14, 12]
Stack [0]
Checked []
Visited {0}
[tensor([0.1699, 0.1693, 0.1640, 0.1658, 0.1679, 0.1631])]
11
neigh ok
rw [0, 11, 1]
aw [0, 1, 2]
Nlast [0, 1, 8, 7, 4, 14]
Stack [0, 11]
Checked []
Visited {0, 11}
[tensor([0.1648, 0.1687, 0.1697, 0.1694, 0.1636, 0.1637])]
1
neigh ok
rw [0, 11, 1, 10]
aw [0, 1, 2, 3]
Nlast [10, 11, 8, 5, 14, 12]
Stack [0, 11, 1]
Checked []
Visited {0, 1, 11}
[tensor([0.1699, 0.1638, 0.1643, 0.1661, 0.1682, 0.1677])]
10
neigh ok
rw [0, 11, 1, 10, 8]
aw [0, 1, 2, 3, 4]
Nlast [1, 8

 20%|██        | 3/15 [00:00<00:01,  8.68it/s]

[tensor([0.1671, 0.1655, 0.1668, 0.1683, 0.1679, 0.1644])]
0
neigh ok
rw [0, 6]
aw [0, 1]
Nlast [6, 7, 1, 5, 3, 14]
Stack [0]
Checked []
Visited {0}
[tensor([0.1685, 0.1665, 0.1649, 0.1662, 0.1677, 0.1662])]
6
neigh ok
rw [0, 6, 10]
aw [0, 1, 2]
Nlast [0, 10, 1, 12, 5, 3]
Stack [0, 6]
Checked []
Visited {0, 6}
[tensor([0.1649, 0.1673, 0.1667, 0.1676, 0.1663, 0.1672])]
10
neigh ok
rw [0, 6, 10, 11]
aw [0, 1, 2, 3]
Nlast [11, 6, 13, 7, 12, 2]
Stack [0, 6, 10]
Checked []
Visited {0, 10, 6}
[tensor([0.1670, 0.1656, 0.1658, 0.1666, 0.1681, 0.1668])]
11
neigh ok
rw [0, 6, 10, 11, 1]
aw [0, 1, 2, 3, 4]
Nlast [10, 1, 12, 5, 4, 2]
Stack [0, 6, 10, 11]
Checked []
Visited {0, 10, 11, 6}
[tensor([0.1690, 0.1655, 0.1680, 0.1654, 0.1682, 0.1640])]
1
neigh ok
rw [0, 6, 10, 11, 1, 4]
aw [0, 1, 2, 3, 4, 5]
Nlast [0, 11, 6, 4, 9, 14]
Stack [0, 6, 10, 11, 1]
Checked []
Visited {0, 1, 6, 10, 11}
[tensor([0.1667, 0.1657, 0.1691, 0.1657, 0.1685, 0.1643])]
4
neigh ok
rw [0, 6, 10, 11, 1, 4, 8]
aw [0, 1, 2, 3

 27%|██▋       | 4/15 [00:00<00:01,  7.70it/s]

[tensor([0.1669, 0.1644, 0.1679, 0.1666, 0.1672, 0.1669])]
9
neigh ok
rw [0, 12, 6, 2, 11, 10, 7, 9, 5]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8]
Nlast [6, 10, 2, 11, 7, 5]
Stack [0, 12, 6, 2, 11, 10, 7, 9]
Checked []
Visited {0, 2, 6, 7, 9, 10, 11, 12}
[tensor([0.1667, 0.1666, 0.1680, 0.1667, 0.1662, 0.1659])]
5
neigh ok
rw [0, 12, 6, 2, 11, 10, 7, 9, 5, 4]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Nlast [6, 7, 0, 4, 9, 12]
Stack [0, 12, 6, 2, 11, 10, 7, 9, 5]
Checked []
Visited {0, 2, 5, 6, 7, 9, 10, 11, 12}
[tensor([0.1675, 0.1675, 0.1673, 0.1685, 0.1646, 0.1646])]
4
neigh ok
rw [0, 12, 6, 2, 11, 10, 7, 9, 5, 4, 8]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Nlast [8, 3, 1, 5, 14, 13]
Stack [0, 12, 6, 2, 11, 10, 7, 9, 5, 4]
Checked []
Visited {0, 2, 4, 5, 6, 7, 9, 10, 11, 12}
[tensor([0.1679, 0.1657, 0.1677, 0.1678, 0.1680, 0.1628])]
8
neigh ok
rw [0, 12, 6, 2, 11, 10, 7, 9, 5, 4, 8, 3]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
Nlast [11, 0, 7, 4, 3, 13]
Stack [0, 12, 6, 2, 11, 10, 7, 9, 5, 4, 8]
Checked []

 47%|████▋     | 7/15 [00:00<00:01,  7.74it/s]

[tensor([0.1659, 0.1665, 0.1671, 0.1671, 0.1685, 0.1649])]
11
neigh ok
rw [0, 10, 8, 6, 1, 7, 9, 3, 13, 2, 5, 11, 4]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
Nlast [10, 9, 7, 5, 4, 14]
Stack [0, 10, 8, 6, 1, 7, 9, 3, 13, 2, 5, 11]
Checked []
Visited {0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 13}
[tensor([0.1660, 0.1681, 0.1666, 0.1673, 0.1669, 0.1651])]
4
neigh ok
rw [0, 10, 8, 6, 1, 7, 9, 3, 13, 2, 5, 11, 4, 14]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
Nlast [2, 7, 11, 5, 0, 14]
Stack [0, 10, 8, 6, 1, 7, 9, 3, 13, 2, 5, 11, 4]
Checked []
Visited {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13}
[tensor([0.1677, 0.1665, 0.1653, 0.1665, 0.1679, 0.1660])]
14
pop ok
rw [0, 10, 8, 6, 1, 7, 9, 3, 13, 2, 5, 11, 4, 14, 4]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 12]
Nlast [10, 2, 3, 7, 11, 4]
Stack [0, 10, 8, 6, 1, 7, 9, 3, 13, 2, 5, 11, 4]
Checked [4]
[tensor([0.1660, 0.1681, 0.1666, 0.1673, 0.1669, 0.1651])]
4
pop ok
rw [0, 10, 8, 6, 1, 7, 9, 3, 13, 2, 5, 11, 4, 14, 4, 11]
aw [0, 1

 53%|█████▎    | 8/15 [00:00<00:00,  7.42it/s]

{1: [0, 3, 7, 10, 12, 14], 3: [1, 2, 10, 11, 12, 13], 14: [1, 2, 4, 6, 9, 12], 10: [1, 3, 5, 7, 8, 9], 0: [1, 5, 8, 9, 11, 13], 12: [1, 3, 4, 7, 13, 14], 7: [1, 2, 8, 10, 11, 12], 11: [0, 3, 5, 6, 7, 13], 13: [0, 3, 4, 9, 11, 12], 2: [3, 5, 6, 7, 8, 14], 5: [0, 2, 4, 6, 10, 11], 6: [2, 4, 5, 8, 11, 14], 4: [5, 6, 9, 12, 13, 14], 8: [0, 2, 6, 7, 9, 10], 9: [0, 4, 8, 10, 13, 14]}
[tensor([0.1673, 0.1647, 0.1666, 0.1691, 0.1679, 0.1645])]
0
neigh ok
rw [0, 9]
aw [0, 1]
Nlast [9, 11, 1, 8, 5, 13]
Stack [0]
Checked []
Visited {0}
[tensor([0.1696, 0.1661, 0.1671, 0.1695, 0.1649, 0.1628])]
9
neigh ok
rw [0, 9, 10]
aw [0, 1, 2]
Nlast [0, 10, 8, 4, 13, 14]
Stack [0, 9]
Checked []
Visited {0, 9}
[tensor([0.1675, 0.1656, 0.1647, 0.1663, 0.1667, 0.1692])]
10
neigh ok
rw [0, 9, 10, 1]
aw [0, 1, 2, 3]
Nlast [9, 1, 8, 7, 3, 5]
Stack [0, 9, 10]
Checked []
Visited {0, 9, 10}
[tensor([0.1692, 0.1657, 0.1663, 0.1693, 0.1669, 0.1626])]
1
neigh ok
rw [0, 9, 10, 1, 12]
aw [0, 1, 2, 3, 4]
Nlast [10, 0, 12, 7

 67%|██████▋   | 10/15 [00:01<00:00,  7.81it/s]

[tensor([0.1683, 0.1680, 0.1633, 0.1674, 0.1695, 0.1635])]
4
neigh ok
rw [0, 1, 11, 7, 10, 6, 5, 12, 9, 4, 13]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Nlast [11, 0, 1, 9, 13, 3]
Stack [0, 1, 11, 7, 10, 6, 5, 12, 9, 4]
Checked []
Visited {0, 1, 4, 5, 6, 7, 9, 10, 11, 12}
[tensor([0.1674, 0.1670, 0.1647, 0.1669, 0.1676, 0.1664])]
13
pop ok
rw [0, 1, 11, 7, 10, 6, 5, 12, 9, 4, 13, 4]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9]
Nlast [7, 0, 1, 6, 9, 4]
Stack [0, 1, 11, 7, 10, 6, 5, 12, 9, 4]
Checked [4]
[tensor([0.1683, 0.1680, 0.1633, 0.1674, 0.1695, 0.1635])]
4
neigh ok
rw [0, 1, 11, 7, 10, 6, 5, 12, 9, 4, 13, 4, 3]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 11]
Nlast [11, 0, 1, 9, 13, 3]
Stack [0, 1, 11, 7, 10, 6, 5, 12, 9, 4]
Checked [13]
Visited {0, 1, 4, 5, 6, 7, 9, 10, 11, 12, 13}
[tensor([0.1655, 0.1659, 0.1677, 0.1672, 0.1665, 0.1672])]
3
pop ok
rw [0, 1, 11, 7, 10, 6, 5, 12, 9, 4, 13, 4, 3, 4]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 11, 9]
Nlast [6, 12, 9, 10, 5, 4]
Stack [0, 1, 11, 7,

 73%|███████▎  | 11/15 [00:01<00:00,  7.87it/s]

[tensor([0.1663, 0.1659, 0.1660, 0.1698, 0.1677, 0.1643])]
6
neigh ok
rw [0, 12, 10, 1, 9, 6, 3]
aw [0, 1, 2, 3, 4, 5, 6]
Nlast [9, 10, 3, 7, 4, 13]
Stack [0, 12, 10, 1, 9, 6]
Checked []
Visited {0, 1, 6, 9, 10, 12}
[tensor([0.1657, 0.1662, 0.1666, 0.1688, 0.1670, 0.1657])]
3
neigh ok
rw [0, 12, 10, 1, 9, 6, 3, 8]
aw [0, 1, 2, 3, 4, 5, 6, 7]
Nlast [9, 10, 8, 6, 2, 12]
Stack [0, 12, 10, 1, 9, 6, 3]
Checked []
Visited {0, 1, 3, 6, 9, 10, 12}
[tensor([0.1679, 0.1670, 0.1665, 0.1658, 0.1674, 0.1654])]
8
neigh ok
rw [0, 12, 10, 1, 9, 6, 3, 8, 5]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8]
Nlast [1, 5, 2, 3, 4, 7]
Stack [0, 12, 10, 1, 9, 6, 3, 8]
Checked []
Visited {0, 1, 3, 6, 8, 9, 10, 12}
[tensor([0.1667, 0.1681, 0.1684, 0.1671, 0.1649, 0.1647])]
5
neigh ok
rw [0, 12, 10, 1, 9, 6, 3, 8, 5, 7]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Nlast [10, 8, 12, 7, 13, 14]
Stack [0, 12, 10, 1, 9, 6, 3, 8, 5]
Checked []
Visited {0, 1, 3, 5, 6, 8, 9, 10, 12}
[tensor([0.1677, 0.1676, 0.1670, 0.1672, 0.1670, 0.1636])]
7
nei

 80%|████████  | 12/15 [00:01<00:00,  7.41it/s]

[tensor([0.1679, 0.1637, 0.1664, 0.1675, 0.1674, 0.1671])]
13
pop ok
rw [0, 9, 6, 7, 2, 1, 11, 8, 10, 12, 3, 12, 4, 13, 5, 14, 5, 13, 4]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 11, 12, 13, 14, 13, 12, 11]
Nlast [1, 6, 8, 12, 5, 4]
Stack [0, 9, 6, 7, 2, 1, 11, 8, 10, 12, 4]
Checked [5, 4]
[tensor([0.1680, 0.1659, 0.1692, 0.1664, 0.1671, 0.1634])]
4
pop ok
rw [0, 9, 6, 7, 2, 1, 11, 8, 10, 12, 3, 12, 4, 13, 5, 14, 5, 13, 4, 12]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 11, 12, 13, 14, 13, 12, 11, 9]
Nlast [9, 1, 12, 10, 2, 13]
Stack [0, 9, 6, 7, 2, 1, 11, 8, 10, 12]
Checked [13, 12]
[tensor([0.1656, 0.1654, 0.1707, 0.1676, 0.1661, 0.1647])]
12
pop ok
rw [0, 9, 6, 7, 2, 1, 11, 8, 10, 12, 3, 12, 4, 13, 5, 14, 5, 13, 4, 12, 10]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 11, 12, 13, 14, 13, 12, 11, 9, 8]
Nlast [9, 10, 11, 3, 4, 13]
Stack [0, 9, 6, 7, 2, 1, 11, 8, 10]
Checked [3, 4, 10]
last ok
rw [0, 9, 6, 7, 2, 1, 11, 8, 10, 12, 3, 12, 4, 13, 5, 14, 5, 13, 4, 12, 10, 13, 12]
aw [0, 1, 2, 3, 4,

100%|██████████| 15/15 [00:01<00:00,  7.16it/s]

[tensor([0.1714, 0.1632, 0.1664, 0.1691, 0.1676, 0.1622])]
7
neigh ok
rw [0, 10, 8, 4, 1, 2, 12, 9, 7, 5, 6, 3, 6, 5, 7, 13]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 10, 9, 8, 12]
Nlast [0, 9, 10, 5, 3, 13]
Stack [0, 10, 8, 4, 1, 2, 12, 9, 7]
Checked [5]
Visited {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12}
last ok
rw [0, 10, 8, 4, 1, 2, 12, 9, 7, 5, 6, 3, 6, 5, 7, 13, 3, 7]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 10, 9, 8, 12, 11, 8]
node 3
last 7
Nlast [0, 9, 10, 5, 3, 13]
Stack [0, 10, 8, 4, 1, 2, 12, 9, 7, 13]
Checked [5, 13, 3]
path not valid
[tensor([0.1714, 0.1632, 0.1664, 0.1691, 0.1676, 0.1622])]
7
pop ok
rw [0, 10, 8, 4, 1, 2, 12, 9, 7, 5, 6, 3, 6, 5, 7, 13, 3, 7, 7]
aw [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 10, 9, 8, 12, 11, 8, 8]
Nlast [0, 9, 10, 5, 3, 13]
Stack [0, 10, 8, 4, 1, 2, 12, 9, 7]
Checked [7]
path not valid
{4: [2, 3, 7, 8, 9, 11], 7: [0, 4, 6, 10, 12, 14], 8: [0, 3, 4, 5, 11, 14], 2: [0, 1, 4, 6, 12, 13], 9: [1, 3, 4, 10, 11, 12], 11: [4, 5, 6, 8, 9, 10], 3: [4, 6, 




"           \n            print(is_valid_path_new(random_walk, edges))\n            \n            if len(path_buffer) >= 10:\n                graph_emb = agent.embed_graph(problem.edges)\n                reward = path_buffer.rank_path(random_walk[:])\n                rewards = torch.FloatTensor([reward]*(len(states)-1))\n                probs, values = agent.get_dist(states[:-1], graph_emb, edges)\n                for i, dist in enumerate(probs):\n                    valids = edges[states[i][-1]]\n                    a = valids.index(actions[i])\n                    m = torch.distributions.Categorical(dist)\n                    log_prob = m.log_prob(torch.tensor(a))\n                    log_probs.append(log_prob)\n\n                advantage = rewards - values\n\n                actor_loss  = -(log_probs * advantage.detach()).mean()\n                critic_loss = advantage.pow(2).mean()\n                                                          \n                actor_losses.append(act

In [None]:
print(random_walk)

In [None]:
edges

In [None]:
[0, 10, 1, 13, 9, 3, 8, 14, 5, 7, 12, 2, 12, 11, 4]

In [None]:
states[-1]

In [None]:
def test_agent(agent, problem, vertex):
    
    with torch.no_grad():
        graph_emb = agent.embed_graph(problem.edges)
        problem.path = [vertex]
        source = problem.get_state()[0]
        random_walk = [source]
        checked = ddict(list)
        stack = [source]
        visited = {source}
        ranks = {0: source} # to attempt to get maximal cover (possible to do without rank, but then no guarantees on maximality)
        revranks = {source: 0}
        
        while len(stack) > 0:
            last = stack[-1]
            lastrank = revranks[last]
            maxrank = max(ranks.keys()) + 1
            with torch.no_grad():
                probs, v = agent(get_states_emb([random_walk], graph_emb))    
            valids = problem.get_valid_actions(random_walk[-1])
            probs = probs*torch.FloatTensor(valids)
            probs_sum = torch.sum(probs)
            if probs_sum > 0:
                probs /= probs_sum
            else:
                print("All valid moves were masked, do workaround.")
                probs = probs + torch.FloatTensor(valids)
                probs /= torch.sum(probs)
                
            sorted, indices = torch.sort(probs, 1, descending=True)
                
            Nlast = list(indices[0].numpy())

            # going in depth
            for neighbor in Nlast:
                if neighbor not in visited: # found new node, then add it to the walk
                    actions.append(neighbor)
                    random_walk.append(neighbor)
                    states.append(random_walk[:])
                    stack.append(neighbor)
                    checked[last].append(neighbor)
                    visited.add(neighbor)
                    ranks[maxrank] = neighbor
                    revranks[neighbor] = maxrank
                    break
            else: # we didn't find any new neighbor and rollback
                stack.pop()
                if len(stack) > 0:
                    random_walk.append(stack[-1])
                    checked[last].append(stack[-1])

                # interconnecting nodes that are already in walk
            for r in range(maxrank-1, lastrank+1, -1):
                node = ranks[r]
                if node not in checked[last] and node in Nlast:
                    random_walk.extend([node, last])
                    checked[last].append(node)
        return random_walk

In [None]:
G = nx.random_regular_graph(6, 15)
G2 = relabel_graph(G)
print(15, graph_isomorphism_algorithm_covers(G, G2, agent, test_agent, 1))

In [None]:
x = [0, 9 , 6, 7] 

In [None]:
x.index(6)