## Make predictions

Now let's make some predictions on the validation dataset with the trained model. This notebook is designed to generate heat maps on TSP200, TSP500, TSP1000 and TSP10000.

In [1]:
import os
import json
import argparse
import time
import math
import numpy as np
from scipy.special import softmax
import torch.nn.functional as F
import torch.nn as nn
from utils.plot_utils import plot_predictions_cluster

from sklearn.utils.class_weight import compute_class_weight

# Remove warning
# import warnings
# warnings.filterwarnings("ignore", category=UserWarning)
# from scipy.sparse import SparseEfficiencyWarning
# warnings.simplefilter('ignore', SparseEfficiencyWarning)

from utils.process import *
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform

from data.data_generator import tsp_instance_reader

from utils.tsplib import read_tsplib_coor, read_tsplib_opt, write_tsplib_prob


from multiprocessing import Pool
from multiprocessing import cpu_count

import tqdm

### 1. Loading trained Att-GCN based on TSP50-trainset

In [3]:
# model-parameter
config_path = "configs/tsp50.json"
config = get_config(config_path)

# setting random seed to 1
if torch.cuda.is_available():
    print("Using cuda")
    dtypeFloat = torch.cuda.FloatTensor
    dtypeLong = torch.cuda.LongTensor
    torch.cuda.manual_seed_all(1)
else:
    dtypeFloat = torch.FloatTensor
    dtypeLong = torch.LongTensor
    torch.manual_seed(1)

In [None]:
# Instantiate the network
net = nn.DataParallel(ResidualGatedGCNModel(config, dtypeFloat, dtypeLong))
if torch.cuda.is_available():
    net.cuda()  
# Define optimizer
learning_rate = config.learning_rate
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
# Load checkpoint
log_dir = f"./logs/{config.expt_name}/"
if torch.cuda.is_available():
    # TSP-50
    checkpoint = torch.load("logs/tsp50/best_val_checkpoint.tar")
else:
    checkpoint = torch.load("logs/tsp50/best_val_checkpoint.tar", map_location='cpu')
# Load network state
net.load_state_dict(checkpoint['model_state_dict'])
# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load other training parameters
epoch = checkpoint['epoch']
train_loss = checkpoint['train_loss']
val_loss = checkpoint['val_loss']
for param_group in optimizer.param_groups:
    learning_rate = param_group['lr']  

### 2. Graph Sampling 

In [None]:
def test_one_tsp(tsp_source, coor_buff, node_num=20, 
                    cluster_center = 0, top_k = 19, top_k_expand = 19):
    mean_rank_sum, mean_greater_zero_edges = 0, 0
    coor, opt = tsp_instance_reader(tspinstance=tsp_source,
                       buff = coor_buff, num_node=node_num)
    coors = [coor]
    
    distA = pdist(coors[0], metric='euclidean')
    distB_raw = squareform(distA)
    distB = squareform(distA) + 10.0 * np.eye(N = node_num, M =node_num, dtype = np.float64)
    
    edges_probs = np.zeros(shape = (node_num, node_num), dtype = np.float64)
    
    pre_edges = np.ones(shape = (top_k + 1, top_k + 1), dtype = np.int32) + np.eye(N = top_k + 1, M = top_k + 1)
    pre_node = np.ones(shape = (top_k + 1, ))
    
    pre_node_target = np.arange(0, top_k + 1)
    pre_node_target = np.append(pre_node_target, 0)
    pre_edge_target = np.zeros(shape = (top_k + 1, top_k + 1)) 
    pre_edge_target[pre_node_target[:-1], pre_node_target[1:]] = 1
    pre_edge_target[pre_node_target[1:], pre_node_target[:-1]] = 1
    
    neighbor = np.argpartition(distB, kth = top_k, axis=1)
    
    neighbor_expand = np.argpartition(distB, kth = top_k_expand, axis=1)
    Omega_w = np.zeros(shape=(node_num, ), dtype = np.int32)
    Omega = np.zeros(shape=(node_num, node_num), dtype = np.int32)
    
    edges, edges_values = [], []
    nodes, nodes_coord = [], []
    edges_target, nodes_target = [], []
    meshs = []
    num_clusters = 0
    if node_num==20:
        num_clusters_threshold = 1
    else:
        num_clusters_threshold = math.ceil((node_num / (top_k+1) ) * 5)
    all_visited = False
    num_batch_size = 0
    
    while num_clusters < num_clusters_threshold or all_visited == False:
        if all_visited==False:
            
            cluster_center_neighbor = neighbor[cluster_center, :top_k]
            cluster_center_neighbor = np.insert(cluster_center_neighbor,
                                                0, cluster_center)
        else:
            np.random.shuffle(neighbor_expand[cluster_center, :top_k_expand])
            cluster_center_neighbor = neighbor_expand[cluster_center, :top_k]
            cluster_center_neighbor = np.insert(cluster_center_neighbor,
                                                0, cluster_center)
        
        Omega_w[cluster_center_neighbor] += 1

        # case 4
        node_coord = coors[0][cluster_center_neighbor]
        x_y_min = np.min(node_coord, axis=0)
        scale = 1.0 / np.max(np.max(node_coord, axis=0)-x_y_min)
        node_coord = node_coord - x_y_min
        node_coord *= scale
        nodes_coord.append(node_coord)

        # case 1-2
        edges.append(pre_edges)
        mesh = np.meshgrid(cluster_center_neighbor, cluster_center_neighbor)
        
        edges_value = distB_raw[mesh].copy()
        edges_value *= scale
        edges_values.append(edges_value)
        meshs.append(mesh)
        Omega[mesh] += 1

        # case 3
        nodes.append(pre_node)

        # case 5-6
        edges_target.append(pre_edge_target)
        nodes_target.append(pre_node_target[:-1])

        num_clusters += 1
        
        if 0 not in Omega_w:
            all_visited = True
        
        cluster_center = np.random.choice(np.where(Omega_w==np.min(Omega_w))[0])
    
    return edges, edges_values, nodes, nodes_coord, edges_target, nodes_target, meshs, Omega, opt

In [None]:
def multiprocess_write(sub_prob, meshgrid, omega, node_num = 20,
                       tsplib_name = './sample.txt', statiscs = False, opt = None):
    edges_probs = np.zeros(shape = (node_num, node_num), dtype = np.float32)
    for i in range(len(meshgrid)):
        edges_probs[list(meshgrid[i])] += sub_prob[i, :, :, 1]
    edges_probs = edges_probs / (omega + 1e-8)#[:, None]
    # normalize the probability in an instance 
    edges_probs = edges_probs + edges_probs.T
    edges_probs_norm = edges_probs/np.reshape(np.sum(edges_probs, axis=1),
                                              newshape=(node_num, 1))
    
    if statiscs:
        mean_rank = 0
        for i in range(node_num-1):
            mean_rank += len(np.where(edges_probs_norm[opt[i], :]>=edges_probs_norm[opt[i], opt[i+1]])[0]) 
        mean_rank /= (node_num-1)
        
        false_negative_edge = opt[np.where(edges_probs_norm[opt[:-1], opt[1:]]<1e-5)]
        # false negative edges in an instance
        num_fne = len(false_negative_edge)
        
        greater_zero_edges = len(np.where(edges_probs_norm>1e-6)[0])
        greater_zero_edges /= node_num
        
        write_tsplib_prob(tsplib_name, edge_prob = edges_probs_norm,
                  num_node=node_num, mean=mean_rank, fnn = num_fne, greater_zero=greater_zero_edges)
    else:
        write_tsplib_prob(tsplib_name, edge_prob = edges_probs_norm,
                          num_node=node_num, mean=0, fnn = 0, greater_zero=0)
    return mean_rank
net.eval()

#### 2.1 Heatmap generator on TSP200

All output would be stored on the dir `./results/heatmap/tsp200`. After running the next node code cell, we would get 128 probabilistic heat maps for TSP200-instances and then copy them to the dir `MCTS/tsp-200-500-1000/heatmap`.

In [None]:
num_nodes = 200
f = open('./data/tsp{}_test_concorde.txt'.format(num_nodes), 'r')
testset_tsp = f.readlines()
f.close()

config.expt_name = 'tsp{}'.format(num_nodes)
K = 49
K_expand = 79
avg_mean_rank = [] 
top_k, cluster_center = K, 0
batch_size = 128 
threshold = math.ceil((num_nodes / (top_k+1) ) * 5)
epoch = int(len(testset_tsp)/batch_size)
buff_coor = np.zeros(shape=(num_nodes, 2), dtype = np.float64)
start_row_num = 0

In [None]:
# init
count_buff = np.zeros(shape=(batch_size*threshold, ), dtype=np.int32)
edges = np.zeros(shape=(batch_size*threshold, K+1, K+1), dtype=np.int32)
edges_values = np.zeros(shape=(batch_size*threshold, K+1, K+1), dtype=np.float16)
nodes = np.zeros(shape = (batch_size*threshold, K+1), dtype=np.int32)
nodes_coord = np.zeros(shape = (batch_size*threshold, K+1, 2), dtype=np.float16)
edges_target = np.zeros(shape = (batch_size*threshold, K+1, K+1), dtype=np.int32)
nodes_target = np.zeros(shape = (batch_size*threshold, K+1), dtype=np.int32)
meshs = np.zeros(shape = (batch_size*threshold, 2, K+1, K+1), dtype=np.int32)
Omegas = np.zeros(shape = (batch_size, num_nodes, num_nodes), dtype=np.int32)
opts = np.zeros(shape = (batch_size, num_nodes+1), dtype=np.int32)
num_neighbors = config.num_neighbors
beam_size = config.beam_size

sum_time = 0
for j in tqdm.tqdm(range(epoch)):
    start = time.time()
    for i in range(batch_size):
        edge, edges_value, node, node_coord, edge_target, node_target, mesh, omega, opt = test_one_tsp(tsp_source=testset_tsp[start_row_num+i], 
                                                                                      coor_buff=buff_coor, node_num=num_nodes, 
                                                                                      cluster_center=0, top_k=K, top_k_expand=K_expand)
        edges[i*threshold:(i+1)*threshold, ...] = edge
        edges_values[i*threshold:(i+1)*threshold, ...] = edges_value
        nodes[i*threshold:(i+1)*threshold, ...] = node
        nodes_coord[i*threshold:(i+1)*threshold, ...] = node_coord
        edges_target[i*threshold:(i+1)*threshold, ...] = edge_target
        nodes_target[i*threshold:(i+1)*threshold, ...] = node_target
        meshs[i*threshold:(i+1)*threshold, ...] = mesh
        Omegas[i, ...] = omega
        opts[i, ...] = opt


    with torch.no_grad():
        # Convert batch to torch Variables
        x_edges = Variable(torch.LongTensor(edges).type(dtypeLong), requires_grad=False)
        x_edges_values = Variable(torch.FloatTensor(edges_values).type(dtypeFloat), requires_grad=False)
        x_nodes = Variable(torch.LongTensor(nodes).type(dtypeLong), requires_grad=False)
        x_nodes_coord = Variable(torch.FloatTensor(nodes_coord).type(dtypeFloat), requires_grad=False)
        y_edges = Variable(torch.LongTensor(edges_target).type(dtypeLong), requires_grad=False)
        y_nodes = Variable(torch.LongTensor(nodes_target).type(dtypeLong), requires_grad=False)

        # Compute class weights
        edge_labels = y_edges.cpu().numpy().flatten()
        edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

        # Forward pass
        y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
        y_preds_prob = F.softmax(y_preds, dim=3)
        y_preds_prob_numpy = y_preds_prob.cpu().numpy()

    # multi - processes
#     progress_pool = Pool(processes=10)
#     for i in range(batch_size):
#         heatmap_path = f'./results/heatmap/tsp{num_nodes}/heatmaptsp{num_nodes}_{i+start_row_num}.txt'
#         progress_pool.apply_async(multiprocess_write, args=(y_preds_prob_numpy[i*thre:(i+1)*thre, ...],
#                                                            meshs[i*thre:(i+1)*thre, ...], Omegas[i, ...],
#                                                            num_nodes, heatmap_path, True, opts[i, ...]))
#     progress_pool.close()
#     progress_pool.join()
    end = time.time()
    sum_time += end - start
    # single - process
    for i in range(batch_size):
        heatmap_path = f'./results/heatmap/tsp{num_nodes}/heatmaptsp{num_nodes}_{i+start_row_num}.txt'
        rank = multiprocess_write(y_preds_prob_numpy[i*threshold:(i+1)*threshold, ...],
                                                           meshs[i*threshold:(i+1)*threshold, ...], Omegas[i, ...],
                                                           num_nodes, heatmap_path, True, opts[i, ...])
        avg_mean_rank.append(rank)
    start_row_num+= batch_size


#### 2.2 Heatmap generator on TSP500

All output would be stored on the dir `./results/heatmap/tsp500`. After running the next node code cell, we would get 128 probabilistic heat maps for TSP500-instances and then copy them to the dir `MCTS/tsp-200-500-1000/heatmap`.

In [None]:
num_nodes = 500
f = open('./data/tsp{}_test_concorde.txt'.format(num_nodes), 'r')
testset_tsp = f.readlines()
f.close()

config.expt_name = 'tsp{}'.format(num_nodes)
K = 49
avg_mean_rank = [] 
top_k, cluster_center = K, 0
batch_size = 64 
threshold = math.ceil((num_nodes / (top_k+1) ) * 5)
epoch = int(len(testset_tsp)/batch_size)
buff_coor = np.zeros(shape=(num_nodes, 2), dtype = np.float64)
start_row_num = 0

In [None]:
# init
K_expand = 99
count_buff = np.zeros(shape=(batch_size*threshold, ), dtype=np.int32)
edges = np.zeros(shape=(batch_size*threshold, K+1, K+1), dtype=np.int32)
edges_values = np.zeros(shape=(batch_size*threshold, K+1, K+1), dtype=np.float16)
nodes = np.zeros(shape = (batch_size*threshold, K+1), dtype=np.int32)
nodes_coord = np.zeros(shape = (batch_size*threshold, K+1, 2), dtype=np.float16)
edges_target = np.zeros(shape = (batch_size*threshold, K+1, K+1), dtype=np.int32)
nodes_target = np.zeros(shape = (batch_size*threshold, K+1), dtype=np.int32)
meshs = np.zeros(shape = (batch_size*threshold, 2, K+1, K+1), dtype=np.int32)
Omegas = np.zeros(shape = (batch_size, num_nodes, num_nodes), dtype=np.int32)
opts = np.zeros(shape = (batch_size, num_nodes+1), dtype=np.int32)
num_neighbors = config.num_neighbors
beam_size = config.beam_size

sum_time = 0
for j in tqdm.tqdm(range(epoch)):
    start = time.time()
    for i in range(batch_size):
        edge, edges_value, node, node_coord, edge_target, node_target, mesh, omega, opt = test_one_tsp(tsp_source=testset_tsp[start_row_num+i], 
                                                                                      coor_buff=buff_coor, node_num=num_nodes, 
                                                                                      cluster_center=0, top_k=K, top_k_expand=K_expand)
        edges[i*threshold:(i+1)*threshold, ...] = edge
        edges_values[i*threshold:(i+1)*threshold, ...] = edges_value
        nodes[i*threshold:(i+1)*threshold, ...] = node
        nodes_coord[i*threshold:(i+1)*threshold, ...] = node_coord
        edges_target[i*threshold:(i+1)*threshold, ...] = edge_target
        nodes_target[i*threshold:(i+1)*threshold, ...] = node_target
        meshs[i*threshold:(i+1)*threshold, ...] = mesh
        Omegas[i, ...] = omega
        opts[i, ...] = opt


    with torch.no_grad():
        # Convert batch to torch Variables
        x_edges = Variable(torch.LongTensor(edges).type(dtypeLong), requires_grad=False)
        x_edges_values = Variable(torch.FloatTensor(edges_values).type(dtypeFloat), requires_grad=False)
        x_nodes = Variable(torch.LongTensor(nodes).type(dtypeLong), requires_grad=False)
        x_nodes_coord = Variable(torch.FloatTensor(nodes_coord).type(dtypeFloat), requires_grad=False)
        y_edges = Variable(torch.LongTensor(edges_target).type(dtypeLong), requires_grad=False)
        y_nodes = Variable(torch.LongTensor(nodes_target).type(dtypeLong), requires_grad=False)

        # Compute class weights
        edge_labels = y_edges.cpu().numpy().flatten()
        edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

        # Forward pass
        y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
        y_preds_prob = F.softmax(y_preds, dim=3)
        y_preds_prob_numpy = y_preds_prob.cpu().numpy()

    # multi - processes
#     progress_pool = Pool(processes=10)
#     for i in range(batch_size):
#         heatmap_path = f'./results/heatmap/tsp{num_nodes}/heatmaptsp{num_nodes}_{i+start_row_num}.txt'
#         progress_pool.apply_async(multiprocess_write, args=(y_preds_prob_numpy[i*thre:(i+1)*thre, ...],
#                                                            meshs[i*thre:(i+1)*thre, ...], Omegas[i, ...],
#                                                            num_nodes, heatmap_path, True, opts[i, ...]))
#     progress_pool.close()
#     progress_pool.join()
    end = time.time()
    sum_time += end - start
    # single - process
    for i in range(batch_size):
        heatmap_path = f'./results/heatmap/tsp{num_nodes}/heatmaptsp{num_nodes}_{i+start_row_num}.txt'
        rank = multiprocess_write(y_preds_prob_numpy[i*threshold:(i+1)*threshold, ...],
                                                           meshs[i*threshold:(i+1)*threshold, ...], Omegas[i, ...],
                                                           num_nodes, heatmap_path, True, opts[i, ...])
        avg_mean_rank.append(rank)
    start_row_num+= batch_size


#### 2.3 Heatmap generator on TSP1000

All output would be stored on the dir `./results/heatmap/tsp1000`. After running the next node code cell, we would get 128 probabilistic heat maps for TSP1000-instances and then copy them to the dir `MCTS/tsp-200-500-1000/heatmap`.

In [None]:
num_nodes = 1000
f = open('./data/tsp{}_test_concorde.txt'.format(num_nodes), 'r')
testset_tsp = f.readlines()
f.close()

config.expt_name = 'tsp{}'.format(num_nodes)
K = 49
avg_mean_rank = [] 
top_k, cluster_center = K, 0
batch_size = 64 
threshold = math.ceil((num_nodes / (top_k+1) ) * 5)
epoch = int(len(testset_tsp)/batch_size)
buff_coor = np.zeros(shape=(num_nodes, 2), dtype = np.float64)
start_row_num = 0

In [None]:
# init
K_expand = 99
count_buff = np.zeros(shape=(batch_size*threshold, ), dtype=np.int32)
edges = np.zeros(shape=(batch_size*threshold, K+1, K+1), dtype=np.int32)
edges_values = np.zeros(shape=(batch_size*threshold, K+1, K+1), dtype=np.float16)
nodes = np.zeros(shape = (batch_size*threshold, K+1), dtype=np.int32)
nodes_coord = np.zeros(shape = (batch_size*threshold, K+1, 2), dtype=np.float16)
edges_target = np.zeros(shape = (batch_size*threshold, K+1, K+1), dtype=np.int32)
nodes_target = np.zeros(shape = (batch_size*threshold, K+1), dtype=np.int32)
meshs = np.zeros(shape = (batch_size*threshold, 2, K+1, K+1), dtype=np.int32)
Omegas = np.zeros(shape = (batch_size, num_nodes, num_nodes), dtype=np.int32)
opts = np.zeros(shape = (batch_size, num_nodes+1), dtype=np.int32)
num_neighbors = config.num_neighbors
beam_size = config.beam_size

sum_time = 0
for j in tqdm.tqdm(range(epoch)):
    start = time.time()
    for i in range(batch_size):
        edge, edges_value, node, node_coord, edge_target, node_target, mesh, omega, opt = test_one_tsp(tsp_source=testset_tsp[start_row_num+i], 
                                                                                      coor_buff=buff_coor, node_num=num_nodes, 
                                                                                      cluster_center=0, top_k=K, top_k_expand=K_expand)
        edges[i*threshold:(i+1)*threshold, ...] = edge
        edges_values[i*threshold:(i+1)*threshold, ...] = edges_value
        nodes[i*threshold:(i+1)*threshold, ...] = node
        nodes_coord[i*threshold:(i+1)*threshold, ...] = node_coord
        edges_target[i*threshold:(i+1)*threshold, ...] = edge_target
        nodes_target[i*threshold:(i+1)*threshold, ...] = node_target
        meshs[i*threshold:(i+1)*threshold, ...] = mesh
        Omegas[i, ...] = omega
        opts[i, ...] = opt


    with torch.no_grad():
        # Convert batch to torch Variables
        x_edges = Variable(torch.LongTensor(edges).type(dtypeLong), requires_grad=False)
        x_edges_values = Variable(torch.FloatTensor(edges_values).type(dtypeFloat), requires_grad=False)
        x_nodes = Variable(torch.LongTensor(nodes).type(dtypeLong), requires_grad=False)
        x_nodes_coord = Variable(torch.FloatTensor(nodes_coord).type(dtypeFloat), requires_grad=False)
        y_edges = Variable(torch.LongTensor(edges_target).type(dtypeLong), requires_grad=False)
        y_nodes = Variable(torch.LongTensor(nodes_target).type(dtypeLong), requires_grad=False)

        # Compute class weights
        edge_labels = y_edges.cpu().numpy().flatten()
        edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

        # Forward pass
        y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
        y_preds_prob = F.softmax(y_preds, dim=3)
        y_preds_prob_numpy = y_preds_prob.cpu().numpy()

    # multi - processes
#     progress_pool = Pool(processes=10)
#     for i in range(batch_size):
#         heatmap_path = f'./results/heatmap/tsp{num_nodes}/heatmaptsp{num_nodes}_{i+start_row_num}.txt'
#         progress_pool.apply_async(multiprocess_write, args=(y_preds_prob_numpy[i*thre:(i+1)*thre, ...],
#                                                            meshs[i*thre:(i+1)*thre, ...], Omegas[i, ...],
#                                                            num_nodes, heatmap_path, True, opts[i, ...]))
#     progress_pool.close()
#     progress_pool.join()
    end = time.time()
    sum_time += end - start
    # single - process
    for i in range(batch_size):
        heatmap_path = f'./results/heatmap/tsp{num_nodes}/heatmaptsp{num_nodes}_{i+start_row_num}.txt'
        rank = multiprocess_write(y_preds_prob_numpy[i*threshold:(i+1)*threshold, ...],
                                                           meshs[i*threshold:(i+1)*threshold, ...], Omegas[i, ...],
                                                           num_nodes, heatmap_path, True, opts[i, ...])
        avg_mean_rank.append(rank)
    start_row_num+= batch_size


#### 2.4 Heatmap generator on TSP10000

All output would be stored on the dir `./results/heatmap/tsp10000`. After running the next node code cell, we would get 16 probabilistic heat maps for TSP10000-instances and then copy them to the dir `MCTS/tsp-10000/heatmap`.

In [None]:
num_nodes = 10000
f = open('./data/tsp{}_test_concorde.txt'.format(num_nodes), 'r')
testset_tsp = f.readlines()
f.close()

config.expt_name = 'tsp{}'.format(num_nodes)
K = 49
avg_mean_rank = [] 
top_k, cluster_center = K, 0
batch_size = 8
threshold = math.ceil((num_nodes / (top_k+1) ) * 5)
epoch = int(len(testset_tsp)/batch_size)
buff_coor = np.zeros(shape=(num_nodes, 2), dtype = np.float64)
start_row_num = 0

In [None]:
# init
K_expand = 149
count_buff = np.zeros(shape=(batch_size*threshold, ), dtype=np.int32)
edges = np.zeros(shape=(batch_size*threshold, K+1, K+1), dtype=np.int32)
edges_values = np.zeros(shape=(batch_size*threshold, K+1, K+1), dtype=np.float16)
nodes = np.zeros(shape = (batch_size*threshold, K+1), dtype=np.int32)
nodes_coord = np.zeros(shape = (batch_size*threshold, K+1, 2), dtype=np.float16)
edges_target = np.zeros(shape = (batch_size*threshold, K+1, K+1), dtype=np.int32)
nodes_target = np.zeros(shape = (batch_size*threshold, K+1), dtype=np.int32)
meshs = np.zeros(shape = (batch_size*threshold, 2, K+1, K+1), dtype=np.int32)
Omegas = np.zeros(shape = (batch_size, num_nodes, num_nodes), dtype=np.int32)
opts = np.zeros(shape = (batch_size, num_nodes+1), dtype=np.int32)
num_neighbors = config.num_neighbors
beam_size = config.beam_size

sum_time = 0
for j in tqdm.tqdm(range(epoch)):
    start = time.time()
    for i in range(batch_size):
        edge, edges_value, node, node_coord, edge_target, node_target, mesh, omega, opt = test_one_tsp(tsp_source=testset_tsp[start_row_num+i], 
                                                                                      coor_buff=buff_coor, node_num=num_nodes, 
                                                                                      cluster_center=0, top_k=K, top_k_expand=K_expand)
        edges[i*threshold:(i+1)*threshold, ...] = edge
        edges_values[i*threshold:(i+1)*threshold, ...] = edges_value
        nodes[i*threshold:(i+1)*threshold, ...] = node
        nodes_coord[i*threshold:(i+1)*threshold, ...] = node_coord
        edges_target[i*threshold:(i+1)*threshold, ...] = edge_target
        nodes_target[i*threshold:(i+1)*threshold, ...] = node_target
        meshs[i*threshold:(i+1)*threshold, ...] = mesh
        Omegas[i, ...] = omega
        opts[i, ...] = opt


    with torch.no_grad():
        # Convert batch to torch Variables
        x_edges = Variable(torch.LongTensor(edges).type(dtypeLong), requires_grad=False)
        x_edges_values = Variable(torch.FloatTensor(edges_values).type(dtypeFloat), requires_grad=False)
        x_nodes = Variable(torch.LongTensor(nodes).type(dtypeLong), requires_grad=False)
        x_nodes_coord = Variable(torch.FloatTensor(nodes_coord).type(dtypeFloat), requires_grad=False)
        y_edges = Variable(torch.LongTensor(edges_target).type(dtypeLong), requires_grad=False)
        y_nodes = Variable(torch.LongTensor(nodes_target).type(dtypeLong), requires_grad=False)

        # Compute class weights
        edge_labels = y_edges.cpu().numpy().flatten()
        edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

        # Forward pass
        y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
        y_preds_prob = F.softmax(y_preds, dim=3)
        y_preds_prob_numpy = y_preds_prob.cpu().numpy()

    # multi - processes
#     progress_pool = Pool(processes=10)
#     for i in range(batch_size):
#         heatmap_path = f'./results/heatmap/tsp{num_nodes}/heatmaptsp{num_nodes}_{i+start_row_num}.txt'
#         progress_pool.apply_async(multiprocess_write, args=(y_preds_prob_numpy[i*thre:(i+1)*thre, ...],
#                                                            meshs[i*thre:(i+1)*thre, ...], Omegas[i, ...],
#                                                            num_nodes, heatmap_path, True, opts[i, ...]))
#     progress_pool.close()
#     progress_pool.join()
    end = time.time()
    sum_time += end - start
    # single - process
    for i in range(batch_size):
        heatmap_path = f'./results/heatmap/tsp{num_nodes}/heatmaptsp{num_nodes}_{i+start_row_num}.txt'
        rank = multiprocess_write(y_preds_prob_numpy[i*threshold:(i+1)*threshold, ...],
                                                           meshs[i*threshold:(i+1)*threshold, ...], Omegas[i, ...],
                                                           num_nodes, heatmap_path, True, opts[i, ...])
        avg_mean_rank.append(rank)
    start_row_num+= batch_size
