In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import os, sys
from IPython.display import clear_output
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from pandas import DataFrame
import torch, torch.nn as nn
import numpy as np
import random
import networkx as nx
from utils_mcts import ReplayBuffer, PathsBuffer, get_states_emb, convert_to_walk
from MCTS import MCTS
from problem_mcts import GraphProblem, generate_erdos_renyi_problems, generate_regular_problems
from network_mcts import Agent
import time

In [3]:
sys.path.insert(0, '..')
moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values

In [4]:
def replace(P, source, target):
    '''Replace last occurrence of source with source-target-source.'''
    assert source in P
    ix = len(P) - P[::-1].index(source)
    return P[:ix] + [target, P[ix - 1]] + P[ix:]

In [5]:
def covering_walk(graph, source):
    P = [0]  # supporting walk
    S = [0]  # stack of nodes to check
    node2anon = {source: 0}
    anon2node = {0: source}
    checked = dict()  # nodes that has been checked for edge
    degrees = graph.degree()
    while len(S) > 0:  # grow supporting walk in DFS manner
        curr = S[-1]
        x = max(P) + 1  # next node to check

        # check if there is a node in the neighborhood that has not been explored yet
        Ncurr = list(nx.neighbors(graph, anon2node[curr]))
        if random.uniform(0, 1) < 0.99:
            random.shuffle(Ncurr)  # option 1: random order
        else:
            Ncurr = sorted(Ncurr, key=lambda v: degrees[v], reverse=True)  # option 2: top-degree
            # Ncurr = sorted(Ncurr, key=lambda v: degrees[v], reverse=False)  # option 3: low-degree
        # print(anon2node[curr], Ncurr)
        for neighbor in Ncurr:
            if neighbor in node2anon:
                continue  # already visited
            else:
                node2anon[neighbor] = x
                anon2node[x] = neighbor
                S.append(x)
                checked.setdefault(curr, set()).add(x)
                P = replace(P, curr, x)  # move to it
                break
        else:
            S.pop()  # move back in the stack

        for u in range(x-1, curr, -1):  # u is already in the supporting walk
            # check if there is connection to already discovered nodes
            if u not in checked[curr]:  # see if we already checked this edge
                if anon2node[u] in graph[anon2node[curr]]:
                    P = replace(P, curr, u)
                checked.setdefault(curr, set()).add(u)

    cover = [anon2node[v] for v in P]
    return cover, P

In [6]:
#params
NUM_PROBLEMS = 50
NUM_EPISODES = 50
BATCH_SIZE = 32
NUM_MCSIMS = 20
NUM_UPDATES = 5
NUM_VERTICES = 15
DEGREE = 6
CPUCT = 1.0
THRESHOLD = 0.75
PATHS_BUFFER_CAPACITY = 1000
REPLAY_BUFFER_CAPACITY = 10000

In [7]:
moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values

In [8]:
#generate regular train graphs (n=15, d=6)
problem_maker = generate_regular_problems(num_vertices=NUM_VERTICES, degree=DEGREE)

In [9]:
#initialize agent
agent = Agent(hid_size=256, gcn_size=256, vertex_emb_size=64, num_vertices=NUM_VERTICES)

In [10]:
optimizer = torch.optim.Adam(agent.parameters(), lr=1e-4)

In [11]:
#initialize buffers
path_buffer = PathsBuffer(capacity=PATHS_BUFFER_CAPACITY, threshold=THRESHOLD)
train_buffer = ReplayBuffer(capacity=REPLAY_BUFFER_CAPACITY)

In [12]:
# loss stats
pi_losses_history = []
v_losses_history = []

In [13]:
problems = [next(problem_maker) for i in range(NUM_PROBLEMS)]

In [14]:
start = time.time()
for iteration in trange(len(problems)):
    
    path_buffer.flush()
    
    problem = problems[iteration]
    """
    for i in range(20):
        cover, _ = covering_walk(problem.nx_graph, random.sample(list(problem.edges.keys()), 1)[0])
        path_buffer.push(cover)
    """
    
    path_length = 2*problem.num_edges+1
    
    for i in range(NUM_EPISODES):
        
        graph_emb = agent.embed_graph(problem.edges)
    
        problem.path = [random.sample(list(problem.edges.keys()), 1)[0]]
        
        mcts = MCTS(game=problem, nnet=agent, graph_emb=graph_emb,
                    numMCTSSims=NUM_MCSIMS, cpuct=CPUCT, path_length=path_length)

        trainExamples = []
        
        path = problem.get_state()
        
        while len(path) != path_length:
            with torch.no_grad():
                pi = mcts.getActionProb(path)
            trainExamples.append([path, pi, None])
            vertex = np.random.choice(len(pi), p=pi)
            path = problem.get_next_state(path, vertex)
        
        break
        path_buffer.push(path)
        if len(path_buffer) >= 10: 
            r = path_buffer.rank_path(path)
            for x in trainExamples:
                x[-1] = r
            train_buffer.push(trainExamples)
            
        if len(train_buffer) >= BATCH_SIZE:
            for i in range(NUM_UPDATES):
                batch = train_buffer.sample(BATCH_SIZE)
                paths, pis, vs = zip(*batch)
                embs = get_states_emb(paths, graph_emb)

                target_pis = torch.FloatTensor(np.array(pis))

                target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))

                out_pi, out_v = agent(embs)
                loss_pi = -torch.sum(target_pis*out_pi)/target_pis.size()[0]
                loss_v = torch.sum((target_vs-out_v.view(-1))**2)/target_vs.size()[0]
                total_loss = loss_pi + loss_v

                pi_losses_history.append(loss_pi.item())
                v_losses_history.append(loss_v.item())

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                if iteration % 5 == 0:
                    clear_output(True)
                    plt.figure(figsize=[12, 6])
                    plt.subplot(1,2,1)
                    plt.title('Policy error'); plt.grid()
                    plt.scatter(np.arange(len(pi_losses_history)), pi_losses_history, alpha=0.1)
                    plt.plot(moving_average(pi_losses_history, span=100, min_periods=100))

                    plt.subplot(1,2,2)
                    plt.title('Value error'); plt.grid()
                    plt.scatter(np.arange(len(v_losses_history)), v_losses_history, alpha=0.1)
                    plt.plot(moving_average(v_losses_history, span=10, min_periods=10))
                    plt.show()
        
    break
end = time.time()

  0%|          | 0/50 [00:00<?, ?it/s]

[0.10526315789473684, 0.0, 0.0, 0.0, 0.42105263157894735, 0.0, 0.2631578947368421, 0.0, 0.0, 0.0, 0.21052631578947367, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.34782608695652173, 0.34782608695652173, 0.0, 0.0, 0.30434782608695654, 0.0, 0.0]
[0.23076923076923078, 0.0, 0.0, 0.0, 0.3076923076923077, 0.0, 0.23076923076923078, 0.0, 0.0, 0.0, 0.23076923076923078, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.14814814814814814, 0.0, 0.0, 0.0, 0.0, 0.18518518518518517, 0.0, 0.0, 0.25925925925925924, 0.0, 0.18518518518518517, 0.2222222222222222, 0.0, 0.0]
[0.07692307692307693, 0.0, 0.0, 0.3076923076923077, 0.2692307692307692, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15384615384615385, 0.19230769230769232, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.3333333333333333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2857142857142857, 0.14285714285714285, 0.0, 0.23809523809523808, 0.0, 0.0]
[0.0, 0.0, 0.19230769230769232, 0.0, 0.0, 0.0, 0.0, 0.038461538461538464, 0.3076923076923077, 0.2692307692307692, 0.0, 0.19230769230769232, 0.0, 0.0, 0

[0.038461538461538464, 0.0, 0.0, 0.34615384615384615, 0.2692307692307692, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15384615384615385, 0.19230769230769232, 0.0, 0.0, 0.0]
[0.03571428571428571, 0.0, 0.17857142857142858, 0.0, 0.0, 0.0, 0.0, 0.10714285714285714, 0.25, 0.25, 0.0, 0.17857142857142858, 0.0, 0.0, 0.0]
[0.0, 0.18181818181818182, 0.13636363636363635, 0.4090909090909091, 0.2727272727272727, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.03571428571428571, 0.0, 0.17857142857142858, 0.0, 0.0, 0.0, 0.0, 0.10714285714285714, 0.25, 0.25, 0.0, 0.17857142857142858, 0.0, 0.0, 0.0]
[0.038461538461538464, 0.0, 0.0, 0.34615384615384615, 0.2692307692307692, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15384615384615385, 0.19230769230769232, 0.0, 0.0, 0.0]
[0.03571428571428571, 0.0, 0.17857142857142858, 0.0, 0.0, 0.0, 0.0, 0.10714285714285714, 0.25, 0.25, 0.0, 0.17857142857142858, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.375, 0.0, 0.20833333333333334, 0.25, 0.16666666666666666, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.17391




In [None]:
hours, rem = divmod(end-start, 3600)
minutes, seconds = divmod(rem, 60)
print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))

In [18]:
p = problems[10]

In [19]:
p.path = [random.sample(list(p.edges.keys()), 1)[0]]

In [20]:
graph_emb = agent.embed_graph(p.edges)
path_length = 2*p.num_edges+1
mcts = MCTS(game=p, nnet=agent, graph_emb=graph_emb,
                    numMCTSSims=NUM_MCSIMS, cpuct=CPUCT, path_length=path_length)
path = p.get_state()
while len(path) != path_length:
    with torch.no_grad():
        pi = mcts.getActionProb(path)
    vertex = np.random.choice(len(pi), p=pi)
    path = p.get_next_state(path, vertex)
print(path)

[0.05263157894736842, 0.0, 0.21052631578947367, 0.0, 0.0, 0.0, 0.2631578947368421, 0.0, 0.47368421052631576, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.14285714285714285, 0.2857142857142857, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14285714285714285, 0.17857142857142858, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.21739130434782608, 0.0, 0.13043478260869565, 0.0, 0.34782608695652173, 0.0, 0.0, 0.08695652173913043, 0.21739130434782608, 0.0, 0.0]
[0.0, 0.0, 0.14814814814814814, 0.2962962962962963, 0.25925925925925924, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14814814814814814, 0.14814814814814814, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.11538461538461539, 0.0, 0.0, 0.0, 0.15384615384615385, 0.0, 0.2692307692307692, 0.23076923076923078, 0.0, 0.0, 0.23076923076923078, 0.0, 0.0]
[0.0, 0.16, 0.16, 0.32, 0.28, 0.0, 0.0, 0.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.18518518518518517, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18518518518518517, 0.3333333333333333, 0.0, 0.0, 0.0, 0.2962962962962963, 0.0, 0.0]
[0.0, 0.0, 0.14285714285714285, 0.28

[0.0, 0.0, 0.0, 0.30434782608695654, 0.0, 0.0, 0.0, 0.0, 0.2608695652173913, 0.21739130434782608, 0.0, 0.0, 0.21739130434782608, 0.0, 0.0]
[0.0, 0.19230769230769232, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19230769230769232, 0.3076923076923077, 0.0, 0.0, 0.0, 0.3076923076923077, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.2916666666666667, 0.0, 0.0, 0.0, 0.0, 0.25, 0.20833333333333334, 0.0, 0.0, 0.25, 0.0, 0.0]
[0.0, 0.0, 0.12, 0.32, 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16, 0.16, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.18181818181818182, 0.0, 0.13636363636363635, 0.0, 0.36363636363636365, 0.0, 0.0, 0.09090909090909091, 0.22727272727272727, 0.0, 0.0]
[0.09523809523809523, 0.0, 0.19047619047619047, 0.0, 0.0, 0.0, 0.23809523809523808, 0.0, 0.47619047619047616, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.17391304347826086, 0.0, 0.13043478260869565, 0.0, 0.34782608695652173, 0.0, 0.0, 0.13043478260869565, 0.21739130434782608, 0.0, 0.0]
[0.0, 0.16666666666666666, 0.16666666666666666, 0.3333333333333333, 0.29166666666

In [None]:
path_buffer.buffer[-1]

In [16]:
torch.save(agent, "./agent.pth")

In [17]:
agent = torch.load("./agent.pth")