In [1]:
import os 
import random 
import numpy as np
import torch 
from stable_baselines3 import PPO, DQN, HerReplayBuffer
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy

from src.env import MultiODEnv, SparseMultiODEnv
from src.solution import MultiODSolution
from src.problem import MultiODProblem
from src.utils import read_instance_data, get_lkh3_tour, get_ortools_tour
from src.rl.stable_baselines3.nn import PSExtractor
from src.rl.stable_baselines3.callback import SaveBestSolCallback

  if not hasattr(tensorboard, "__version__") or LooseVersion(
  ) < LooseVersion("1.15"):


In [2]:
instance_dir = os.path.join('data', 'tsppdlib', 'instances', 'random-uniform')
instances = [i for i in os.listdir(instance_dir) if i.endswith('.tsp')]
num_Os = ["005", "010", "020", "050"]
num_O = '005'

In [3]:
# 5 random instances for each num_O, with random.sample(sub_instances, k=5)
sub_instances = [i for i in instances if '-' + num_O + '-' in i]
resample = False 

if resample:
    sub_instances = random.sample(sub_instances, k=5)
else:
    if num_O == "005": 
        sub_instances = [
            'random-005-06203.tsp',
            'random-005-14680.tsp',
            'random-005-27025.tsp',
            'random-005-22010.tsp',
            'random-005-27053.tsp']
    elif num_O == "010":
        sub_instances = [
            'random-010-05876.tsp',
            'random-010-13200.tsp',
            'random-010-07248.tsp',
            'random-010-11763.tsp',
            'random-010-20971.tsp']
    elif num_O == "020":
        sub_instances = [
            'random-020-13151.tsp',
            'random-020-32388.tsp',
            'random-020-19723.tsp',
            'random-020-02593.tsp',
            'random-020-10770.tsp']
    elif num_O == "050":
        sub_instances = [
            'random-050-13219.tsp',
            'random-050-29393.tsp',
            'random-050-04371.tsp',
            'random-050-12086.tsp',
            'random-050-21722.tsp']
    else:
        sub_instances = random.sample(sub_instances, k=5)

In [4]:
lkh3_dir = os.path.join('./', 'U')
lkh3_results = os.listdir(lkh3_dir)

ortools_dir = os.path.join('./', 'tmp', 'ortools')
ortools_results = os.listdir(ortools_dir)

In [5]:
# params
hidden_dim = 64
num_heads = 4
lr = 0.001
net_arch = [64, 64]
batch_size = 100

# env
episode_max_time_length = int(1e3)
episode_max_length = int(4e3)
n_steps = episode_max_length
learn_totoal_steps = int(5e2) * episode_max_length
k_recent = 5

# callback
verbose = 1
early_stop = True
tensorboard_log = '../tmp/ppo'
callback_log_dir = '../tmp/paths'

use_sparse_reward = False 
use_her = False 

# HER
n_sampled_goal = 4
goal_selection_strategy = 'future'
replay_buffer_kwargs=dict(n_sampled_goal=n_sampled_goal, goal_selection_strategy=goal_selection_strategy)

In [6]:
for i in range(len(sub_instances)):
    instance_name = sub_instances[i]
    instance_name_head = instance_name[:instance_name.index('.tsp')]
    print(f'instance: {instance_name_head}')
    
    lkh3_instance_result = [i for i in lkh3_results if instance_name_head in i][0]
    lkh3_tour = get_lkh3_tour(os.path.join(lkh3_dir, lkh3_instance_result))
    ortools_instance_result = [i for i in ortools_results if instance_name_head in i][0]
    ortools_tour = get_ortools_tour(os.path.join(ortools_dir, ortools_instance_result))
    
    instance =  os.path.join(instance_dir, instance_name)
    locations = read_instance_data(instance)
    problem = MultiODProblem(locations=locations, ignore_to_dummy_cost=False, ignore_from_dummy_cost=False)
    problem.convert_distance_matrix_to_int()

    lkh3_tour = MultiODSolution([lkh3_tour], problem)
    ortools_tour = MultiODSolution([ortools_tour], problem)
    lkh3_cost, ortools_cost = problem.calc_cost(lkh3_tour), problem.calc_cost(ortools_tour)
    print(f'LKH3 cost: {lkh3_cost}, ortools cost: {ortools_cost}')
    if lkh3_cost < ortools_cost:
        target_tour = lkh3_tour  
        print('Target tour is LKH3')
    else:
        target_tour = ortools_tour
        print('Target tour is ortools')
    
    if use_sparse_reward:
        env = SparseMultiODEnv(target_cost=int(problem.calc_cost(target_tour) * (1 + 0.05)), 
                               problem=problem, 
                               max_length=episode_max_length, 
                               max_time_length=episode_max_time_length,
                               k_recent=k_recent)
    else:
        env = MultiODEnv(problem=problem, 
                         max_length=episode_max_length, 
                         max_time_length=episode_max_time_length,
                         k_recent=k_recent)
    
    features_dim = env.observation_space['solution'].shape[-1] + env.observation_space['problem'].shape[0]
    
    policy_kwargs = dict(
        features_extractor_class=PSExtractor,
        features_extractor_kwargs=dict(features_dim=features_dim, 
                                    sol_input_dim=env.observation_space['solution'].shape[-1],
                                    hidden_dim=hidden_dim,
                                    num_heads=num_heads),
        net_arch=net_arch
    )
    
    if use_her:
        model = DQN("MultiInputPolicy", env, policy_kwargs=policy_kwargs, verbose=verbose,
                train_freq=n_steps, batch_size=batch_size, tensorboard_log=tensorboard_log,
                replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=replay_buffer_kwargs
                )
    else:
        model = PPO("MultiInputPolicy", env, policy_kwargs=policy_kwargs, verbose=verbose, n_steps=n_steps, batch_size=batch_size, learning_rate=lr, tensorboard_log=tensorboard_log)
    
    instance_save_as = instance_name[:instance_name.index('.tsp')]
    model.learn(learn_totoal_steps, 
                tb_log_name=instance_save_as,
                callback=SaveBestSolCallback(log_dir=callback_log_dir, 
                                            instance_name=instance_save_as, 
                                            verbose=verbose,
                                            target_tour=target_tour,
                                            early_stop=early_stop)
                )

instance: random-005-06203
LKH3 cost: 2967.0, ortools cost: 2967.0
Target tour is ortools
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ../tmp/ppo/random-005-06203_1
Target cost: 2967
Best solution cost: 6491, found at 0 step, 0.02 seconds used
Best solution cost: 3754, found at 2 step, 0.02 seconds used
Best solution cost: 3525, found at 6 step, 0.02 seconds used
Best solution cost: 3367, found at 13 step, 0.03 seconds used
Best solution cost: 3024, found at 41 step, 0.04 seconds used
Best solution cost: 2967, found at 72 step, 0.06 seconds used
Rollout best solution cost: 2967, 
                  found at 72 step, 0.04 seconds used. 
                  Convergence gap: 2967.0. Target gap: 0
---------------------------------------------
| best/                          |          |
|    best_cost                   | 2967     |
|    best_sol_at_step            | 72       |
|    best_sol_found_time         | 0.0571   |
| rollout/

  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Rollout best solution cost: 2942, 
                  found at 38 step, 0.02 seconds used. 
                  Convergence gap: 2942.0. Target gap: 0
---------------------------------------------
| best/                          |          |
|    best_cost                   | 2942     |
|    best_sol_at_step            | 38       |
|    best_sol_found_time         | 0.0321   |
| rollout/                       |          |
|    convergence_gap             | 2.94e+03 |
|    rollout_best_cost           | 2942     |
|    rollout_best_sol_at_step    | 38       |
|    rollout_best_sol_found_time | 0.0217   |
|    target_gap                  | 0        |
---------------------------------------------
instance: random-005-27025
LKH3 cost: 3030.0, ortools cost: 3030.0
Target tour is ortools
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ../tmp/ppo/random-005-27025_1
Target cost: 3030
Best solution cost: 3423, found at 1 step, 0.01 seconds u

In [7]:
test_epoch_length = int(4e4)


for _ in range(test_epoch_length):
    pass 