In [None]:
import os 
import numpy as np
import torch 
from stable_baselines3 import PPO, DQN

from src.env import MultiODEnv, SparseMultiODEnv
from src.problem import MultiODProblem
from src.utils import read_instance_data
from src.rl.stable_baselines3.nn import PSExtractor
from src.rl.stable_baselines3.callback import SaveBestSolCallback

In [None]:
instance_dir = os.path.join('data', 'tsppdlib', 'instances', 'random-uniform')
instances = [i for i in os.listdir(instance_dir) if i.endswith('.tsp')]
num_Os = ["005", "010", "020", "050"]
num_O = '050'

In [None]:
episode_max_time_length = int(1e3)
episode_max_length = int(2e4)
n_steps = episode_max_length
learn_totoal_steps = int(4e4) * episode_max_length
verbose = 1
batch_size = 1000
tensorboard_log = '../tmp/ppo'
callback_log_dir = '../tmp/paths'

In [None]:
instance_name = 'random-050-00272.tsp'
target_cost = 9447
instance = f'/home/fangbowen/LAHR/data/tsppdlib/instances/random-uniform/{instance_name}'
locations = read_instance_data(instance)
problem = MultiODProblem(locations=locations, ignore_to_dummy_cost=False)
env = MultiODEnv(problem=problem, max_length=episode_max_length, max_time_length=episode_max_time_length)

In [None]:
features_dim = env.observation_space['solution'].shape[-1] + env.observation_space['problem'].shape[0]
hidden_dim = 256
num_heads = 16

policy_kwargs = dict(
    features_extractor_class=PSExtractor,
    features_extractor_kwargs=dict(features_dim=features_dim, 
                                   sol_input_dim=env.observation_space['solution'].shape[-1],
                                   hidden_dim=hidden_dim,
                                   num_heads=num_heads),
    net_arch=dict(pi=[128, 128], vf=[128, 128]),
    activation_fn=torch.nn.ReLU
)
model = PPO("MultiInputPolicy", env, policy_kwargs=policy_kwargs, verbose=verbose, n_steps=n_steps, batch_size=batch_size, tensorboard_log=tensorboard_log)
# model = DQN("MultiInputPolicy", env, policy_kwargs=policy_kwargs, verbose=verbose, train_freq=n_steps, batch_size=batch_size, tensorboard_log=tensorboard_log)

In [None]:
instance_save_as = instance_name[:instance_name.index('.tsp')]
model.learn(learn_totoal_steps, 
            tb_log_name=instance_save_as,
            callback=SaveBestSolCallback(log_dir=callback_log_dir, 
                                         instance_name=instance_save_as, 
                                         verbose=verbose,
                                         target_cost=target_cost)
            )