In [2]:
import time
import torch
from torch.distributions import Categorical, kl
from d2l.torch import Animator

from aco import ACO
from utils import load_test_dataset, print_results

torch.manual_seed(1234)

lr = 3e-4
EPS = 1e-10
T_max = 250
T_update = 50 # in this case, we regenerate heuristic measures in T=50, 100, 150, 200
device = 'cuda:0'

In [3]:
@torch.no_grad()
def infer_instance(model, pyg_data, distances, n_ants, t_aco_diff):
    model.eval()
    aco = ACO(
    n_ants=n_ants,
    distances=distances,
    device=device
    )
    
    heu_vec = model(pyg_data, 0)
    heu_mat = model.reshape(pyg_data, heu_vec) + EPS
    
    aco.heuristic = heu_mat
    
    sum_t = 0
    results = []
    for i, t in enumerate(t_aco_diff):
        for _ in range(t):
            best_cost = aco.run(1)
            sum_t += 1
            if sum_t % T_update == 0 and sum_t!= T_max:
                heu_vec = model(pyg_data, sum_t//T_update)
                heu_mat = model.reshape(pyg_data, heu_vec) + EPS
                aco.heuristic = heu_mat
        results.append(best_cost.item())
    return results

@torch.no_grad()
def infer_instance_static(model, pyg_data, distances, n_ants, t_aco_diff):
    model.eval()
    aco = ACO(
    n_ants=n_ants,
    distances=distances,
    device=device
    )
    
    heu_vec = model(pyg_data)
    heu_mat = model.reshape(pyg_data, heu_vec) + EPS
    
    aco.heuristic = heu_mat
    results = []
    for i, t in enumerate(t_aco_diff):
        best_cost = aco.run(t)
        results.append(best_cost.item())
    return results
        
    
@torch.no_grad()
def test(dataset, model, n_ants, t_aco, static):
    _t_aco = [0] + t_aco
    t_aco_diff = [_t_aco[i+1]-_t_aco[i] for i in range(len(_t_aco)-1)]
    all_results = []
    start = time.time()
    for pyg_data, distances in dataset:
        if static:
            results = infer_instance_static(model, pyg_data, distances, n_ants, t_aco_diff)
        else:
            results = infer_instance(model, pyg_data, distances, n_ants, t_aco_diff)
        all_results.append(results)
    end = time.time()
    print('total duration: ', end-start)
    print_results(all_results, t_aco)
    return results

### Test on TSP20

Dynamic

In [3]:
from net import Net

n_node = 20
n_ants = 20
k_sparse = 10

t_aco = [1, 50, 100, 150, 200, 250]
test_list = load_test_dataset(n_node, k_sparse, device)
net_tsp = Net(T_max // T_update).to(device)
net_tsp.load_state_dict(torch.load(f'../pretrained/tsp-temporal/tsp{n_node}.pt', map_location=device))
results = test(test_list, net_tsp, n_ants, t_aco, static=False)

total duration:  353.0747723579407
T=1, 3.9111075401306152
T=50, 3.8073511123657227
T=100, 3.804638147354126
T=150, 3.8040575981140137
T=200, 3.8035595417022705
T=250, 3.8035595417022705


Static

In [4]:
import sys
sys.path.insert(0, '../')

from tsp.net import Net

n_node = 20
n_ants = 20
k_sparse = 10

t_aco = [1, 50, 100, 150, 200, 250]
test_list = load_test_dataset(n_node, k_sparse, device)
net_tsp = Net().to(device)
net_tsp.load_state_dict(torch.load(f'../pretrained/tsp/tsp{n_node}.pt', map_location=device))
results = test(test_list, net_tsp, n_ants, t_aco, static=True) 

total duration:  343.49689078330994
T=1, 3.9243578910827637
T=50, 3.811661720275879
T=100, 3.8096370697021484
T=150, 3.8076534271240234
T=200, 3.8073697090148926
T=250, 3.806710720062256


### Test on TSP100

Dynamic

In [5]:
from net import Net

n_node = 100
n_ants = 20
k_sparse = 20

t_aco = [1, 50, 100, 150, 200, 250]
test_list = load_test_dataset(n_node, k_sparse, device)
net_tsp = Net(T_max // T_update).to(device)
net_tsp.load_state_dict(torch.load(f'../pretrained/tsp-temporal/tsp{n_node}.pt', map_location=device))
results = test(test_list, net_tsp, n_ants, t_aco, static=False)  

total duration:  1005.4726014137268
T=1, 9.008529663085938
T=50, 8.273301124572754
T=100, 8.211779594421387
T=150, 8.189105033874512
T=200, 8.177424430847168
T=250, 8.168805122375488


Static

In [4]:
import sys
sys.path.insert(0, '../')

from tsp.net import Net

n_node = 100
n_ants = 20
k_sparse = 20

t_aco = [1, 50, 100, 150, 200, 250]
test_list = load_test_dataset(n_node, k_sparse, device)
net_tsp = Net().to(device)
net_tsp.load_state_dict(torch.load(f'../pretrained/tsp/tsp{n_node}.pt', map_location=device))
results = test(test_list, net_tsp, n_ants, t_aco, static=True)  

total duration:  1018.3225774765015
T=1, 8.971996307373047
T=50, 8.279731750488281
T=100, 8.23666000366211
T=150, 8.20466423034668
T=200, 8.188815116882324
T=250, 8.17532730102539
