In [1]:
import os
import numpy as np
import sty
import torch
import pprint as pp
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_batch
from tqdm import tqdm
from args_2opt import get_args
from datasets.tsp import TSPDataset
from environments.tsp_2opt import TSP2OPTEnv
from models.tsp_2opt_agent import TSP2OPTAgent

In [2]:
args = get_args('')

args.load_path = "outputs/tsp_100/continued_20210714T142718/best-model.pt"
args.val_dataset = "datasets/tsp_100_validation_256.pt"
args.graph_size = 100
args.tour_gnn_layers = 10
args.num_gnn_layers = 5
args.batch_size = 256
args.device = torch.device("cuda" if args.use_cuda else "cpu")

# Pretty print the run args
pp.pprint(vars(args))

# Set the random seed
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

{'batch_size': 256,
 'bias': True,
 'checkpoint_epochs': 1,
 'decode_type': 'sampling',
 'decoder_num_heads': 8,
 'device': device(type='cuda'),
 'edge_dim': 1,
 'embed_dim': 128,
 'encoder_num_heads': 8,
 'entropy_beta': 0.005,
 'epoch_size': 5120,
 'epoch_start': 0,
 'eval_batch_size': 1024,
 'eval_only': False,
 'exp_beta': 0.8,
 'gamma': 0.99,
 'graph_knn': 10,
 'graph_size': 100,
 'graph_type': 'knn',
 'horizon': 10,
 'load_path': 'outputs/tsp_100/continued_20210714T142718/best-model.pt',
 'log_dir': 'logs',
 'log_step': 10,
 'lr_critic': 0.0001,
 'lr_decay': 1.0,
 'lr_model': 0.0001,
 'max_grad_norm': 0.3,
 'max_num_steps': 200,
 'n_epochs': 200,
 'no_cuda': False,
 'no_norm_return': False,
 'no_progress_bar': False,
 'node_dim': 2,
 'normalization': 'batch',
 'num_gnn_layers': 5,
 'num_workers': 8,
 'output_dir': 'outputs',
 'pooling_method': 'mean',
 'problem': 'tsp',
 'run_name': 'run_20210715T173934',
 'save_dir': 'outputs/tsp_50/run_20210715T173934',
 'seed': 1234,
 'tanh_cl

In [3]:
load_data = torch.load(args.load_path)
env = TSP2OPTEnv()
model = TSP2OPTAgent(args).to(args.device)

model.load_state_dict(load_data["model"])

<All keys matched successfully>

In [4]:
val_dataset = TSPDataset(
    size=args.val_size, graph_size=args.graph_size, graph_type=args.graph_type,
    load_path=args.val_dataset, graph_knn=args.graph_knn
)

In [17]:
model.eval()
model.set_decode_type("sampling")
for bat in DataLoader(val_dataset, batch_size=args.eval_batch_size):
    step = 0
    bat = bat.to(args.device)
    node_pos = to_dense_batch(bat.pos, bat.batch)[0]
    done = False
    with torch.no_grad():
        state = env.reset(T=2000, node_pos=node_pos)
        embed_data = model.init_embed(bat)
        node_embeddings, _ = model.encoder(embed_data)
        while not done:
            step += 1
            action, _, _ = model(state, node_embeddings, embed_data.batch)
            state, _, done, _ = env.step(action.squeeze())
            if(step%200==0):
                avg_len = state.best_tour_len.mean().item()
                opt = sum([g.opt for g in val_dataset]) / len(val_dataset)
                opt_gap = (avg_len - opt) / opt
                print(f"step:{step}, average length:{avg_len:.3f}, gap:{opt_gap*100:.3f}%")
                # print(action[0])


step:2000, average length:8.175, gap:5.074%
step:4000, average length:8.128, gap:4.472%
step:6000, average length:8.099, gap:4.107%
step:8000, average length:8.077, gap:3.825%
step:10000, average length:8.064, gap:3.657%
step:12000, average length:8.054, gap:3.520%
step:14000, average length:8.046, gap:3.424%
step:16000, average length:8.037, gap:3.304%
step:18000, average length:8.032, gap:3.239%


step:100, average length:10.828, gap:39.182%
step:200, average length:8.669, gap:11.428%
step:300, average length:8.480, gap:8.996%
step:400, average length:8.391, gap:7.861%
step:500, average length:8.346, gap:7.284%
step:600, average length:8.316, gap:6.889%
step:700, average length:8.290, gap:6.551%
step:800, average length:8.268, gap:6.269%
step:900, average length:8.252, gap:6.069%
step:1000, average length:8.238, gap:5.890%
step:1100, average length:8.228, gap:5.766%
step:1200, average length:8.220, gap:5.651%
step:1300, average length:8.210, gap:5.534%
step:1400, average length:8.200, gap:5.398%
step:1500, average length:8.194, gap:5.318%
step:1600, average length:8.184, gap:5.192%
step:1700, average length:8.174, gap:5.069%
step:1800, average length:8.168, gap:4.992%
step:1900, average length:8.165, gap:4.948%

In [None]:
state.curr_edge_list[0]

In [None]:
import torch.nn as nn
class TestResetParam(nn.Module):
    def __init__(self):
        super().__init__()
        self.reset_parameters()

    def reset_parameters(self):
        print("Reset Parameters")

class FOO(nn.Module):
    def __init__(self):
        super().__init__()
        self.test = TestResetParam()

foo = FOO()