## Test AttentionDecoder

In [1]:
import torch
import numpy as np
import torch.nn.functional as F

from datasets.tsp import gen_fully_connected_graph
from models.decoder import AttentionDecoder
from importlib import reload
from torch_geometric.data import Data Batch
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn import global_max_pool

In [2]:
embed_dim = 4
num_heads = 2
attn = AttentionDecoder(query_dim=embed_dim, embed_dim=embed_dim, num_heads=num_heads)
attn

AttentionDecoder(
  (query_proj): Linear(in_features=4, out_features=4, bias=True)
  (key_proj): Linear(in_features=4, out_features=4, bias=True)
)

In [3]:
g1 = gen_fully_connected_graph(5)
g2 = gen_fully_connected_graph(10)
g1.mask = torch.ones((51), dtype=torch.bool)
g2.mask = torch.ones((101), dtype=torch.bool)
g1.mask[0:4] = False
g2.mask[0:4] = False
data_list = [g1 if (i%2)==0 else g2 for i in range(64)]


In [4]:
batch = Batch.from_data_list(data_list)
batch

Batch(batch=[480], edge_index=[2, 4000], mask=[480, 1], pos=[480, 2], x=[480, 4])

In [5]:
dense_batch = to_dense_batch(batch.x, batch.batch)[0]
dense_mask = to_dense_batch(batch.mask, batch.batch)[0]

In [6]:
key = dense_batch.permute(102)
key.shape

torch.Size([10, 64, 4])

In [7]:
mask = dense_mask.permute(021)
mask.shape

torch.Size([64, 1, 10])

In [8]:
query = global_max_pool(batch.x, batch.batch)
query = query[None, :, :]
query.shape

torch.Size([1, 64, 4])

In [9]:
attn_weight = attn(query, key, ~mask)

In [10]:
mask[0:2]

tensor([[[False, False, False, False,  True, False, False, False, False, False]],

        [[False, False, False, False,  True,  True,  True,  True,  True,  True]]])

In [11]:
attn_weight.exp()[0:2]

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.1521, 0.1673, 0.1746, 0.1700,
          0.1728, 0.1631]]], grad_fn=<SliceBackward>)

In [15]:
attn_weight.exp().squeeze().multinomial(1)[0:2]

tensor([[4],
        [7]])

## Test TSP Environment

In [1]:
import torch
from environments.tsp import TSPEnv
from datasets.tsp import gen_fully_connected_graph
from torch_geometric.utils import to_dense_batch
from torch_geometric.data import Batch


In [2]:
num_nodes = 10
batch_size = 5


In [3]:
g = gen_fully_connected_graph(num_nodes)
g_list = [g for _ in range(batch_size)]
batch = Batch.from_data_list(g_list)
batch

Batch(batch=[50], edge_index=[2, 500], pos=[50, 2], x=[50, 4])

In [4]:
node_pos, dense_mask = to_dense_batch(batch.pos, batch.batch)
assert dense_mask.all()
node_pos.shape

torch.Size([5, 10, 2])

In [5]:
env = TSPEnv(node_pos)
assert env.num_nodes == num_nodes
assert env.batch_size == batch_size
print(env._avail_mask)

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True]])


In [6]:
a = env.random_action()
env.step(a)

(TSPState(first_node=tensor([[0],
         [0],
         [0],
         [0],
         [0]]), pre_node=tensor([[9],
         [4],
         [2],
         [1],
         [8]]), avail_mask=tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [False,  True,  True,  True, False,  True,  True,  True,  True,  True],
         [False,  True, False,  True,  True,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
         [False,  True,  True,  True,  True,  True,  True,  True, False,  True]])),
 tensor([0., 0., 0., 0., 0.]),
 False,
 {})

In [7]:
done = False
while not done:
    state, reward, done, _ = env.step(env.random_action())
state, reward

(TSPState(first_node=tensor([[0],
         [0],
         [0],
         [0],
         [0]]), pre_node=tensor([[1],
         [5],
         [5],
         [2],
         [6]]), avail_mask=tensor([[False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False]])),
 tensor([-6.7641, -6.1001, -7.0537, -6.4557, -7.4225]))

## Test TSPAgent, TSPCritic

In [8]:
import torch
from collections import namedtuple
from environments.tsp import TSPEnv
from datasets.tsp import gen_fully_connected_graph
from torch_geometric.utils import to_dense_batch
from torch_geometric.data import Batch
from models.tsp_agents import TSPAgent, TSPCritic


In [51]:
num_nodes = 10
batch_size = 5

class args:
    input_dim= 4
    embed_dim= 64
    num_embed_layers= 2
    num_gnn_layers= 2
    encoder_num_heads= 1
    decoder_num_heads= 1
    bias= True
    pooling_method= "add"
    decode_type= "sampling"
    eval_batch_size=64


In [10]:
model = TSPAgent(args)

In [11]:
critic = TSPCritic(args)

In [12]:
graph = gen_fully_connected_graph(num_nodes)
batch = Batch.from_data_list([graph] * batch_size)
batch

Batch(batch=[50], edge_index=[2, 500], pos=[50, 2], x=[50, 4])

In [13]:
node_pos = to_dense_batch(batch.pos, batch.batch)[0]
env = TSPEnv(node_pos)
env

TSP Environment, with 5 graphs of size 10

In [14]:
model.encode(batch)

Batch(batch=[50], edge_index=[2, 500], pos=[50, 2], x=[50, 64])

In [17]:
critic(batch)

tensor([[-0.0133],
        [-0.0133],
        [-0.0133],
        [-0.0133],
        [-0.0133]], grad_fn=<LeakyReluBackward0>)

In [19]:
log_p_s = []
selected_s = []
reward_s = []
done = False
state = env.reset(node_pos)
step = 0
while (not done) and (step<999):
    selected, log_p = model(state)
    state, reward, done, _ = env.step(selected)
    log_p_s.append(log_p)
    selected_s.append(selected)
    reward_s.append(reward)
    step+=1
    

In [20]:
seqs = torch.stack(selected_s, 1)

In [21]:
logps = torch.stack(log_p_s, 1)

In [29]:
logps.gather(2,seqs).squeeze().sum(1)

tensor([-12.8018, -12.8018, -12.8018, -12.8018, -12.8018],
       grad_fn=<SumBackward1>)

## Test Train

In [47]:
from args import get_args
from train import rollout, validate
from datasets.tsp import TSPDataset

In [41]:
model

TSPAgent(
  (linear_embedder): MLP(
    (mlp): Sequential(
      (0): Sequential(
        (0): Linear(in_features=4, out_features=34, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
      (1): Sequential(
        (0): Linear(in_features=34, out_features=64, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
      )
    )
  )
  (encoder): GATEncoder(
    (gnn_layer_list): ModuleList(
      (0): GATConv(64, 64, heads=1)
      (1): GATConv(64, 64, heads=1)
    )
  )
  (decoder): AttentionDecoder(
    (query_proj): Linear(in_features=192, out_features=64, bias=True)
    (key_proj): Linear(in_features=64, out_features=64, bias=True)
  )
)

In [44]:
dataset = TSPDataset(5, min_num_node=10, max_num_node=10)
env

Processing...
Done!


TSP Environment, with 5 graphs of size 10

In [52]:
rollout(model, dataset, env, args)

tensor([-7.2590, -5.6745, -4.0938, -3.5245, -5.0056])

In [53]:
validate(model, dataset, env, args)

Validating...
Validation overall avg_cost: -5.111471176147461 +- 0.6517035961151123


tensor(-5.1115)