<a href="https://colab.research.google.com/github/neo-pan/RL_GNN_TSP/blob/master/graph_conv_net_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Mount Google Drive


In [0]:
from google.colab import drive
drive.mount('/gdrive')

In [0]:
import os
os.chdir("/gdrive/My Drive/NIPS2020")
current_path = os.getcwd()
data_root = os.path.join(current_path, "data")

## Install Libraries and Extensions

In [0]:
!pip install dgl-cu101 memory_profiler line_profiler

In [0]:
%load_ext memory_profiler
%load_ext line_profiler

In [0]:
%reload_ext tensorboard
%tensorboard --logdir runs

## Import libraries and Set config


In [0]:
import gc
import dgl
import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import dgl.function as dglfn

from copy import deepcopy
from tqdm.notebook import tqdm
from scipy.spatial.distance import cdist
# Set logger
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
ch.setLevel(logging.DEBUG)
if not logger.hasHandlers():
  logger.addHandler(ch)

exp_index = "s2v_dqn_50_100_cluster_3"
# Tensorboard writer
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(f"runs/{exp_index}")


In [0]:
class config():
  model_path = os.path.join(current_path, "model", exp_index)
  batch_size = 128
  node_dim = 6
  edge_dim = 4
  embed_dim = 64
  activation = F.relu
  max_bp_iter = 4
  reg_hidden = 32
  w_scale = 0.01
  num_env = 1
  decay = 0.1
  n_step = 1
  max_n = 50
  min_n = 40
  mem_size = 50000
  learning_rate = 0.0001
  max_iter = 200000
  device = torch.device("cpu:0")


## Prepare Dataset

In [0]:
from dgl.data.utils import load_graphs

train_dataset_path = os.path.join(data_root, "tsp_train_50_100_cluster.bin")
val_dataset_path = os.path.join(data_root, "tsp_val_50_100_cluster.bin")
assert os.path.exists(train_dataset_path) and os.path.exists(val_dataset_path)

train_dataset, _ = load_graphs(train_dataset_path)
val_dataset, _ = load_graphs(val_dataset_path)

In [0]:
# from torch.utils.data.dataloader import DataLoader

# def collate_fn(graphs):
#   batched_graph = dgl.batch(graphs)
#   return batched_graph

# train_set = DataLoader(train_dataset, batch_size = config.batch_size,
#             collate_fn = collate_fn)
# val_set = DataLoader(val_dataset, batch_size = config.batch_size,
#             collate_fn = collate_fn)

class GraphPool():
  def __init__(self, g_list):
    self.graph_pool = g_list
    self.size = len(g_list)

  def __len__(self):
    return len(self.graph_pool)
  
  def __getitem__(self, key):
    assert key in range(self.size)
    return self.graph_pool[key]

  def sample(self):
    idx = np.random.randint(self.size)
    return self[idx]


In [0]:
train_set = GraphPool(train_dataset)
val_set = GraphPool(val_dataset)

## Utilities

In [0]:
def edge_dist_init(edges):
  pos1 = edges.src["pos"]
  pos2 = edges.dst["pos"]
  # Calculate the Euclidean Distance between nodes
  dist = torch.sqrt((pos1-pos2).square().sum(dim=1, keepdim=True))

  return {"dist":dist}


In [0]:
def modify_graph_feature(g, states, actions=None):
  assert(g.batch_size==len(states))
  if actions:
    assert(g.batch_size==len(actions))
  # Generate graph for computing
  batch_size = g.batch_size
  # graph = deepcopy(g)
  s, d = g.edges()
  graph = dgl.DGLGraph((s,d))

  graph.ndata["pos"] = g.ndata["pos"].clone()
  graph.readonly(False)

  # Calculate cumulated node numbers
  cum_num_nodes = np.cumsum([0]+g.batch_num_nodes[:-1])

  # Transfer node index to batch index
  covered_list = []
  for i in range(batch_size):
    covered_list.append(np.add(states[i], cum_num_nodes[i]))
  all_covered = np.concatenate(covered_list, axis=0)

  # First, add edges in tour
  tour_list = []
  for i, covered in enumerate(covered_list):
    if len(covered)==1:
      continue
    srcs = covered.tolist()
    dsts = srcs[1:] + srcs[0:1]
    if len(covered)==2:
      tour_pair = np.array((srcs, dsts))
    else:
      tour_pair = np.array((srcs+dsts, dsts+srcs))
    tour_list.append(tour_pair)
  if tour_list:
    all_tour = np.concatenate(tour_list, axis=1)
    tour_in_graph = graph.has_edges_between(all_tour[0,:], all_tour[1,:])
    tour_to_add = all_tour[:,np.where(tour_in_graph==0)].squeeze()
    # logger.debug(f"tour_to_add:{tour_to_add}")
    graph.add_edges(tour_to_add[0,:], tour_to_add[1,:])
  else:
    all_tour = np.asarray([[],[]])

  # Recalculate distance
  graph.apply_edges(edge_dist_init)

  # Second, remove edges between selected nodes but not in TSP tour
  _edges_src_covered = graph.out_edges(all_covered, form="eid")
  _edges_dst_covered = graph.in_edges(all_covered, form="eid")
  edges_bothside_covered = np.intersect1d(
      _edges_src_covered,
      _edges_dst_covered,
      assume_unique=True
  )
  edges_to_reserve = graph.edge_ids(
      all_tour[0,:],
      all_tour[1,:]
  )
  edges_to_remove = np.setdiff1d(
      edges_bothside_covered,
      edges_to_reserve,
      assume_unique=True
  )
  graph.remove_edges(edges_to_remove)
  
  # Initialize node and edge features
  if actions:
    graph.init_ndata(
        ndata_name="action_select", 
        shape=(graph.number_of_nodes(), batch_size), 
        dtype="float32"
    )
  graph.init_ndata(
      ndata_name="feat", 
      shape=(graph.number_of_nodes(), config.node_dim), 
      dtype="float32"
  )
  graph.init_edata(
      edata_name="feat",
      shape=(graph.number_of_edges(), config.edge_dim),
      dtype="float32"
  )
  if actions:
    graph.ndata["action_select"].fill_(0.0)
  graph.ndata["feat"].fill_(1.0)
  graph.edata["feat"].fill_(0.0)

  # Modify graph features wrt states and actions

  if actions:
    assert len(actions)==cum_num_nodes.shape[0]
    actions = np.add(actions, cum_num_nodes)
    a_idx = np.array([actions, np.arange(batch_size)])
    graph.ndata["action_select"][a_idx[0,:], a_idx[1,:]] = 1.0

  start_nodes_covered = np.intersect1d(cum_num_nodes, all_covered, assume_unique=True)

  graph.ndata["feat"][cum_num_nodes, 2] = 0.0
  graph.ndata["feat"][start_nodes_covered ,3] = 0.0
  graph.ndata["feat"][all_covered, 4] = 0.0
  
  edges_src_covered = graph.out_edges(all_covered, form="eid")
  edges_dst_covered = graph.in_edges(all_covered, form="eid")
  edegs_oneside_covered = np.setxor1d(
    edges_src_covered,
    edges_dst_covered,
    assume_unique=True
  )

  graph.edata["feat"][edges_src_covered, 0] = 1.0
  graph.edata["feat"][edegs_oneside_covered, 2] = 1.0
  graph.edata["feat"][:, 3].fill_(1.0)

  graph.ndata["feat"][:, 0:2] = graph.ndata["pos"].clone()
  graph.edata["feat"][:, 1:2] = graph.edata["dist"].clone()
  
  return graph.to(config.device)

In [0]:
def modify_graph_feature_rebuild(g, states, actions=None):
  assert(g.batch_size==len(states))
  if actions:
    assert(g.batch_size==len(actions))
  # Generate graph for computing
  batch_size = g.batch_size
  # graph = deepcopy(g)
  s, d = g.edges()
  graph = dgl.DGLGraph((s,d))

  graph.ndata["pos"] = g.ndata["pos"].clone()
  graph.readonly(False)

  # Calculate cumulated node numbers
  cum_num_nodes = np.cumsum([0]+g.batch_num_nodes[:-1])

  # Transfer node index to batch index
  covered_list = []
  for i in range(batch_size):
    covered_list.append(np.add(states[i], cum_num_nodes[i]))
  all_covered = np.concatenate(covered_list, axis=0)

  # First, add edges in tour
  tour_list = []
  for i, covered in enumerate(covered_list):
    if len(covered)==1:
      continue
    srcs = covered.tolist()
    dsts = srcs[1:] + srcs[0:1]
    if len(covered)==2:
      tour_pair = np.array((srcs, dsts))
    else:
      tour_pair = np.array((srcs+dsts, dsts+srcs))
    tour_list.append(tour_pair)
  if tour_list:
    all_tour = np.concatenate(tour_list, axis=1)
    tour_in_graph = graph.has_edges_between(all_tour[0,:], all_tour[1,:])
    tour_to_add = all_tour[:,np.where(tour_in_graph==0)].squeeze()
    # logger.debug(f"tour_to_add:{tour_to_add}")
    graph.add_edges(tour_to_add[0,:], tour_to_add[1,:])
  else:
    all_tour = np.asarray([[],[]])

  # Recalculate distance
  graph.apply_edges(edge_dist_init)

  # Second, remove edges between selected nodes but not in TSP tour
  _edges_src_covered = graph.out_edges(all_covered, form="eid")
  _edges_dst_covered = graph.in_edges(all_covered, form="eid")
  edges_bothside_covered = np.intersect1d(
      _edges_src_covered,
      _edges_dst_covered,
      assume_unique=True
  )
  edges_to_reserve = graph.edge_ids(
      all_tour[0,:],
      all_tour[1,:]
  )
  edges_to_remove = np.setdiff1d(
      edges_bothside_covered,
      edges_to_reserve,
      assume_unique=True
  )
  graph.remove_edges(edges_to_remove)
  
  # Initialize node and edge features
  if actions:
    graph.init_ndata(
        ndata_name="action_select", 
        shape=(graph.number_of_nodes(), batch_size), 
        dtype="float32"
    )
  graph.init_ndata(
      ndata_name="feat", 
      shape=(graph.number_of_nodes(), config.node_dim), 
      dtype="float32"
  )
  graph.init_edata(
      edata_name="feat",
      shape=(graph.number_of_edges(), config.edge_dim),
      dtype="float32"
  )
  if actions:
    graph.ndata["action_select"].fill_(0.0)
  graph.ndata["feat"].fill_(1.0)
  graph.edata["feat"].fill_(0.0)

  # Modify graph features wrt states and actions

  if actions:
    assert len(actions)==cum_num_nodes.shape[0]
    actions = np.add(actions, cum_num_nodes)
    a_idx = np.array([actions, np.arange(batch_size)])
    graph.ndata["action_select"][a_idx[0,:], a_idx[1,:]] = 1.0

  start_nodes_covered = np.intersect1d(cum_num_nodes, all_covered, assume_unique=True)

  graph.ndata["feat"][cum_num_nodes, 2] = 0.0
  graph.ndata["feat"][start_nodes_covered ,3] = 0.0
  graph.ndata["feat"][all_covered, 4] = 0.0
  
  edges_src_covered = graph.out_edges(all_covered, form="eid")
  edges_dst_covered = graph.in_edges(all_covered, form="eid")
  edegs_oneside_covered = np.setxor1d(
    edges_src_covered,
    edges_dst_covered,
    assume_unique=True
  )

  graph.edata["feat"][edges_src_covered, 0] = 1.0
  graph.edata["feat"][edegs_oneside_covered, 2] = 1.0
  graph.edata["feat"][:, 3].fill_(1.0)

  graph.ndata["feat"][:, 0:2] = graph.ndata["pos"].clone()
  graph.edata["feat"][:, 1:2] = graph.edata["dist"].clone()
  
  return graph.to(config.device)

## Build Net Model

In [0]:
class S2VLayer(nn.Module):
  def __init__(self, embed_dim, activation):
    super().__init__()
    self.activation = activation
    self.node_linear_conv = nn.Linear(embed_dim, embed_dim, bias=False)
    self.trans_node_1 = nn.Linear(embed_dim, embed_dim, bias=False)
    self.trans_node_2 = nn.Linear(embed_dim, embed_dim, bias=False)

  def edge_msg(self, edges):
    msg = torch.add(edges.src["node_msg"], edges.data["edge_init"])
    msg = self.activation(msg)
    return {"edge_msg": msg}

  def forward(self, g, cur_node_embed):
    with g.local_scope():
      # Node embedding linear transformation
      node_msg = self.node_linear_conv(cur_node_embed)
      # logger.debug(f"n_node: {g.number_of_nodes()}")
      
      g.ndata["node_msg"] = node_msg
      # Calculate edge message
      g.apply_edges(self.edge_msg)
      # Message Passing
      g.update_all(
          message_func=dglfn.copy_e(e="edge_msg", out="msg"),
          reduce_func=dglfn.sum(msg="msg", out="node_reduce")
      )
      # Linear transformation and element-wise add
      node_linear = torch.add(
          self.trans_node_1(g.ndata["node_reduce"]),
          self.trans_node_2(cur_node_embed)
      )
      return self.activation(node_linear)


In [0]:
class S2VTSPNet(nn.Module):
  def __init__(self, node_dim, edge_dim, embed_dim,
            activation, max_bp_iter, reg_hidden):
    super().__init__()
    self.activation = activation
    self.max_bp_iter = max_bp_iter
    self.node_to_latent = nn.Linear(node_dim, embed_dim, bias=False)
    self.edge_to_latent = nn.Linear(edge_dim, embed_dim, bias=False)
    self.s2v_layer = S2VLayer(embed_dim, self.activation)

    if(reg_hidden>0):
      self.reg_layers = nn.Sequential(
          nn.Linear(embed_dim, reg_hidden, bias=False),
          nn.Linear(reg_hidden, 1, bias=False)
      )
    else:
      self.reg_layers = nn.Linear(embed_dim, 1, bias=False)

  def _forward(self, g):
    # Feature to latent
    node_init = self.node_to_latent(g.ndata["feat"])
    edge_init = self.edge_to_latent(g.edata["feat"])
    g.edata["edge_init"] = edge_init
    cur_node_embed = self.activation(node_init)
    # Graph message passing
    for i in range(self.max_bp_iter):
      cur_node_embed = self.s2v_layer(g, cur_node_embed)

    return cur_node_embed

  def forward(self, g):
    cur_node_embed = self._forward(g)
    # Action embed
    action_embed = torch.matmul(
        g.ndata["action_select"].t(), cur_node_embed
    )
    # Predicted q value given a
    q_pred = action_embed
    for layer in self.reg_layers:
      q_pred = layer(q_pred)
    return q_pred

  def predict_all(self, g):
    cur_node_embed = self._forward(g)
    # Q value for all actions
    q_on_all = cur_node_embed
    for layer in self.reg_layers:
      q_on_all = layer(q_on_all)
    return q_on_all


In [0]:
class DQNTSPModel(nn.Module):
  def __init__(self, node_dim, edge_dim, embed_dim,
            activation, max_bp_iter, reg_hidden):
    super().__init__()
    self.eval_net = S2VTSPNet(node_dim, edge_dim, embed_dim,
            activation, max_bp_iter, reg_hidden)
    self.target_net = S2VTSPNet(node_dim, edge_dim, embed_dim,
            activation, max_bp_iter, reg_hidden)
    
    self.target_net.requires_grad_(False)

  def forward(self, graph):
    q_pred = self.eval_net(graph)
    del graph

    return q_pred

  def predict(self, graph):
    with torch.no_grad():
      q_on_all = self.target_net.predict_all(graph)
      q_on_all = torch.where(
          graph.ndata["feat"][:,4].unsqueeze(-1)==0.0,
          torch.tensor(float("-inf")).to(config.device),
          q_on_all
      )
      del graph

      return q_on_all

  def update_target_network(self):
    self.target_net.load_state_dict(
        self.eval_net.state_dict()
    )

## TSP Environment

In [0]:
class TSPEnv():
  def __init__(self, norm):
    self.norm = norm
    self.graph = None # Graph used in GNN
    self.graph_size = 0
    self.dist_matrix = None
    self.state_seq = []
    self.act_seq = []
    self.action_list = []
    self.reward_seq = []
    self.sum_rewards = []
    self.partial_set = set()
    self.trajectory_len = 0
  
  def clear(self):
    self.dist_matrix = None
    self.state_seq.clear()
    self.act_seq.clear()
    self.action_list.clear()
    self.reward_seq.clear()
    self.sum_rewards.clear()
    self.partial_set.clear()
    self.action_list.append(0)
    self.partial_set.add(0)
    self.trajectory_len = 0

  def reset(self, graph):
    self.clear()
    self.graph = graph
    self.graph_size = graph.number_of_nodes()
    self.dist_matrix = cdist(graph.ndata["pos"], graph.ndata["pos"])

  @property
  def has_graph(self):
    return self.graph is not None

  def _add_node(self, node):
    # Insert helper function
    # logger.debug(f"add new node id: {node}")

    srcs = self.action_list
    dsts = srcs[1:] + srcs[0:1]
    
    srcs = np.array(srcs)
    dsts = np.array(dsts)

    # Find insert position that minimize the additional cost
    # logger.debug(f"num_nodes:{self.graph.number_of_nodes()}, num_edges:{self.graph.number_of_edges()}")
    # src_dst_node = np.vstack(np.broadcast_arrays(
    #     src,
    #     dst,
    #     node
    # ))
    cost_list = self.dist_matrix[srcs, node] \
      + self.dist_matrix[dsts, node] \
      - self.dist_matrix[srcs, dsts]
    assert(cost_list.ndim==1)
    pos = np.argmin(cost_list)
    cost = cost_list[pos]
    if cost < 0:
      logger.warning(f"cost less than zero\n \
      src:{src_dst_node[0,pos]}, \
      dst:{src_dst_node[1,pos]}, \
      node:{src_dst_node[2,pos]}")
      cost = 0
    # for i, (src, dst) in enumerate(tour_pair):
    #   cost = self.dist_matrix[node, src] \
    #     + self.dist_matrix[node, dst] \
    #     - self.dist_matrix[src, dst]
    #   if(cost < cur_dist):
    #     cur_dist = cost
    #     pos = i
    #     if (cur_dist<0):
    #       logger.warning(f"distance less than 0\n \
    #         node:{node}, src:{src}, dst:{dst}, cost:{cost}\n \
    #         {self.dist_matrix[node, src]}\n \
    #         {self.dist_matrix[node, dst]}\n \
    #         {self.dist_matrix[src, dst]}"
    #       )
    #       cur_dist=0

    assert(pos>=0)
    self.action_list.insert(pos+1, node)
    self.partial_set.add(node)

    return cost / self.norm

  def step(self, a):
    assert(self.has_graph)
    assert(not self.is_terminal)
    assert(a not in self.partial_set)
    assert(a > 0)
    assert(a < self.graph_size), f"a:{a}, num_nodes:{self.graph_size}"
    
    self.state_seq.append(self.action_list.copy())
    self.act_seq.append(a)

    r_t = self._add_node(a)

    self.reward_seq.append(r_t)
    self.sum_rewards.append(r_t)
    self.trajectory_len += 1

    assert(
        len(self.action_list)
        ==len(self.partial_set)
        ==self.trajectory_len+1
    )
    assert(
        len(self.state_seq)
        ==len(self.act_seq)
        ==len(self.reward_seq)
        ==len(self.sum_rewards)
        ==self.trajectory_len
    )

    return r_t

  def random_action(self):
    assert(self.has_graph)
    avail_list= []
    for i in range(self.graph_size):
      if i not in self.partial_set:
        avail_list.append(i)
    idx = np.random.randint(len(avail_list))
    return avail_list[idx]

  @property
  def is_terminal(self):
    assert(self.has_graph)
    return (self.trajectory_len+1)==self.graph_size


## Replay Memory and Simulator

In [0]:
class Sample():
  def __init__(self, batch_size):
    self.size = batch_size
    self.g_list = []
    self.s_list = []
    self.a_list = []
    self.r_list = []
    self.s_prime_list = []
    self.t_list = []

class ReplayMemory():
  def __init__(self, mem_size):
    self.capacity = mem_size
    self._memory = [None for i in range(mem_size)]
    self.size = 0
    self.idx = 0

  def store(self, experience):
    # experience=(g, s, a, r, s', t)
    assert(len(experience)==6)
    self._memory[self.idx] = deepcopy(experience)
    self.idx += 1
    self.idx = self.idx % self.capacity
    self.size = max(self.size, self.idx)

  def add_env(self, env):
    assert(env.is_terminal)
    num_steps = env.trajectory_len
    for i in range(num_steps-2,-1,-1):
      # Cauculate cumsum reward
      env.sum_rewards[i] = env.sum_rewards[i+1] + env.sum_rewards[i]
    for i in range(num_steps):
      if i+config.n_step >= num_steps:
        r = env.sum_rewards[i]
        s_prime = env.action_list
        t = True
      else:
        r = env.sum_rewards[i] - env.sum_rewards[i+config.n_step]
        s_prime = env.state_seq[i+config.n_step]
        t = False
      # experience=(g, s, a, r, s', t)
      self.store((env.graph, env.state_seq[i], env.act_seq[i], r, s_prime, t))

  def clear(self):
    self.idx = 0
    self.size = 0

  def sample(self, batch_size):
    sample = Sample(batch_size)
    for i in range(batch_size):
      idx = np.random.randint(0, self.size)
      g, s, a, r, s_prime, t = self._memory[idx]
      sample.g_list.append(g)
      sample.s_list.append(s)
      sample.a_list.append(a)
      sample.r_list.append(r)
      sample.s_prime_list.append(s_prime)
      sample.t_list.append(t)

    return sample


In [0]:
class Simulator():
  def __init__(self, num_env, replay_mem):
    self.num_env = num_env
    self.env_list = [TSPEnv(config.max_n) for _ in range(num_env)]
    self.state_list = [None for _ in range(num_env)]
    self.replay_mem = replay_mem

  def reset(self):
    self.env_list = [TSPEnv(config.max_n) for _ in range(self.num_env)]
    self.state_list = [None for _ in range(self.num_env)]
    
  def run(self, num_seqs, eps, model, graph_pool):
    assert(eps>=0.0 and eps<=1.0)
    n = 0
    while n<num_seqs:
      for i in range(self.num_env):
        if (not self.env_list[i].has_graph) or (self.env_list[i].is_terminal):
          if (self.env_list[i].has_graph) and (self.env_list[i].is_terminal):
            n+=1
            self.replay_mem.add_env(self.env_list[i])
          # Reset env graph
          graph = graph_pool.sample()
          self.env_list[i].reset(graph)
          self.state_list[i] = self.env_list[i].action_list
      if n>=num_seqs:
        break

      random = np.random.uniform(low=0.0, high=1.0) < eps
      if random:
        for i in range(self.num_env):
          a = self.env_list[i].random_action()
          self.env_list[i].step(a)
      else:
        g_batch = dgl.batch([env.graph for env in self.env_list])
        cum_num_nodes = np.cumsum([0] + g_batch.batch_num_nodes)
        q_pred = model.predict(g_batch, self.state_list)
        for i in range(self.num_env):
          a = torch.argmax(q_pred[cum_num_nodes[i]:cum_num_nodes[i+1]])
          self.env_list[i].step(a)

## Train

In [0]:
if not os.path.exists(config.model_path):
  os.mkdir(config.model_path)

In [0]:
def validate(graph_pool, model):
  batch_size = graph_pool.size
  g_batch = dgl.batch(graph_pool.graph_pool)
  cum_num_nodes = np.cumsum([0] + g_batch.batch_num_nodes)
  env_list = [TSPEnv(config.max_n) for i in range(batch_size)]
  s_list = []
  r_list = [0 for i in range(batch_size)]
  t_list = [False for i in range(batch_size)]
  for i in range(batch_size):
    env_list[i].reset(graph_pool[i])
    s_list.append(env_list[i].action_list)
  
  while False in t_list:
    # logger.debug(r_list)
    graph = modify_graph_feature(g_batch, s_list)
    q_list = model.predict(graph)
    # logger.debug(q_list.shape)
    for i in range(batch_size):
      if not env_list[i].is_terminal:
        a = torch.argmax(q_list[cum_num_nodes[i]:cum_num_nodes[i+1]]).cpu()
        # logger.debug(f"action of env {i}: {a}")
        r = env_list[i].step(a)
        # logger.debug(f"reward of env {i}: {r}")
        r_list[i] += r * config.max_n
      else:
        if not t_list[i]:
          t_list[i] = True
  del env_list
  del g_batch
  return np.sum(r_list)

In [0]:
model = DQNTSPModel(
    config.node_dim,
    config.edge_dim,
    config.embed_dim,
    config.activation,
    config.max_bp_iter,
    config.reg_hidden
)

In [0]:
model.to(config.device)

In [0]:
replay_memory = ReplayMemory(config.mem_size)
simulator = Simulator(config.num_env, replay_memory)
test_env = TSPEnv(config.max_n)

In [0]:
# Replay Memory Startup
simulator.reset()
for i in tqdm(range(10)):
  simulator.run(100, 1, model, train_set)

In [0]:
replay_memory.size

In [0]:
lr = config.learning_rate
eps_start = 1.0
eps_end = 1.0
eps_step = 10000.0
current_iter = 0

loss_fn = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.eval_net.parameters(), lr=lr)

In [0]:
model_load_path = "/gdrive/My Drive/NIPS2020/model/s2v_dqn_50_100_cluster_2/nrange_40_50_iter_34400.pth"

if os.path.exists(model_load_path):
  checkpoint = torch.load(model_load_path)
  current_iter = checkpoint['iter']+1
  model.load_state_dict(checkpoint["model_state_dict"])
  optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [0]:
for iter in tqdm(range(current_iter, config.max_iter+1)):
# def prun_test(
#     iter=1,
#     lr=lr,
#     eps_start=eps_start,
#     eps_end=eps_end,
#     eps_step=eps_step,
#     loss_fn=loss_fn,
#     optimizer=optimizer,):
  lr = optimizer.state_dict()["param_groups"][0]["lr"]

  # Get samples from replay memory
  sample = replay_memory.sample(config.batch_size)
  g_batch = dgl.batch(sample.g_list)
  batch_size = g_batch.batch_size
  cum_num_nodes = np.cumsum([0] + g_batch.batch_num_nodes)
  graph_predict = modify_graph_feature(g_batch, sample.s_prime_list)
  graph_forward = modify_graph_feature(g_batch, sample.s_list, sample.a_list)

  # Target Network
  q_pred = model.predict(graph_predict)
  q_rhs = np.zeros((sample.size, 1))

  for i in range(batch_size):
    if not sample.t_list[i]:
      q_rhs[i] = config.decay \
        * torch.max(q_pred[cum_num_nodes[i]:cum_num_nodes[i+1]]).cpu()

  q_target = np.add(
      q_rhs,
      np.asarray(sample.r_list)[...,np.newaxis]
  ).astype(np.float32)

  q_target = torch.from_numpy(q_target).to(config.device)

  q_eval = model.forward(graph_forward)

  loss = loss_fn(q_eval, q_target)
  # tensorboard log
  writer.add_scalar("loss", loss, global_step=iter)
  
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  eps = eps_end + max(0., (eps_start - eps_end) * (eps_step - iter) / eps_step)
  if iter % 10 == 0:
    simulator.run(10, eps, model, train_set)
  if iter % 100 == 0:
    tour_len = validate(val_set, model)
    writer.add_scalar("val_tour_len", tour_len, global_step=iter)
    model_path = os.path.join(
      config.model_path,
      f"nrange_{config.min_n}_{config.max_n}_iter_{iter}.pth"
    )

    tqdm.write(
      f"\riter: {iter}, lr: {lr}, eps: {eps},\
  average tour length: {tour_len/len(val_set)},\
  loss: {loss},\
  model saved: {model_path}",
      end=''
    )

    torch.save({
        'iter': iter,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
      },
      model_path
    )
    gc.collect()
  if iter % 1000 == 0:
    model.update_target_network()
    lr = lr * 0.95
    for p_group in optimizer.param_groups:
      p_group["lr"] = lr


In [0]:
validate(val_set, model)

## Debug Demo

In [0]:
%lprun -f prun_test -f modify_graph_feature -f validate prun_test(iter=0)

In [0]:
samplep = replay_memory.sample(128)
g_batch = dgl.batch(samplep.g_list)
g_sample = samplep.g_list[1]

In [0]:
nx.draw(g_sample.to_networkx(), pos={i:g_sample.ndata["pos"][i] for i in range(g_sample.number_of_nodes())})

In [0]:
%reload_ext autoreload
%autoreload 1
%aimport mprun_demo

In [0]:
import mprun_demo
from mprun_demo import prun_test, model, validate, simulator

In [0]:
import tracemalloc

tracemalloc.start(25)

for i in tqdm(range(5000)):
  mgraph = modify_graph_feature(dgl.batch(samplep.g_list),samplep.s_list,samplep.a_list)

snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')

print("[ Top 10 ]")
for stat in top_stats[:10]:
    print(stat)

In [0]:
for stat in top_stats[0:10]:
  print(f"------Block------")
  for line in stat.traceback.format():
    print(line)

In [0]:
for line in tracemalloc.get_object_traceback(graph):
  print(line)

In [0]:
# for i in tqdm(range(5000)):
%lprun -f modify_graph_feature mgraph = modify_graph_feature(dgl.batch(samplep.g_list),samplep.s_list,samplep.a_list)

In [0]:
%mprun -f prun_test -f model.forward -f model.predict -f validate -f simulator.run prun_test(iters=10)

In [0]:
pos_x = np.arange(3)
pos_y = np.arange(3)
pos = np.meshgrid(pos_x, pos_y)
pos = np.concatenate((pos[0].flatten()[..., np.newaxis], pos[1].flatten()[..., np.newaxis]), axis=1)
gdemo = dgl.transform.knn_graph(torch.tensor(pos), 8)
# nx.draw(gdemo.to_networkx())
gdemo = dgl.transform.remove_self_loop(gdemo)
# nx.draw(gdemo.to_networkx())
gdemo = dgl.transform.to_bidirected(gdemo)

gdemo.ndata["pos"] = pos.astype(np.float32)
gdemo.apply_edges(edge_dist_init)
gnet = gdemo.to_networkx(edge_attrs=["dist"])
dmatrix = nx.floyd_warshall_numpy(gnet, weight="dist")

nx.draw(gdemo.to_networkx(), pos={i : pos[i] for i in range(9)})
gdemo.number_of_edges()


In [0]:
states = np.arange(5).tolist()
states = [0,1,2,5,4,3]
states

In [0]:
actions = 3

In [0]:
def nx_test(g):
  graph = deepcopy(g.to_networkx())

def rebuild_test(g):
  s, d = g.edges()
  graph = dgl.DGLGraph((s,d))

%timeit nx_test(g_batch)
%timeit rebuild_test(g_batch)

In [0]:
import sys
import time

In [0]:
tracemalloc.get_traced_memory()

## Result Visualization

In [0]:
r, s_list, a_list = validate(val_set, model)

In [0]:
def plot_tsp(p, x_coord, W, W_val, W_target, title="default"):
    """
    Helper function to plot TSP tours.
    
    Args:
        p: Matplotlib figure/subplot
        x_coord: Coordinates of nodes
        W: Edge adjacency matrix
        W_val: Edge values (distance) matrix
        W_target: One-hot matrix with 1s on groundtruth/predicted edges
        title: Title of figure/subplot
    
    Returns:
        p: Updated figure/subplot
    
    """

    def _edges_to_node_pairs(W):
        """Helper function to convert edge matrix into pairs of adjacent nodes.
        """
        pairs = []
        for r in range(len(W)):
            for c in range(len(W)):
                if W[r][c] == 1:
                    pairs.append((r, c))
        return pairs
    
    G = nx.from_numpy_matrix(W_val)
    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))
    adj_pairs = _edges_to_node_pairs(W)
    target_pairs = _edges_to_node_pairs(W_target)
    colors = ['g'] + ['b'] * (len(x_coord) - 1)  # Green for 0th node, blue for others
    nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=50)
    nx.draw_networkx_edges(G, pos, edgelist=adj_pairs, alpha=0.3, width=0.5)
    nx.draw_networkx_edges(G, pos, edgelist=target_pairs, alpha=1, width=1, edge_color='r')
    p.set_title(title)
    return p

In [0]:
graph_index = 1

for i in range(len(s_list[graph_index])-1):
  a = s_list[graph_index][i]

  tour = list(zip(a[:], a[1:]+a[0:1]))
  w_target = np.zeros((50,50))
  for p in tour:
    w_target[p[0], p[1]]=1

  fig = plt.subplot(111)
  plot_tsp(
      fig,
      val_set[graph_index].ndata["pos"],
      val_set[graph_index].adjacency_matrix_scipy(return_edge_ids=False).todense().tolist(),
      cdist(val_set[graph_index].ndata["pos"],val_set[graph_index].ndata["pos"]),
      w_target,
      title=f"step {i}"
  )
  plt.show()

In [0]:
from termcolor import colored

for i, state in enumerate(s_list[graph_index]):
  if i==0:
    for i,n in enumerate(state):
      print(n, end=" ")
      if (i+1)%20==0:
        print("")
  else:
    action = a_list[graph_index][i-1].tolist()
    for i,n in enumerate(state):
      if not n in action:
        print(n, end=" ")
      else:
        print(colored(n, color="red"), end=" ")
      if (i+1)%20==0:
        print("")
  print("\n")
  