# Multilevel GNN

Code for FYP on _"Efficient Graph Neural Networks for Travelling Salesman Problem using Multilevel Clustering"_

## 1. Setup

In [None]:
# Set whether to install pyconcorde
INSTALL_PYCONCORDE = True
# Set whether to use CPU or GPU
USE_GPU = True

In [None]:
if INSTALL_PYCONCORDE:
  try: 
    from concorde.tsp import TSPSolver
  except:
    ! git clone https://github.com/jvkersch/pyconcorde.git
    ! cd pyconcorde && pip install -e .

In [None]:
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
from torch.distributions.categorical import Categorical

import datetime
import math
import time
import os
import sys

# For visualization of graphs
%matplotlib inline
from IPython.display import set_matplotlib_formats, clear_output
set_matplotlib_formats('png2x','pdf')
import matplotlib.pyplot as plt

# For visualization of concorde
import numpy as np
import pandas as pd
import networkx as nx
from scipy.spatial.distance import pdist, squareform

try: 
    sys.path.append(os.path.abspath('./pyconcorde'))
    from concorde.tsp import TSPSolver
except:
    print("PyConcorde not found")

# Reduce printed warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
# Device selection
device = torch.device("cpu")
gpu_id = -1

if USE_GPU:
  gpu_id = '0' # if single GPU is present
  #gpu_id = '0,1,2' # if multiple GPU are present
  os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)  
  if torch.cuda.is_available():
    device = torch.device("cuda")
    print('GPU name: {:s}, gpu_id: {:s}'.format(torch.cuda.get_device_name(0),gpu_id))   
      
  print(device)

## 2. Helper functions

In [None]:
# To perform heavy-edge-matching clustering on given nodes
def cluster_hem(nodes, device="cpu"):
  bsz, node_count, dim = nodes.size()
  clusters = torch.zeros(bsz, node_count, device=device, dtype=torch.uint8)
  nodes = torch.clone(nodes)

  curr_idx = 0
  for i in range(node_count//2):
    # Find the first non-visited node
    node_1_idx = torch.argmin(torch.isinf(nodes).long(), dim=1)[:,0]
    node_1 = nodes[torch.arange(nodes.size(0)), node_1_idx]
    node_1 = node_1.unsqueeze(1).expand(nodes.shape)

    # Find the closest node to current
    dist = nodes.add( - node_1).pow(2).sum(dim=2).pow(.5)
    node_2_idx = dist.topk(2, largest=False, sorted=False)[1][torch.arange(nodes.size(0)), 1]

    # Mark both as current cluster
    clusters[torch.arange(nodes.size(0)), node_1_idx] = curr_idx
    clusters[torch.arange(nodes.size(0)), node_2_idx] = curr_idx
    curr_idx += 1

    # Mark both nodes as visited
    nodes[torch.arange(nodes.size(0)), node_1_idx] = float('inf')
    nodes[torch.arange(nodes.size(0)), node_2_idx] = float('inf')

  # If odd number of nodes
  if node_count%2 != 0:
    last_node_idx = torch.argmin(torch.isinf(nodes).long(), dim=1)[:,0]
    clusters[torch.arange(nodes.size(0)), last_node_idx] = curr_idx
    
  return clusters

In [None]:
# To calculate coordinates of clusters (average of members)
def get_parent_coordinates_multilevel(graph_dataset_original, order_dataset, device="cpu"):
  graph_dataset = graph_dataset_original
  
  # For averaging nodes of a cluster
  avg_pool = nn.AvgPool2d((2, 1), stride=(2, 1))

  # Extracting size of each node
  node_size = graph_dataset.size(3)
  
  for i in range(1, graph_dataset.size(0)):
    # Get number of actual nodes in this level
    level_node_count = int(graph_dataset.size(2)/(2**i))

    # Average the nodes in the clusters
    pooled_nodes = avg_pool(graph_dataset[i-1])

    # Arrange nodes in order (i.e. from [0,3,2,5,3...] to [0,0,1,1,2,2...])
    ordered_idx = order_dataset[i,:,0:level_node_count]
    sorted_pooled_nodes = torch.gather(pooled_nodes, 1, ordered_idx.unsqueeze(2).repeat(1,1,node_size).long())

    # Insert into final dataset
    graph_dataset_clone = graph_dataset.clone() 
    graph_dataset_clone[i,:,0:sorted_pooled_nodes.size(1),:] = sorted_pooled_nodes
    graph_dataset = graph_dataset_clone

  return graph_dataset

In [None]:
def compute_tour_length(graphs, tours):
  """
  Computes the lengths of a batch of tours
  Inputs: 
          graphs - [batch_size, node_count, 2] Batch of TSP problem instances
          tours  - [batch_size, node_count] Batch of sequences of node indices of TSP tours
                                            tours[batch, time] contains the index of the node visited in given batch at given time
  Output: 
          tour_lengths - [batch_size] Length of each TSP tour in the batch
  """
  batch_size = graphs.shape[0]
  node_count = graphs.shape[1]

  indices = torch.arange(batch_size, device=graphs.device)
  first_cities = graphs[indices, tours[:, 0], :] # [batch_size, 2]
  previous_cities = first_cities
  tour_lengths = torch.zeros(batch_size, device=graphs.device)

  with torch.no_grad():
    for i in range(1, node_count):
      current_cities = graphs[indices, tours[:, i], :] # [batch_size, 2]
      tour_lengths += torch.sum((current_cities - previous_cities)**2, dim=1)**0.5
      previous_cities = current_cities
    tour_lengths += torch.sum((current_cities - first_cities)**2, dim=1)**0.5

  return tour_lengths

## 3. Defining model hyperparameters

In [None]:
class DotDict(dict):
  def __init__(self, **kwargs):
    self.update(kwargs)
    self.__dict__ = self

args = DotDict()

# Number of nodes in each input graph
args.node_count = 16

# Number of nodes in base level graph
args.node_count_base = 2

# Selecting the batch size of input graphs to model
if args.node_count <= 50:
  args.batch_size = 512 # TSP16 or TSP32
else:
  args.batch_size = 128 # TSP64 or TSP128

# Selecting the size of the node embeddings
args.embedding_size = 128

# Selecting number of neurons in hidden layer of encoder
args.ff_neurons = 512

# Number of coordinates to represent a node
args.input_node_size = 2

# Layers in the base level encoder and decoder
args.layers_encoder_base = 6
args.layers_decoder_base = 3

# Layers in all other levels' encoders and the decoders
args.layers_encoder = 1
args.layers_decoder = 2

# Number of heads used in Multi-Headed Attention
args.mha_heads = 8

# Number of epochs to train for
args.epochs = 100

# How many batches to train in an epoch
args.batch_per_epoch_train = 250

# How many batches to evaluate in an epoch
args.batch_per_epoch_eval = 20

# ID of the GPU selected
args.gpu_id = gpu_id

# Learning rate of the model
args.lr = 1e-9

# Tolerance needed to update baseline
args.tol = 1e-3

# Should the model use batchnorm?
args.batchnorm = True

assert (args.node_count & (args.node_count-1) == 0) and args.node_count != 0 # Input graph size is power of 2
assert (args.node_count_base & (args.node_count_base-1) == 0) and args.node_count_base != 0 # Base level graph size is power of 2
assert args.node_count_base <= args.node_count # Base level graph can't be larger than input graph

In [None]:
# Smaller sizes for debugging
DEBUG = False

if DEBUG:
  args.batch_per_epoch_train = 250
  args.epochs = 100
  args.layers_encoder = 1
  args.layers_decoder = 2
  args.lr = 1e-8

In [None]:
print(args)

## 4. Generating the test dataset

In [None]:
save_1000tsp = True

if save_1000tsp:
  # Generating the nodes
  batch_size = 1000
  generated_nodes = torch.rand(batch_size, args.node_count, args.input_node_size, device='cpu') 

  # Creating the directory
  data_dir = os.path.join("data")
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  # Saving the data
  torch.save({ 'x': generated_nodes, }, '{}.pkl'.format(data_dir + f"/1000tsp{args.node_count}"))

In [None]:
checkpoint = None

checkpoint = torch.load(f"data/1000tsp{args.node_count}.pkl")

if checkpoint is not None:
  print("Loading...")
  tsp_data = checkpoint['x'].to(device)
  tsp_data_size = tsp_data.size(1)
  print('TSP Nodes:', tsp_data_size)
  print("Loaded!")
else:
  print("Checkpoint not found, generating...")
  tsp_data = torch.rand(1000, args.node_count, args.input_node_size, device=device)
  tsp_data_size = tsp_data.size(1)
  print('TSP Nodes:', tsp_data_size)
  print("Generated!")

## 5. Creating the model

### i. Multilevel clustering algorithm

In [None]:
class Cluster(nn.Module):
  def __init__(self, input_node_size, embedding_size):
    super(Cluster, self).__init__()

    # For averaging coordinates of neighboring nodes
    self.avg_pool = nn.AvgPool2d((2, 1), stride=(2, 1))

  def forward(self, nodes, device="cpu"):
    # Defining cluster parameters
    batch_size, node_count, node_size = nodes.size()
    total_levels = int(math.ceil(math.log(int(args.node_count/args.node_count_base), 2)))+1

    # Placeholders for final clusters and nodes
    graph_dataset = torch.full((total_levels, batch_size, node_count, node_size), float("inf"), device=device)
    order_dataset = torch.full((total_levels, batch_size, node_count), node_count, device=device, dtype=torch.uint8)
    # Filled with node_count so sort operations move empty spaces to the end

    # Iterate till only 2 nodes remain
    for i in range(total_levels):
      # Get number of actual nodes in this level
      level_node_count = int(node_count/(2**i))

      # Get nodes for current row
      if i == 0:
        level_nodes = nodes # Generate nodes for first row
      else: 
        level_nodes = self.avg_pool(graph_dataset[i-1,:,:,:])[:,0:level_node_count,:] # Nodes are average of clusters from previous level

      # Cluster the nodes in this level
      clusters = cluster_hem(level_nodes, device=device)

      # Arrange nodes in order (i.e. from [0,3,2,5,3...] to [0,0,1,1,2,2...])
      sort_result = torch.sort(clusters, dim=1)
      ordered_idx = sort_result.indices
      sorted_nodes = torch.gather(level_nodes, 1, ordered_idx.unsqueeze(2).repeat(1,1,node_size))

      # Insert into final datasets
      graph_dataset[i,:,0:level_node_count,:] = sorted_nodes
      order_dataset[i,:,0:level_node_count] = ordered_idx

    return graph_dataset, order_dataset

### ii. Encoder

In [None]:
class Encoder(nn.Module):
  """
  Inputs: 
          node_embeddings - [batch_size, node_count, embedding_size] Node embeddings after passing through FF NN
          cluster_mappings - [log(node_count), batch_size, node_count]
  Output: (self-transformer for encoding the set of points)
          node_embeddings - [batch_size, node_count, embedding_size] Node embeddings after passing through FF NN
  """
  def __init__(self, layers, embedding_size, mha_heads, ff_neurons, batchnorm):
    super(Encoder, self).__init__()

    assert embedding_size == mha_heads * (embedding_size//mha_heads) # Check if embedding size is divisble by number of heads

    self.MHA_layers = nn.ModuleList([nn.MultiheadAttention(embedding_size, mha_heads) for _ in range(layers)])
    self.linear1_layers = nn.ModuleList([nn.Linear(embedding_size, ff_neurons) for _ in range(layers)])
    self.linear2_layers = nn.ModuleList([nn.Linear(ff_neurons, embedding_size) for _ in range(layers)])
    if batchnorm:
      self.norm1_layers = nn.ModuleList([nn.BatchNorm1d(embedding_size) for _ in range(layers)])
      self.norm2_layers = nn.ModuleList([nn.BatchNorm1d(embedding_size) for _ in range(layers)])
    else:
      self.norm1_layers = nn.ModuleList([nn.LayerNorm(embedding_size) for _ in range(layers)])
      self.norm2_layers = nn.ModuleList([nn.LayerNorm(embedding_size) for _ in range(layers)])

    self.layers = layers
    self.mha_heads = mha_heads

  def forward(self, node_embeddings):
    node_embeddings = node_embeddings.transpose(0,1)

    for i in range(self.layers):
      residual_connection = node_embeddings
      node_embeddings, score = self.MHA_layers[i](node_embeddings, node_embeddings, node_embeddings)
      # Adding the residual connection
      node_embeddings = residual_connection + node_embeddings
      node_embeddings = node_embeddings.permute(1,2,0).contiguous()
      node_embeddings = self.norm1_layers[i](node_embeddings)
      node_embeddings = node_embeddings.permute(2,0,1).contiguous()
      # Feedforward
      residual_connection = node_embeddings
      node_embeddings = torch.relu(self.linear1_layers[i](node_embeddings))
      node_embeddings = self.linear2_layers[i](node_embeddings)
      node_embeddings = residual_connection + node_embeddings
      node_embeddings = node_embeddings.permute(1,2,0).contiguous()
      node_embeddings = self.norm2_layers[i](node_embeddings)
      node_embeddings = node_embeddings.permute(2,0,1).contiguous()

    node_embeddings = node_embeddings.transpose(0,1)

    return node_embeddings

### iii. Decoder

#### a. Base level decoder

In [None]:
class BaseDecoder(nn.Module):
  """
  Inputs: 
          Q    - [batch_size, 1, 3*embedding_size] Query vector
          K    - [batch_size, node_count, embedding_size*layers] Key vector
          V    - [batch_size, node_count, embedding_size*layers] Value vector
          mask - [batch_size, node_count] For deciding which nodes to include
  Output:
          next_node_probabilities - [batch_size, node_count] Probabilities of which node is best for selection
  """
  def __init__(self, embedding_size, mha_heads, layers):
    super(BaseDecoder, self).__init__()
    self.embedding_size = embedding_size
    self.mha_heads = mha_heads
    self.layers = layers
    self.WO = nn.ModuleList([nn.Linear(embedding_size, embedding_size) for _ in range(layers-1)])

  def MHA(self, Q, K, V, mha_heads, mask=None, clip_value=None):
    batch_size, node_count, embedding_size = K.size() # Get the dimensions of the key
    
    # Reshape Q, K, and V for MHA
    if mha_heads > 1:
      Q = Q.transpose(1,2).contiguous()
      Q = Q.view(batch_size * mha_heads, embedding_size//mha_heads, 1)
      Q = Q.transpose(1,2).contiguous()

      K = K.transpose(1,2).contiguous()
      K = K.view(batch_size * mha_heads, embedding_size//mha_heads, node_count)
      K = K.transpose(1,2).contiguous()

      V = V.transpose(1,2).contiguous()
      V = V.view(batch_size * mha_heads, embedding_size//mha_heads, node_count)
      V = V.transpose(1,2).contiguous()

    # Multiply Q and K to get attention weights
    attention_weights = torch.bmm(Q, K.transpose(1,2)) / Q.size(-1)**0.5
      
    # Clip attention weights between [-C, C]
    if clip_value is not None:
      attention_weights = clip_value * torch.tanh(attention_weights)

    # Mask already visited nodes
    if mask is not None:
      if mha_heads > 1:
        mask = torch.repeat_interleave(mask, repeats=mha_heads, dim=0)
      attention_weights = attention_weights.masked_fill(mask.unsqueeze(1), float('-inf'))

    # Get softmax of attention weights
    attention_weights = torch.softmax(attention_weights, dim=-1)

    # Get decoder output
    attention_output = torch.bmm(attention_weights, V)

    # Reshape and get mean if MHA is used
    if mha_heads > 1:
      attention_output = attention_output.transpose(1, 2).contiguous()
      attention_output = attention_output.view(batch_size, embedding_size, 1)
      attention_output = attention_output.transpose(1, 2).contiguous()
      attention_weights = attention_weights.view(batch_size, mha_heads, 1, node_count)
      attention_weights = attention_weights.mean(dim=1) # Take mean across heads

    return attention_output, attention_weights

  def forward(self, Q, K, V, mask=None):

    for i in range(self.layers):
      Ki = K[:,:,i * self.embedding_size : (i+1) * self.embedding_size].contiguous()
      Vi = V[:,:,i * self.embedding_size : (i+1) * self.embedding_size].contiguous()
      if i < self.layers-1: # Use MHA
        residual_Q = Q
        attention_output, _ = self.MHA(Q, Ki, Vi, self.mha_heads, mask)
        Q = self.WO[i](attention_output) # MHA output
        Q = residual_Q + torch.relu(Q)
      else: # Use single head for last layer
        _, attention_weights = self.MHA(Q, Ki, Vi, 1, mask, 10)
    next_node_probabilities = attention_weights.squeeze()

    return next_node_probabilities

#### b. Multilevel decoder

In [None]:
class Decoder(nn.Module):
  """
  Inputs: 
          Q    - [batch_size, 1, 3*embedding_size] Query vector
          K    - [batch_size, node_count, embedding_size*layers] Key vector
          V    - [batch_size, node_count, embedding_size*layers] Value vector
          to_visit - [batch_size, 2] Neighbouring nodes to be visited next
  Output:
          next_node_probabilities - [batch_size, 2] Probabilities of which node is best for selection
  """
  def __init__(self, embedding_size, mha_heads, layers):
    super(Decoder, self).__init__()
    self.embedding_size = embedding_size
    self.mha_heads = mha_heads
    self.layers = layers
    self.WO = nn.ModuleList([nn.Linear(embedding_size, embedding_size) for _ in range(layers-1)])

  def MHA(self, Q, K, V, mha_heads, clip_value=None):
    batch_size, node_count, embedding_size = K.size() # Get the dimensions of the key
    
    # Reshape Q, K, and V for MHA
    if mha_heads > 1:
      Q = Q.transpose(1,2).contiguous()
      Q = Q.view(batch_size * mha_heads, embedding_size//mha_heads, 1)
      Q = Q.transpose(1,2).contiguous()

      K = K.transpose(1,2).contiguous()
      K = K.view(batch_size * mha_heads, embedding_size//mha_heads, node_count)
      K = K.transpose(1,2).contiguous()

      V = V.transpose(1,2).contiguous()
      V = V.view(batch_size * mha_heads, embedding_size//mha_heads, node_count)
      V = V.transpose(1,2).contiguous()

    # Multiply Q and K to get attention weights
    attention_weights = torch.bmm(Q, K.transpose(1,2)) / Q.size(-1)**0.5
      
    # Clip attention weights between [-C, C]
    if clip_value is not None:
      attention_weights = clip_value * torch.tanh(attention_weights)

    # Get softmax of attention weights
    attention_weights = torch.softmax(attention_weights, dim=-1)

    # Get decoder output
    attention_output = torch.bmm(attention_weights, V)

    # Reshape and get mean if MHA is used
    if mha_heads > 1:
      attention_output = attention_output.transpose(1, 2).contiguous()
      attention_output = attention_output.view(batch_size, embedding_size, 1)
      attention_output = attention_output.transpose(1, 2).contiguous()
      attention_weights = attention_weights.view(batch_size, mha_heads, 1, node_count)
      attention_weights = attention_weights.mean(dim=1) # Take mean across heads

    return attention_output, attention_weights

  def forward(self, Q, K, V, to_visit=None):
    bsz, nc, ns = K.size()

    K = torch.gather(K, 1, to_visit.unsqueeze(2).repeat(1,1,K.size(2)))
    V = torch.gather(V, 1, to_visit.unsqueeze(2).repeat(1,1,V.size(2)))

    for i in range(self.layers):
      Ki = K[:,:,i * self.embedding_size : (i+1) * self.embedding_size].contiguous()
      Vi = V[:,:,i * self.embedding_size : (i+1) * self.embedding_size].contiguous()
      if i < self.layers-1: # Use MHA
        residual_Q = Q
        attention_output, _ = self.MHA(Q, Ki, Vi, self.mha_heads)
        Q = self.WO[i](attention_output) # MHA output
        Q = residual_Q + torch.relu(Q)
      else: # Use single head for last layer
        _, attention_weights = self.MHA(Q, Ki, Vi, 1, 10)
    next_node_probabilities = attention_weights.squeeze()

    output = torch.zeros((bsz, nc), device=device)
    output[torch.arange(bsz), to_visit[:,0]] = next_node_probabilities[torch.arange(bsz),0]
    output[torch.arange(bsz), to_visit[:,1]] = next_node_probabilities[torch.arange(bsz),1]

    return output

### iv. Multilevel Graph Neural Network

In [None]:
class MGNN(nn.Module):
  def __init__(self, input_node_size, embedding_size, ff_neurons, 
               layers_encoder, layers_encoder_base,
               layers_decoder, layers_decoder_base,
               mha_heads, cluster_levels, batchnorm=True):
    super(MGNN, self).__init__()

    # Embed nodes into higher dimension
    self.input_embeddings = nn.Linear(input_node_size, embedding_size)

    # Cluster nodes for multiple levels
    self.clustering = Cluster(input_node_size, embedding_size)

    # Encoder layers
    self.encoder_base = Encoder(layers_encoder_base, embedding_size, mha_heads, ff_neurons, batchnorm)
    self.encoder = Encoder(layers_encoder, embedding_size, mha_heads, ff_neurons, batchnorm)

    # Query dimensionality projector
    self.query_projecter = nn.Linear(3*embedding_size, embedding_size)
    self.embedding_size = embedding_size

    # Embeddings of the first node of the tour and the current node
    self.place_holder_start = nn.Parameter(torch.randn(embedding_size))
    self.place_holder_now = nn.Parameter(torch.randn(embedding_size))

    # Decoder layers
    self.decoder_base = BaseDecoder(embedding_size, mha_heads, layers_decoder_base)
    self.WK_decoder_base = nn.Linear(embedding_size, layers_decoder_base * embedding_size)
    self.WV_decoder_base = nn.Linear(embedding_size, layers_decoder_base * embedding_size)

    self.decoders = nn.ModuleList([Decoder(embedding_size, mha_heads, layers_decoder) for _ in range(cluster_levels)])
    self.WK_decoder = nn.Linear(embedding_size, layers_decoder * embedding_size)
    self.WV_decoder = nn.Linear(embedding_size, layers_decoder * embedding_size)

  def forward(self, nodes, deterministic=False):
    # Sort the nodes along the X-axis
    _, sort_indices = torch.sort(nodes[:,:,0], dim=1)
    nodes = torch.gather(nodes, 1, sort_indices.unsqueeze(2).repeat(1,1,args.input_node_size))

    # Cluster the nodes
    graph_dataset, order_dataset = self.clustering(nodes, device)

    # Extract parameters from graph
    cluster_levels, batch_size, node_count, _ = graph_dataset.size()
    zero_to_batch_size = torch.arange(batch_size, device=device) # [0,1,...,batch_size-1]

    # Convert nodes to higher dimensionality
    h = self.input_embeddings(graph_dataset[0,:,:,:])

    # Encode the nodes and get level-by-level cluster coordinates (for 0...M-1)
    embedded_graph_dataset = torch.full((math.ceil(math.log(node_count, 2)), batch_size, node_count, args.embedding_size), float("inf"), device=device)
    embedded_graph_dataset[0,:,:,:] = self.encoder(h)
    h_levels_encoded = get_parent_coordinates_multilevel(embedded_graph_dataset, order_dataset, device)
    h_graph_encoded = embedded_graph_dataset[0,:,:,:].mean(dim=1) # Get the graph embedding as well
    
    # Encode the nodes and get level-by-level cluster coordinates (for M)
    embedded_graph_dataset_base = torch.full((math.ceil(math.log(node_count, 2)), batch_size, node_count, args.embedding_size), float("inf"), device=device)
    embedded_graph_dataset_base[0,:,:,:] = self.encoder_base(h)
    h_levels_encoded_base = get_parent_coordinates_multilevel(embedded_graph_dataset_base, order_dataset, device)
    h_graph_encoded_base = embedded_graph_dataset_base[0,:,:,:].mean(dim=1) # Get the graph embedding as well

    # For storing final -ve probs of choices made in each time step of the main graph
    log_prob_actions_sum_final = torch.zeros(batch_size, device=device)
    # For storing final -ve probs of choices made in each time step of the main graph
    log_prob_actions_sum_base = torch.zeros(batch_size, device=device)

    # Create list for mask of visited cities
    mask_visited_nodes = torch.zeros(batch_size, args.node_count_base, device=device).bool()

    # Iterate through each level of clustering, starting from coarsest graph
    base_level = cluster_levels-1
    for i in range(base_level, -1, -1):
      if i == base_level:
        # Extract parameters from current level
        node_count = int(graph_dataset.size(2)/(2**i))
        iterations = node_count
        nodes_added = 0

        # Get current level
        h_encoded = h_levels_encoded_base[i,:,0:node_count,:]

        # Create list for storing probs of the choices made at each time step
        log_prob_actions_sum = torch.zeros(batch_size, node_count, device=device)

        # Create list for storing tours for the batch
        tours = torch.zeros(batch_size, node_count, device=device)

        # Create initial key and value for the decoder
        K_encoded = self.WK_decoder_base(h_encoded)
        V_encoded = self.WV_decoder_base(h_encoded)

        # Store the start and current nodes of the tour
        h_start = self.place_holder_start.view(1, self.embedding_size).expand_as(h_graph_encoded_base)
        h_now = self.place_holder_now.view(1, self.embedding_size).expand_as(h_graph_encoded_base)

        # Construct the tour recursively
        for t in range(iterations):
          # Compute probability over the next node in the tour
          query = torch.cat((h_graph_encoded_base, h_start, h_now), dim=-1).unsqueeze(1)
          query = self.query_projecter(query)
          prob_next_node = self.decoder_base(query, K_encoded, V_encoded, mask_visited_nodes)

          if deterministic:
            index = torch.argmax(prob_next_node, dim=1)
          else:
            index = Categorical(prob_next_node).sample()

          # Compute log of probabilities of the action items
          prob_of_choices = prob_next_node[zero_to_batch_size, index]
          log_prob_actions_sum[:, nodes_added] = torch.log(prob_of_choices)
          
          # Update tour
          tours[:, nodes_added] = index
          nodes_added += 1

          # Update current node
          h_now = h_encoded[zero_to_batch_size, index, :]
          if t == 0: # Embedding of first node of tour
            h_start = h_now

          # Update mask of visited nodes
          mask_visited_nodes = mask_visited_nodes.clone()
          mask_visited_nodes[zero_to_batch_size, index] = True
      else:
        # Extract parameters from current level
        node_count = int(graph_dataset.size(2)/(2**i))
        iterations = int(node_count/2)
        nodes_added = 0

        # Get current level
        h_encoded = h_levels_encoded[i,:,0:node_count,:]

        # Create list for storing probs of the choices made at each time step
        log_prob_actions_sum = torch.zeros(batch_size, node_count, device=device)

        # To store nodes to be visited next, starting from 0,1
        next_neighbors = torch.arange(2, device=device).unsqueeze(0).repeat(batch_size,1).long()

        # Create list for storing tours for the batch
        tours = torch.zeros(batch_size, node_count, device=device)

        # Create initial key and value for the decoder
        K_encoded = self.WK_decoder(h_encoded)
        V_encoded = self.WV_decoder(h_encoded)

        # Store the start and current nodes of the tour
        h_start = self.place_holder_start.view(1, self.embedding_size).expand_as(h_graph_encoded)
        h_now = self.place_holder_now.view(1, self.embedding_size).expand_as(h_graph_encoded)

        # Construct the tour recursively
        for t in range(iterations):
          # Compute probability over the next node in the tour
          query = torch.cat((h_graph_encoded, h_start, h_now), dim=-1).unsqueeze(1)
          query = self.query_projecter(query)
          prob_next_node = self.decoders[i](query, K_encoded, V_encoded, next_neighbors)

          if deterministic:
            index = torch.argmax(prob_next_node, dim=1)
          else:
            index = Categorical(prob_next_node).sample()

          # Get cluster ID of the current node
          cluster_id = index//2

          # If sibling exists, add to tour and mask
          neighbors = torch.stack((2*cluster_id, 2*cluster_id+1), dim=1).long()
          neighbor_idx = torch.where(torch.flatten(neighbors[:,0])==index,1,0)

          first_child = index
          second_child = torch.gather(neighbors, 1, neighbor_idx.unsqueeze(1)).flatten()

          # Compute log of probabilities of the action items
          prob_of_choices = prob_next_node[zero_to_batch_size, first_child]
          log_prob_actions_sum[:, nodes_added] = torch.log(prob_of_choices)

          # Compute log of probabilities of the siblings of those selected
          prob_of_choices = prob_next_node[zero_to_batch_size, second_child]
          log_prob_actions_sum[:, nodes_added+1] = torch.log(prob_of_choices)

          # Update current node
          h_now = h_encoded[zero_to_batch_size, second_child, :]
          if t == 0: # Embedding of first node of tour
            h_start = h_encoded[zero_to_batch_size, first_child, :]
          
          # Update tour
          tours[:, nodes_added] = first_child
          tours[:, nodes_added+1] = second_child
          nodes_added += 2

          # Extract index of current cluster from parent tour
          parent_cluster_id = torch.argmax(torch.where(order_dataset[i+1,:,0:int(node_count/2)] == cluster_id.unsqueeze(1), 1, 0), dim=1).unsqueeze(1).repeat(1,int(node_count/2))
          parent_cluster_pos = torch.argmax(torch.where(parent_tours == parent_cluster_id, 1, 0), dim=1)

          # Get the cluster ID of the children of the next cluster in parent tour
          next_cluster_pos = (parent_cluster_pos+1)%parent_tours.size(1)
          next_parent_cluster_id = torch.gather(parent_tours, 1, next_cluster_pos.unsqueeze(1)).long()
          next_cluster_id = order_dataset[i+1,zero_to_batch_size,next_parent_cluster_id.squeeze()]

          # Get 2 nodes in current level belonging to next cluster
          next_neighbors = torch.stack((2*next_cluster_id, 2*next_cluster_id+1), dim=1).long()

          # Mark clusters to be visited next as True
          neighbors_mask = torch.zeros(batch_size, node_count, device=device, dtype=torch.bool)
          neighbors_mask[zero_to_batch_size, next_neighbors[zero_to_batch_size,0]] = True
          neighbors_mask[zero_to_batch_size, next_neighbors[zero_to_batch_size,1]] = True

      parent_tours = tours
      if i == 0:
        log_prob_actions_sum_final =  log_prob_actions_sum.sum(dim=1)
      if i == base_level:
        log_prob_actions_sum_base = log_prob_actions_sum.sum(dim=1)
        tours_base = tours

    # Creating final variables to be returned
    sorted_graph = graph_dataset[0]
    sorted_graph_base = graph_dataset[base_level,:,0:args.node_count_base]
    tours = tours.long()
    tours_base = tours_base.long()

    return (sorted_graph, 
          sorted_graph_base, 
          tours, 
          tours_base, 
          log_prob_actions_sum_final,
          log_prob_actions_sum_base)

## 6. Training the model

In [None]:
try:
  del model_train
  del model_baseline
except:
  pass

# Use more than 1 GPU if available
if torch.cuda.device_count() > 1:
  print(f"Using {torch.cuda.device_count()} GPUs...")
  model_train = nn.DataParallel(model_train)
  model_baseline = nn.DataParallel(model_baseline)


# Setting up the model to be trained
model_train = MGNN(args.input_node_size, args.embedding_size, args.ff_neurons,
                  args.layers_encoder, args.layers_encoder_base,
                  args.layers_decoder, args.layers_decoder_base,
                  args.mha_heads, int(args.node_count/args.node_count_base), batchnorm=args.batchnorm)
model_train = model_train.to(device)

# Setting up the baseline model
model_baseline = MGNN(args.input_node_size, args.embedding_size, args.ff_neurons,
                  args.layers_encoder, args.layers_encoder_base,
                  args.layers_decoder, args.layers_decoder_base,
                  args.mha_heads, int(args.node_count/args.node_count_base), batchnorm=args.batchnorm)
model_baseline = model_baseline.to(device)
model_baseline.eval()

# Create the Adam optimizer
optimizer = torch.optim.Adam(model_train.parameters(), lr = args.lr)

print(args)

# Training logs
os.system("mkdir logs")
time_stamp = datetime.datetime.now().strftime("%y-%m-%d--%H-%M-%S")
file_name = "logs" + "/" + time_stamp + "-n{}".format(args.node_count) + "-gpu{}".format(args.gpu_id) + ".txt"
file = open(file_name, "w", 1)
file.write(time_stamp + "\n\n")
for arg in vars(args):
  file.write(arg)
  hyper_param_val = "={}".format(getattr(args, arg))
  file.write(hyper_param_val)
  file.write("\n")
file.write("\n\n")

# For tracking the train and baseline tour lengths
plot_performance_train = []
plot_performance_baseline = []

# Main training loop
for epoch in range(0, args.epochs):
  #############################
  # TRAIN MODEL FOR ONE EPOCH #
  #############################
  start = time.time()
  model_train.train()

  for step in range(1, args.batch_per_epoch_train+1):
    # Generate a batch of TSP instances
    nodes = torch.rand(args.batch_size, args.node_count, args.input_node_size, device=device)

    # Compute the tours for the model
    (sorted_nodes_train, 
    sorted_nodes_train_base, 
    tours_train, 
    tours_train_base,
    log_prob_actions_sum_final,
    log_prob_actions_sum_base) = model_train(nodes, deterministic=False)

    # Compute the tours for the baseline
    with torch.no_grad():
      (sorted_nodes_baseline, 
      sorted_nodes_baseline_base, 
      tours_baseline, 
      tours_baseline_base,
      _, _) = model_train(nodes, deterministic=True)

    # Get the lengths of the tours
    train_length = compute_tour_length(sorted_nodes_train, tours_train)
    baseline_length = compute_tour_length(sorted_nodes_baseline, tours_baseline)
    train_length_base = compute_tour_length(sorted_nodes_train_base, tours_train_base)
    baseline_length_base = compute_tour_length(sorted_nodes_baseline_base, tours_baseline_base)

    # Backpropogation
    loss_final = torch.mean((train_length - baseline_length) * log_prob_actions_sum_final)
    loss_base = torch.mean((train_length_base - baseline_length_base) * log_prob_actions_sum_base)
    loss = loss_final + loss_base
           
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  epoch_time = time.time() - start

  ###############################
  # EVALUATE MODEL AND BASELINE #
  ###############################
  model_train.eval()
  mean_tour_length_train = 0
  mean_tour_length_baseline = 0

  for step in range(0, args.batch_per_epoch_eval):
    # Generate a batch of TSP instances
    nodes = torch.rand(args.batch_size, args.node_count, args.input_node_size, device=device)

    # Compute the tours of the model and the baseline
    with torch.no_grad():
      sorted_nodes_train, _, tours_train, _, _, _ = model_train(nodes, deterministic=True)
      sorted_nodes_baseline, _, tours_baseline, _, _, _ = model_baseline(nodes, deterministic=True)

    # Get the lengths of the tours
    train_length = compute_tour_length(sorted_nodes_train, tours_train)
    baseline_length = compute_tour_length(sorted_nodes_baseline, tours_baseline)

    # Add tour length
    mean_tour_length_train += train_length.mean().item()
    mean_tour_length_baseline += baseline_length.mean().item()

  # Compute mean of tour lengths across batches
  mean_tour_length_train =  mean_tour_length_train / args.batch_per_epoch_eval
  mean_tour_length_baseline =  mean_tour_length_baseline / args.batch_per_epoch_eval

  # Update baseline if new length is lower than baseline
  update_baseline = mean_tour_length_train + args.tol < mean_tour_length_baseline
  if update_baseline:
    model_baseline.load_state_dict(model_train.state_dict())

  # Compute TSP tours for the test set
  with torch.no_grad():
    sorted_tsp_data, _, tours_baseline, _, _, _ = model_baseline(tsp_data, deterministic=True)
  mean_tour_length_test = compute_tour_length(sorted_tsp_data, tours_baseline).mean().item()

  # Storing the training and baseline tour lengths
  plot_performance_train.append([(epoch+1), mean_tour_length_train])
  plot_performance_baseline.append([(epoch+1), mean_tour_length_baseline])

  # Log epoch details
  epoch_details = 'Epoch: {:d}, Epoch Time: {:.3f}min, Train Length: {:.3f}, Baseline Length: {:.3f}, Test Length: {:.3f}, Baseline Updated: {}'.format(
      epoch, epoch_time/60, mean_tour_length_train, mean_tour_length_baseline, mean_tour_length_test, update_baseline) 
  print(epoch_details)
  file.write(epoch_details + '\n')

  # Save checkpoints
  checkpoint_dir = os.path.join("checkpoint")
  if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
  torch.save({
      'epoch': epoch,
      'time': epoch_time,
      'loss': loss.item(),
      'TSP_length': [torch.mean(train_length).item(), torch.mean(baseline_length).item()],
      'plot_performance_train': plot_performance_train,
      'plot_performance_baseline': plot_performance_baseline,
      'mean_tour_length_test': mean_tour_length_test,
      'model_baseline': model_baseline.state_dict(),
      'model_train': model_train.state_dict(),
      'optimizer': optimizer.state_dict(),
  }, '{}.pkl'.format(checkpoint_dir + "/checkpoint"))

## 7. Plotting the results

### i. Load from checkpoints (if needed)

In [None]:
LOAD = True
if LOAD:
  model_train = MGNN(args.input_node_size, args.embedding_size, args.ff_neurons,
                  args.layers_encoder, args.layers_encoder_base,
                  args.layers_decoder, args.layers_decoder_base,
                  args.mha_heads, int(args.node_count/args.node_count_base), batchnorm=args.batchnorm)
  model_train = model_train.to(device)

  model_baseline = MGNN(args.input_node_size, args.embedding_size, args.ff_neurons,
                    args.layers_encoder, args.layers_encoder_base,
                    args.layers_decoder, args.layers_decoder_base,
                    args.mha_heads, int(args.node_count/args.node_count_base), batchnorm=args.batchnorm)

  optimizer = torch.optim.Adam(model_train.parameters(), lr=args.lr) 

  checkpoint = torch.load("./checkpoint/checkpoint.pkl", map_location=torch.device('cpu'))
  epoch_ckpt = checkpoint['epoch']
  time_ckpt = checkpoint['time']
  loss_ckpt = checkpoint['loss']
  TSP_length_ckpt = checkpoint['TSP_length']
  plot_performance_train_ckpt = checkpoint['plot_performance_train']
  plot_performance_baseline_ckpt = checkpoint['plot_performance_baseline']
  mean_tour_length_test = checkpoint['mean_tour_length_test']
  model_baseline.load_state_dict(checkpoint['model_baseline'])
  model_train.load_state_dict(checkpoint['model_train'])
  optimizer.load_state_dict(checkpoint['optimizer'])

  print(args); print('')
  ckpt_details = 'Epoch: {:d}, Epoch Time: {:.3f}min, Train Length: {:.3f}, Baseline Length: {:.3f}, Test Length: {:.3f}'.format(
          epoch_ckpt, time_ckpt/60, TSP_length_ckpt[0], TSP_length_ckpt[1], mean_tour_length_test) 
  print(ckpt_details)

### ii. Plot tour lengths vs. epochs

In [None]:
clear_output()
if LOAD:
  plt.plot(torch.Tensor(plot_performance_train_ckpt)[:,0], torch.Tensor(plot_performance_train_ckpt)[:,1], 'r-', label="Train")
  plt.plot(torch.Tensor(plot_performance_baseline_ckpt)[:,0], torch.Tensor(plot_performance_baseline_ckpt)[:,1], 'b-', label="Baseline")
  plt.title("Tour length vs. Epoch"); 
  plt.xlabel('Epoch')
  plt.ylabel('Length')
  plt.show()
  print(f"Min: {str(torch.Tensor(plot_performance_baseline_ckpt)[:,1].min().item())[:5]}")
  print(f"#TSP/Epoch: {str(args.batch_per_epoch_train * args.batch_size)}")
else:
  plt.plot(torch.Tensor(plot_performance_train)[:,0], torch.Tensor(plot_performance_train)[:,1], 'r-', label="Train")
  plt.plot(torch.Tensor(plot_performance_baseline)[:,0], torch.Tensor(plot_performance_baseline)[:,1], 'b-', label="Baseline")
  plt.title("Tour length vs. Epoch"); 
  plt.xlabel('Epoch')
  plt.ylabel('Length')
  plt.show()
  print(f"Min: {str(torch.Tensor(plot_performance_baseline)[:,1].min().item())[:5]}")
  print(f"#TSP/Epoch: {str(args.batch_per_epoch_train * args.batch_size)}")

In [None]:
# Plotting baseline tour length vs. epochs
if args.node_count == 16:
  color = "r-"
elif args.node_count == 32:
  color = "b-"
elif args.node_count == 64:
  color = "g-"
elif args.node_count == 128:
  color = "black-"

clear_output()
plt.plot(torch.Tensor(plot_performance_train_ckpt)[:,0], torch.Tensor(plot_performance_train_ckpt)[:,1], color, label=f"TSP{args.node_count}")
plt.title("Baseline Tour Length vs. Training Epochs")
plt.xlabel('Epoch')
plt.ylabel('Tour Length')
plt.legend()
plt.show()

### iii. Plot the TSP tours

In [None]:
if args.node_count == 16:
  width = 0.5
elif args.node_count == 32:
  width = 0.2
elif args.node_count == 64:
  width = 0.05
elif args.node_count == 128:
  width = 0.02

def plot_gnn_tsp(x_coord, x_path, plot_dist_pair=True):
  x_coord = x_coord.detach().cpu()
  x_path = x_path.detach().cpu()

  # Compute TSP lengths
  tsp_length = compute_tour_length(x_coord, x_path)

  # Prepare variables for plotting
  x_coord = np.array(x_coord)
  x_path = np.array(x_path)
  batch_size = x_coord.shape[0]
  node_count = x_coord.shape[1]
  graph = nx.from_numpy_matrix(np.zeros((node_count, node_count)))
  colors = ['g'] + ['b'] * (node_count - 1) # Green for first node, blue for others
  max_plot_count = 3**2
  plot_count = batch_size if batch_size < max_plot_count else max_plot_count
  rows = 1
  cols = 3
  f = plt.figure(figsize=(15, 5))

  # Loop over TSPs and plot
  for i in range(3):
    x_coord_i = x_coord[i]
    pos_i = dict(zip(range(len(x_coord_i)), x_coord_i.tolist()))
    if plot_dist_pair: # Compute pairwise distances matrix for better visualization
      dist_pair_i = squareform(pdist(x_coord_i, metric='euclidean')) 
      G = nx.from_numpy_matrix(dist_pair_i)
    x_path_i = x_path[i] 
    length_tsp_i = tsp_length[i]
    nodes_pair_tsp_i = []
    for r in range(node_count-1):
      nodes_pair_tsp_i.append((x_path_i[r], x_path_i[r+1]))
    nodes_pair_tsp_i.append((x_path_i[node_count-1], x_path_i[0]))

    # Plot the Concorde solution
    subf = f.add_subplot(rows,cols,i+1)
    nx.draw_networkx_nodes(G, pos_i, node_color=colors, node_size=50)
    nx.draw_networkx_edges(G, pos_i, edgelist=nodes_pair_tsp_i, alpha=1, width=1, edge_color='r')
    if plot_dist_pair:
        nx.draw_networkx_edges(G, pos_i, alpha=0.3, width=0.5)
    subf.set_title('Length: ' + str(length_tsp_i.item())[:5])

def plot_concorde_tsp(x_coord, plot_dist_pair=True):
  FACTOR=100
  x_coord = x_coord.detach().cpu()*100

  # Prepare variables for plotting
  x_coord = np.array(x_coord)
  node_count = x_coord.shape[1]
  graph = nx.from_numpy_matrix(np.zeros((node_count, node_count)))
  colors = ['g'] + ['b'] * (node_count - 1) # Green for first node, blue for others
  plot_count = 3
  rows = 1
  cols = int(plot_count)
  f = plt.figure(figsize=(15, 5))

  # Loop over TSPs and plot
  for i in range(plot_count):
    x_coord_i = x_coord[i]
    pos_i = dict(zip(range(len(x_coord_i)), x_coord_i.tolist()))
    if plot_dist_pair: # Compute pairwise distances matrix for better visualization
      dist_pair_i = squareform(pdist(x_coord_i, metric='euclidean')) 
      G = nx.from_numpy_matrix(dist_pair_i)

    # Solve graph using Concorde
    graph =  pd.DataFrame({'lat' : x_coord_i[:,0]}); graph['lon'] =  x_coord_i[:,1]
    solver = TSPSolver.from_data( graph.lat, graph.lon, norm="EUC_2D" )  
    solution = solver.solve().tour
    nodes_pair_concorde_i = []
    for r in range(node_count-1):
        nodes_pair_concorde_i.append((solution[r], solution[r+1]))
    nodes_pair_concorde_i.append((solution[node_count-1], solution[0]))
    
    length_concorde = solver.solve().optimal_value/FACTOR

    subf = f.add_subplot(rows,cols,i+1)
    nx.draw_networkx_nodes(G, pos_i, node_color=colors, node_size=50)
    nx.draw_networkx_edges(G, pos_i, edgelist=nodes_pair_concorde_i, alpha=1, width=1, edge_color='b') #, style='dashed'
    if plot_dist_pair:
      nx.draw_networkx_edges(G, pos_i, alpha=0.3, width=0.5)
    subf.set_title('Length: ' + str(length_concorde))

  plt.show()

In [None]:
# Plotting the TSP tours generated by the model
try:
  sorted_nodes, _, tour_baseline, _, _, _ = model_baseline(tsp_data, deterministic=True)
  plot_gnn_tsp(sorted_nodes, tour_baseline)
except:
  model_baseline = model_baseline.to(device)
  sorted_nodes, _, tour_baseline, _, _, _ = model_baseline(tsp_data, deterministic=True)
  plot_gnn_tsp(sorted_nodes, tour_baseline)

In [None]:
# Plotting the optimal TSP tours
if INSTALL_PYCONCORDE:
  plot_concorde_tsp(tsp_data)