# Get model

In [1]:
from types import SimpleNamespace
from training.utils import get_model
import torch

model_args = SimpleNamespace(**{
    'model': 'custom',
    'add_cross_attn': True,
    'use_q_proj_ca': False,
    'use_feedforward_block_sa': False,
    'clip_logit_c': None,
    'device': 'cpu',
    'in_features': 2,
    'nhead': 8,
    'norm_eps': 1e-5,
    'dropout_p': 0.,
    'sinkhorn_i': 20,
    'positional_encoding': 'sin',
    'num_hidden_encoder_layers': 5,
    'd_model': 128,
    'dim_feedforward': 512,
    'activation': 'relu',
    'norm': 'custom_batch',
    'sinkhorn_tau': 0.02,
    'use_feedforward_block_ca': True,
})

# path to a checkpoint
checkpoint_path = ''

model = get_model(model_args)
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))['model'])
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


TSPCustomTransformer(
  (pe): SinPositionalEncoding()
  (input_ff): Linear(in_features=2, out_features=128, bias=True)
  (input_norm): CustomBatchNorm(
    (bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (encoder): TSPCustomEncoder(
    (layers): ModuleList(
      (0): TSPCustomEncoderLayer(
        (sa_block): TSPCustomEncoderBlock(
          (attn): CustomMHA(
            (qkv_proj): Linear(in_features=128, out_features=384, bias=True)
          )
          (norm): CustomBatchNorm(
            (bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (ca_block): TSPCustomEncoderBlock(
          (attn): CustomMHA(
            (kv_proj): Linear(in_features=128, out_features=256, bias=True)
          )
          (norm): CustomBatchNorm(
            (bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (ff_block): TransformerFeedfo

# Load data

In [13]:
# paths to some graph files
graph_files = [
    'tsp_data/test/b580f87b-8c74-4c75-b3ae-1e0b1ef6758a.pt',
    'tsp_data/test/6f46cd69-faad-44b7-a98c-c544ee289159.pt',
    'tsp_data/test/625d000b-4415-44d2-a976-b155f8d439cf.pt',
    'tsp_data/test/90a29060-e4f0-4e41-9534-3e879ebdc4c4.pt',
    'tsp_data/test/82d83c5b-a82e-4aa8-b0b3-b28eb8a6ed90.pt',
    'tsp_data/test/c81cf13b-2636-4eb2-a8b5-3aa15b7cf8db.pt',
    'tsp_data/test/3d2b6ae5-5ea8-41c0-a761-8394892217b9.pt',
    'tsp_data/test/d6f7ccc2-aa9b-4d1a-9a7d-4848d8a5c7c2.pt',
    'tsp_data/test/8fe7d43f-11e8-4539-8057-4b16486bc6ea.pt',
    'tsp_data/test/fd909412-18ad-4f51-9033-871218495339.pt',
]

graph_objects = [torch.load(x) for x in graph_files]

graph_coords = torch.stack([x['coords'] for x in graph_objects])
ref_tours = torch.stack([x['ref_tour'] for x in graph_objects])
with torch.no_grad():
    tours = model(graph_coords).tour

print(graph_coords.shape, tours.shape, ref_tours.shape)

torch.Size([10, 50, 2]) torch.Size([10, 51]) torch.Size([10, 51])


# Plot - comparison ours vs nx

In [14]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils import get_tour_coords

def create_arrows(coords, color='red', xref="x", yref="y"):
    arrows = [go.layout.Annotation(dict(
                x= coords[i+1,0],
                y= coords[i+1,1],
                xref=xref, yref=yref,
                text="",
                showarrow=True,
                axref=xref, ayref=yref,
                ax= coords[i,0],
                ay= coords[i,1],
                arrowhead = 3,
                arrowwidth=1.5,
                arrowcolor=color,)
            ) for i in range(50)]
    return arrows

for i in range(len(graph_coords)): 

    nx_tour_coords = get_tour_coords(graph_coords[i:i+1], ref_tours[i:i+1])
    model_tour_coords = get_tour_coords(graph_coords[i:i+1], tours[i:i+1])

    nodes = go.Scatter(
        x=graph_coords[i, :, 0],
        y=graph_coords[i, :, 1],
        mode='markers',
        marker=dict(
            color='red'
        )
    )

    fig = make_subplots(1, 2, subplot_titles=['Christofides', 'Ours'])
    fig.add_trace(nodes, 1, 1)
    fig.add_trace(nodes, 1, 2)

    for ann in create_arrows(nx_tour_coords[0], color='blue', xref='x', yref='y'):
        fig.add_annotation(ann)
    for ann in create_arrows(model_tour_coords[0], color='green', xref='x2', yref='y2'):
        fig.add_annotation(ann)
    
    fig.update(layout_showlegend=False)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)

    fig.show()