In [None]:
import os
import numpy as np
import sty
import torch
import pprint as pp
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_batch
from tqdm import tqdm
from args_2opt import get_args
from datasets.tsp import TSPDataset
from environments.tsp_2opt import TSP2OPTEnv
from models.tsp_2opt_agent import TSP2OPTAgent
from torch_discounted_cumsum import discounted_cumsum_right

In [None]:
args = get_args('')

args.load_path = "outputs/tsp_50/transformer_decoder_only_20210719T204806/best-model.pt"
args.val_dataset = "datasets/tsp_50_validation_256.pt"
args.graph_size = 50
args.tour_gnn_layers = 10
args.num_gnn_layers = 3
args.batch_size = 256
args.use_cuda = False
args.device = torch.device("cuda" if args.use_cuda else "cpu")

# Pretty print the run args
pp.pprint(vars(args))

# Set the random seed
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

In [None]:
load_data = torch.load(args.load_path, map_location=torch.device("cpu"))
env = TSP2OPTEnv()
model = TSP2OPTAgent(args).to(args.device)

model.load_state_dict(load_data["model"])

In [None]:
val_dataset = TSPDataset(
    size=args.val_size, graph_size=args.graph_size, graph_type=args.graph_type,
    load_path=args.val_dataset, graph_knn=args.graph_knn
)

In [None]:
model.eval()
model.set_decode_type("sampling")
values = []
rewards = []
logps = []
actions = []
for bat in DataLoader(val_dataset, batch_size=args.eval_batch_size):
    step = 0
    bat = bat.to(args.device)
    node_pos = to_dense_batch(bat.pos, bat.batch)[0]
    done = False
    with torch.no_grad():
        state = env.reset(T=10, node_pos=node_pos)
        embed_data = model.init_embed(bat)
        node_embeddings, _ = model.encoder(embed_data)
        while not done:
            step += 1
            action, log_p, value = model(state, node_embeddings, embed_data.batch)
            state, reward, done, _ = env.step(action.squeeze())
            actions.append(action)
            values.append(value)
            logps.append(log_p)
            rewards.append(reward)
            if(step%1==0):
                avg_len = state.best_tour_len.mean().item()
                opt = sum([g.opt for g in val_dataset]) / len(val_dataset)
                opt_gap = (avg_len - opt) / opt
                print(f"step:{step}, average length:{avg_len:.3f}, gap:{opt_gap*100:.3f}%")
                # print(action[0])


In [None]:
rewards_t = torch.stack(rewards, dim=0)
returns_t = discounted_cumsum_right(rewards_t.squeeze(2).T, 0.99).T.unsqueeze(2)
values_t = torch.stack(values, dim=0)
logp_t = torch.stack(logps, dim=0)
r_mean = returns_t.mean()
r_std = returns_t.std()
eps = torch.finfo(torch.float).eps  # small number to avoid div/0
returns_t = (returns_t - r_mean) / (r_std + eps)

In [None]:
logp_t.exp().max(dim=-1)

In [None]:
import torch
from torch_geometric.nn.conv import MessagePassing
from typing import Optional

class Edge(MessagePassing):
    def __init__(self):
        super().__init__()
        self._edge = None

    def forward(self, x, edge_index):
        self.propagate(edge_index, x=x)
        edge = self._edge
        self._edge = None
        return edge

    def message(self, x_j, x_i):
        edge_embedding = torch.cat([x_i, x_j], dim=-1)
        self._edge = edge_embedding
        return x_i


In [None]:
bsz = 5
gsz = 10
x = torch.arange(gsz).unsqueeze(-1).repeat((bsz,1,1))
tour = torch.rand((bsz, gsz)).argsort(dim=1)
tour[0]

In [None]:
x.gather(dim=1, index=tour.unsqueeze(-1).expand_as(x))[0]

In [None]:
edge = Edge()
node_x = x.flatten(0, 1)
edge_index_offset = torch.arange(bsz) * gsz
dense_edge_index = torch.stack(
    [tour, tour.roll(-1, [1])], dim=2
)
edge_index = (
    (dense_edge_index + edge_index_offset[:, None, None]).flatten(0, 1).T
)
edge(node_x, edge_index)