In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [7]:
import os, sys
from IPython.display import clear_output
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from pandas import DataFrame
import torch, torch.nn as nn
import numpy as np
import random
import networkx as nx
from utils_ac import random_walk, ReplayBuffer, PathsBuffer, get_states_emb, convert_to_walk, is_valid_path, covering_walk 
from problem_ac import GraphProblem, generate_erdos_renyi_problems, generate_regular_problems
from network_ac import ActorCriticAct
import time

In [8]:
sys.path.insert(0, '..')
moving_average = lambda x, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(**kw).mean().values

In [9]:
#params
NUM_PROBLEMS = 15
NUM_EPISODES = 50
NUM_VERTICES = 15
DEGREE = 6
THRESHOLD = 0.75
PATHS_BUFFER_CAPACITY = 100
REPLAY_BUFFER_CAPACITY = 100

In [10]:
problem_maker = generate_regular_problems(num_vertices=NUM_VERTICES, degree=DEGREE)

In [11]:
agent = ActorCriticAct(hid_size=256, gcn_size=256, vertex_emb_size=64)

In [12]:
optimizer = torch.optim.Adam(agent.parameters(), lr=1e-4)

In [13]:
#initialize buffers
path_buffer = PathsBuffer(capacity=PATHS_BUFFER_CAPACITY, threshold=THRESHOLD)
train_buffer = ReplayBuffer(capacity=REPLAY_BUFFER_CAPACITY)

In [14]:
actor_losses = []
critic_losses = []

In [15]:
problems = [next(problem_maker) for i in range(NUM_PROBLEMS)]

In [38]:
i = 0

for k in trange(len(problems)):
    
    problem = problems[k]
    
    edges = problem.get_edges()

    for vertex in problem.get_actions():

        path_buffer.flush()
        
        PATH_LENGTH = 2*problem.num_edges + 1        
        i += 1 
        
        for episode in range(NUM_EPISODES):
            
            problem.path = [vertex]

            path = problem.get_state()
            
            states = []
            actions = []
            
            with torch.no_grad():
                graph_emb = agent.embed_graph(problem.edges)

            for _ in range(PATH_LENGTH-1):
                with torch.no_grad():
                    probs, v = agent.get_dist([path], graph_emb, edges)
                
                valids = edges[path[-1]]
                    
                dist = torch.distributions.Categorical(probs[0])

                next_vertex = dist.sample().item()
                
                next_vertex = valids[next_vertex]

                actions.append(next_vertex)
                
                states.append(path)

                path = problem.get_next_state(path, next_vertex)
                
            path_buffer.push(path)
            if len(path_buffer) >= 10: 
                reward = path_buffer.rank_path(path)
                rewards = torch.FloatTensor([reward]*(PATH_LENGTH-1))
                
                graph_emb = agent.embed_graph(problem.edges)
                
                probs, values = agent.get_dist(states, graph_emb, edges)
                
                log_probs = []
                for i, dist in enumerate(probs):
                    valids = edges[states[i][-1]]
                    a = valids.index(actions[i])
                    m = torch.distributions.Categorical(dist)
                    log_prob = m.log_prob(torch.tensor(a))
                    log_probs.append(log_prob)

                advantage = rewards - values.squeeze(1)

                actor_loss  = -(torch.stack(log_probs) * advantage.detach()).mean()
                critic_loss = advantage.pow(2).mean()
                                                          
                actor_losses.append(actor_loss.item())
                critic_losses.append(critic_loss.item())

                loss = actor_loss + critic_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            path_buffer.push(path)
        if i % 25 == 0:   
            print(path_buffer.buffer)
        if i % 5 == 0:
            clear_output(True)
            plt.figure(figsize=[12, 6])
            plt.subplot(1,2,1)
            plt.title('Actor loss'); plt.grid()
            plt.scatter(np.arange(len(actor_losses)), actor_losses, alpha=0.1)
            plt.plot(moving_average(actor_losses, span=100, min_periods=100))

            plt.subplot(1,2,2)
            plt.title('Critic loss'); plt.grid()
            plt.scatter(np.arange(len(critic_losses)), critic_losses, alpha=0.1)
            plt.plot(moving_average(critic_losses, span=10, min_periods=10))
            plt.show()

  7%|▋         | 1/15 [04:20<1:00:42, 260.21s/it]

KeyboardInterrupt: 

In [None]:
def test_agent(agent, problem, vertex):
    problem.path = [vertex]
    i = 0
    with torch.no_grad():
        graph_emb = agent.embed_graph(problem.edges)
        path = problem.get_state()
        for _ in range(PATH_LENGTH-1):
            probs, v = agent(get_states_emb([path], graph_emb))
            valids = problem.get_valid_actions(path[-1])
            probs = probs*torch.FloatTensor(valids)
            probs_sum = torch.sum(probs)
            if probs_sum > 0:
                probs /= probs_sum
            else:
                print("All valid moves were masked, do workaround.")
                probs = probs + valids
                probs /= np.sum(probs)
            dist = torch.distributions.Categorical(probs)
            next_vertex = dist.sample().item()
            path = problem.get_next_state(path, next_vertex)
    return path

In [None]:
p = next(problem_maker)

In [None]:
v = random.sample(list(p.edges.keys()), 1)[0]

In [None]:
path = test_agent(agent, p, v)

In [None]:
print(path)