In [30]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%env CUDA_VISIBLE_DEVICES=0

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: CUDA_VISIBLE_DEVICES=0


In [31]:
import os, sys
from IPython.display import clear_output
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from pandas import DataFrame
import torch, torch.nn as nn
import numpy as np
import random
import networkx as nx
from utils_mcts import *
from MCTS_Act_LSTM import MCTS
from problem_mcts import GraphProblem, generate_erdos_renyi_problems, generate_regular_problems, convert_graph
from network_mcts import AgentActLSTM, AgentAct
import time
import nn_utils
from collections import defaultdict as ddict

In [32]:
sys.path.insert(0, '..')
moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values

In [33]:
def replace(P, source, target):
    '''Replace last occurrence of source with source-target-source.'''
    assert source in P
    ix = len(P) - P[::-1].index(source)
    return P[:ix] + [target, P[ix - 1]] + P[ix:]

In [34]:
def covering_walk(graph, source):
    P = [0]  # supporting walk
    S = [0]  # stack of nodes to check
    node2anon = {source: 0}
    anon2node = {0: source}
    checked = dict()  # nodes that has been checked for edge
    degrees = graph.degree()
    while len(S) > 0:  # grow supporting walk in DFS manner
        curr = S[-1]
        x = max(P) + 1  # next node to check

        # check if there is a node in the neighborhood that has not been explored yet
        Ncurr = list(nx.neighbors(graph, anon2node[curr]))
        if random.uniform(0, 1) < 0.99:
            random.shuffle(Ncurr)  # option 1: random order
        else:
            Ncurr = sorted(Ncurr, key=lambda v: degrees[v], reverse=True)  # option 2: top-degree
            # Ncurr = sorted(Ncurr, key=lambda v: degrees[v], reverse=False)  # option 3: low-degree
        # print(anon2node[curr], Ncurr)
        for neighbor in Ncurr:
            if neighbor in node2anon:
                continue  # already visited
            else:
                node2anon[neighbor] = x
                anon2node[x] = neighbor
                S.append(x)
                checked.setdefault(curr, set()).add(x)
                P = replace(P, curr, x)  # move to it
                break
        else:
            S.pop()  # move back in the stack

        for u in range(x-1, curr, -1):  # u is already in the supporting walk
            # check if there is connection to already discovered nodes
            if u not in checked[curr]:  # see if we already checked this edge
                if anon2node[u] in graph[anon2node[curr]]:
                    P = replace(P, curr, u)
                checked.setdefault(curr, set()).add(u)

    cover = [anon2node[v] for v in P]
    return cover, P

In [35]:
#params
NUM_PROBLEMS = 10
NUM_EPISODES = 10
BATCH_SIZE = 32
NUM_MCSIMS = 5
NUM_UPDATES = 5
NUM_VERTICES = 15
DEGREE = 6
CPUCT = 1.0
THRESHOLD = 0.75
PATHS_BUFFER_CAPACITY = 1000
REPLAY_BUFFER_CAPACITY = 10000

In [36]:
moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values

In [37]:
#generate regular train graphs (n=15, d=6)
problem_maker = generate_regular_problems(num_vertices=NUM_VERTICES, degree=DEGREE)

In [38]:
#initialize agent
agent = AgentAct(hid_size=256, gcn_size=256, vertex_emb_size=64)

In [39]:
optimizer = torch.optim.Adam(agent.parameters(), lr=1e-4)

In [40]:
#initialize buffers
path_buffer = PathsBuffer(capacity=PATHS_BUFFER_CAPACITY, threshold=THRESHOLD)
train_buffer = ReplayBuffer(capacity=REPLAY_BUFFER_CAPACITY)

In [41]:
# loss stats
pi_losses_history = []
v_losses_history = []

In [42]:
problems = [next(problem_maker) for i in range(NUM_PROBLEMS)]

In [43]:
i = 0

for k in trange(len(problems)):
    
    problem = problems[k]
    
    edges = problem.get_edges()

    for vertex in problem.get_actions():

        path_buffer.flush()
    
        PATH_LENGTH = 2*problem.num_edges + 1        
        
        i += 1
    
        for episode in range(NUM_EPISODES):
            
            problem.path = [vertex]
        
            source = problem.get_state()[0]
            
            states = []
            actions = []
            
            with torch.no_grad():
                graph_emb = agent.embed_graph(problem.edges)
                
            mcts = MCTS(game=problem, nnet=agent, graph_emb=graph_emb,
                        numMCTSSims=NUM_MCSIMS, cpuct=CPUCT, edges=edges)
            
            trainExamples = []
                
            random_walk = [source]
            checked = ddict(list)
            stack = [source]
            visited = {source}
            ranks = {0: source} # to attempt to get maximal cover (possible to do without rank, but then no guarantees on maximality)
            revranks = {source: 0}
            
            while len(stack) > 0:
                last = stack[-1]
                lastrank = revranks[last]
                maxrank = max(ranks.keys()) + 1
                with torch.no_grad():
                    pi = mcts.getActionProb(random_walk, path_buffer)
                Nlast = [x for _,x in sorted(zip(pi, edges[random_walk[:][-1]]), reverse=True)]
                #print("Is valid", all(i in edges[random_walk[:][-1]] for i in Nlast))
                # going in depth
                flag = False
                for neighbor in Nlast:
                    if neighbor not in visited:
                        trainExamples.append([random_walk[:], pi, None])
                        random_walk.append(neighbor)
                        stack.append(neighbor)
                        checked[last].append(neighbor)
                        visited.add(neighbor)
                        ranks[maxrank] = neighbor
                        revranks[neighbor] = maxrank
                        flag = True
                        break

                # interconnecting nodes that are already in walk
                if not flag:
                    for r in range(maxrank-1, lastrank+1, -1):
                        node = ranks[r]
                        if node not in checked[last] and node in Nlast:
                            checked[last].append(node)
                            random_walk.extend([node, last])

                if not flag:
                    stack.pop()
                    if len(stack) > 0:
                        random_walk.append(stack[-1])
                        checked[last].append(stack[-1])
                        
            path_buffer.push(random_walk)
            if len(path_buffer) >= 10: 
                r = path_buffer.rank_path(random_walk)
                for x in trainExamples:
                    x[-1] = r
                train_buffer.push(trainExamples)
            
        if len(train_buffer) >= BATCH_SIZE:
            print("Start training!")
            for i in range(NUM_UPDATES):
                batch = train_buffer.sample(BATCH_SIZE)
                paths, pis, vs = zip(*batch)
                graph_emb = agent.embed_graph(problem.edges)
                out_pi, out_v = agent.get_dist(paths, graph_emb, edges)
                
                target_vs = torch.tensor(vs)
                
                losses_pi = []
                for i, p in enumerate(pis):
                    losses_pi.append(torch.tensor(p)*torch.log(out_pi[i]))
            
                loss_pi = -torch.sum(torch.stack(losses_pi))/len(pis)
                print(target_vs)
                print(out_v.view(-1))
                break
                loss_v = torch.sum((target_vs-out_v.view(-1))**2)/target_vs.size()[0]
                total_loss = loss_pi + loss_v

                pi_losses_history.append(loss_pi.item())
                v_losses_history.append(loss_v.item())

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                clear_output(True)
                plt.figure(figsize=[12, 6])
                plt.subplot(1,2,1)
                plt.title('Policy error'); plt.grid()
                plt.scatter(np.arange(len(pi_losses_history)), pi_losses_history, alpha=0.1)
                plt.plot(moving_average(pi_losses_history, span=100, min_periods=100))

                plt.subplot(1,2,2)
                plt.title('Value error'); plt.grid()
                plt.scatter(np.arange(len(v_losses_history)), v_losses_history, alpha=0.1)
                plt.plot(moving_average(v_losses_history, span=10, min_periods=10))
                plt.show()

  0%|          | 0/10 [00:00<?, ?it/s]

Start training!
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([-0.0912,  0.0102,  0.0235,  0.0278, -0.0479, -0.0618, -0.0709, -0.0642,
        -0.0289,  0.0129, -0.0615, -0.1083, -0.0521,  0.0248, -0.0509, -0.0787,
        -0.0677, -0.0204, -0.0495, -0.0024, -0.0610, -0.0630,  0.0248, -0.0607,
        -0.0618, -0.0271,  0.0038, -0.1153, -0.0653, -0.0811,  0.0110, -0.0842],
       grad_fn=<ViewBackward>)
Start training!
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([ 0.0172, -0.0787, -0.0677, -0.0615, -0.0630,  0.0278,  0.0055, -0.0755,
        -0.1083, -0.0543,  0.0129, -0.0544, -0.0713,  0.0287, -0.0757, -0.0024,
        -0.0495, -0.0509, -0.0780, -0.0539, -0.0521, -0.0289,  0.0248,  0.0151,
         0.0110, -0.0521, -0.0610,  0.0338,  0.0102,  0.0045, -0.0811,  0.0235],
     




KeyboardInterrupt: 

In [None]:
print(pi)

In [None]:
Nlast

In [None]:
pis

In [None]:
out_pi

In [None]:
def test_agent(agent, problem, vertex):
    problem.path = [vertex]
    edges = problem.get_edges()
    with torch.no_grad():
        graph_emb = agent.embed_graph(problem.edges)
        mcts = MCTS(game=problem, nnet=agent, graph_emb=graph_emb,
                        numMCTSSims=NUM_MCSIMS, cpuct=CPUCT, edges=edges)
        
        source = problem.get_state()[0]
        
        random_walk = [source]
        checked = ddict(list)
        stack = [source]
        visited = {source}
        ranks = {0: source} # to attempt to get maximal cover (possible to do without rank, but then no guarantees on maximality)
        revranks = {source: 0}
        
        while len(stack) > 0:
            last = stack[-1]
            lastrank = revranks[last]
            maxrank = max(ranks.keys()) + 1
            pi = mcts.getActionProb(random_walk, path_buffer)
            Nlast = [x for _,x in sorted(zip(pi, edges[random_walk[:][-1]]), reverse=True)]
            flag = False
            for neighbor in Nlast:
                if neighbor not in visited:
                    trainExamples.append([random_walk, pi, None])
                    random_walk.append(neighbor)
                    stack.append(neighbor)
                    checked[last].append(neighbor)
                    visited.add(neighbor)
                    ranks[maxrank] = neighbor
                    revranks[neighbor] = maxrank
                    flag = True
                    break

            if not flag:
                for r in range(maxrank-1, lastrank+1, -1):
                    node = ranks[r]
                    if node not in checked[last] and node in Nlast:
                        checked[last].append(node)
                        random_walk.extend([node, last])

            if not flag:
                stack.pop()
                if len(stack) > 0:
                    random_walk.append(stack[-1])
                    checked[last].append(stack[-1])
    return random_walk

In [None]:
p = next(problem_maker)

In [None]:
edges = p.get_edges()

In [None]:
path = test_agent(agent, p, 0)

In [None]:
is_valid_path_new(path, edges)

In [None]:
G = nx.random_regular_graph(4, 20)
G2 = relabel_graph(G)
print(15, graph_isomorphism_algorithm_covers(G, G2, agent, test_agent, 2))

In [None]:
x = {1, 4, 5}

In [None]:
list(x)