In [17]:
import torch
from problem_mcts import generate_regular_problems
from utils_mcts import *
from MCTS_Act_LSTM import MCTS
from collections import defaultdict as ddict

In [2]:
agent = torch.load("./nets/regular_10_gen_regularazied.pth")

In [3]:
agent

AgentActLSTM(
  (gcn): GraphConvolutionBlock(
    (convs): ModuleList(
      (0): GraphConvolution (64 -> 256)
      (1): GraphConvolution (256 -> 256)
    )
    (activation): ELU(alpha=1.0)
    (dense): Linear(in_features=256, out_features=256, bias=True)
  )
  (lstm): LSTM(256, 256, batch_first=True)
  (critic): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=1, bias=True)
    (3): Tanh()
  )
  (actor): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=1, bias=True)
    (3): Softmax()
  )
)

In [18]:
def get_path(agent, problem, vertex):
    with torch.no_grad():
        edges = problem.get_edges()
        path_length = 2*problem.num_edges + 1
        problem.path = [vertex]
        path_buffer = []
        source = problem.get_state()[0]
        graph_emb = agent.embed_graph(problem.edges)
        mcts = MCTS(game=problem, nnet=agent, graph_emb=graph_emb,
                numMCTSSims=5, cpuct=1, edges=edges, path_length=path_length)
        random_walk = [source]
        checked = ddict(list)
        stack = [source]
        visited = {source}
        ranks = {0: source} # to attempt to get maximal cover (possible to do without rank, but then no guarantees on maximality)
        revranks = {source: 0}
        while len(stack) > 0:
            last = stack[-1]
            lastrank = revranks[last]
            maxrank = max(ranks.keys()) + 1
            pi = mcts.getActionProb(random_walk[:], path_buffer)
            Nlast = [x for _,x in sorted(zip(pi, edges[random_walk[:][-1]]), reverse=True)]
            flag = False
            for neighbor in Nlast:
                if neighbor not in visited:
                    random_walk.append(neighbor)
                    stack.append(neighbor)
                    checked[last].append(neighbor)
                    visited.add(neighbor)
                    ranks[maxrank] = neighbor
                    revranks[neighbor] = maxrank
                    flag = True
                    break

            if not flag:
                for r in range(maxrank-1, lastrank+1, -1):
                    node = ranks[r]
                    if node not in checked[last] and node in Nlast:
                        checked[last].append(node)
                        random_walk.extend([node, last])

            if not flag:
                stack.pop()
                if len(stack) > 0:
                    random_walk.append(stack[-1])
                    checked[last].append(stack[-1])
    return random_walk

In [19]:
problem_maker = generate_regular_problems(num_vertices=15, degree=6)

In [20]:
p = next(problem_maker)

In [21]:
path = get_path(agent, p, 5)

In [22]:
is_valid_path_new(path, p.edges)

True

In [23]:
check_that_cover(path, p)

True