## Test AttentionDecoder

In [1]:
from importlib import reload

import numpy as np
import torch
import torch.nn.functional as F
from datasets.tsp import gen_fully_connected_graph
from models.decoder import AttentionDecoder
from torch_geometric.data import Batch, Data
from torch_geometric.nn import global_max_pool
from torch_geometric.utils import to_dense_batch

In [2]:
embed_dim = 4
num_heads = 2
attn = AttentionDecoder(query_dim=embed_dim, embed_dim=embed_dim, num_heads=num_heads)
attn

AttentionDecoder(
  (query_proj): Linear(in_features=4, out_features=4, bias=True)
  (key_proj): Linear(in_features=4, out_features=4, bias=True)
)

In [3]:
g1 = gen_fully_connected_graph(5)
g2 = gen_fully_connected_graph(10)
g1.mask = torch.ones((51), dtype=torch.bool)
g2.mask = torch.ones((101), dtype=torch.bool)
g1.mask[0:4] = False
g2.mask[0:4] = False
data_list = [g1 if (i % 2) == 0 else g2 for i in range(64)]

In [None]:
batch = Batch.from_data_list(data_list)
batch

In [5]:
dense_batch = to_dense_batch(batch.x, batch.batch)[0]
dense_mask = to_dense_batch(batch.mask, batch.batch)[0]

In [6]:
key = dense_batch.permute(102)
key.shape

torch.Size([10, 64, 4])

In [7]:
mask = dense_mask.permute(021)
mask.shape

torch.Size([64, 1, 10])

In [8]:
query = global_max_pool(batch.x, batch.batch)
query = query[None, :, :]
query.shape

torch.Size([1, 64, 4])

In [9]:
attn_weight = attn(query, key, ~mask)

In [10]:
mask[0:2]

tensor([[[False, False, False, False,  True, False, False, False, False, False]],

        [[False, False, False, False,  True,  True,  True,  True,  True,  True]]])

In [11]:
attn_weight.exp()[0:2]

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.1521, 0.1673, 0.1746, 0.1700,
          0.1728, 0.1631]]], grad_fn=<SliceBackward>)

In [15]:
attn_weight.exp().squeeze().multinomial(1)[0:2]

tensor([[4],
        [7]])

## Test TSP Environment

In [1]:
import torch
from datasets.tsp import gen_fully_connected_graph
from environments.tsp import TSPEnv
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch

In [2]:
num_nodes = 10
batch_size = 5

In [3]:
g = gen_fully_connected_graph(num_nodes)
g_list = [g for _ in range(batch_size)]
batch = Batch.from_data_list(g_list)
batch

Batch(batch=[50], edge_index=[2, 500], pos=[50, 2], x=[50, 4])

In [4]:
node_pos, dense_mask = to_dense_batch(batch.pos, batch.batch)
assert dense_mask.all()
node_pos.shape

torch.Size([5, 10, 2])

In [5]:
env = TSPEnv(node_pos)
assert env.num_nodes == num_nodes
assert env.batch_size == batch_size
print(env._avail_mask)

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True]])


In [6]:
a = env.random_action()
env.step(a)

(TSPState(first_node=tensor([[0],
         [0],
         [0],
         [0],
         [0]]), pre_node=tensor([[9],
         [4],
         [2],
         [1],
         [8]]), avail_mask=tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [False,  True,  True,  True, False,  True,  True,  True,  True,  True],
         [False,  True, False,  True,  True,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
         [False,  True,  True,  True,  True,  True,  True,  True, False,  True]])),
 tensor([0., 0., 0., 0., 0.]),
 False,
 {})

In [7]:
done = False
while not done:
    state, reward, done, _ = env.step(env.random_action())
state, reward

(TSPState(first_node=tensor([[0],
         [0],
         [0],
         [0],
         [0]]), pre_node=tensor([[1],
         [5],
         [5],
         [2],
         [6]]), avail_mask=tensor([[False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False, False]])),
 tensor([-6.7641, -6.1001, -7.0537, -6.4557, -7.4225]))

## Test TSPAgent, TSPCritic

In [1]:
from collections import namedtuple

import torch
from datasets.tsp import gen_fully_connected_graph
from environments.tsp import TSPEnv
from models.tsp_agents import TSPAgent, TSPCritic
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch

In [2]:
num_nodes = 10
batch_size = 64

class args:
    input_dim = 4
    embed_dim = 64
    num_embed_layers = 2
    num_gnn_layers = 2
    encoder_num_heads = 1
    decoder_num_heads = 1
    bias = True
    pooling_method = "add"
    decode_type = "sampling"
    eval_batch_size = 64
    warmup_batch_size = 256
    device = torch.device("cuda:0")

In [3]:
model = TSPAgent(args).to(args.device)

In [4]:
critic = TSPCritic(args).to(args.device)

In [5]:
batch = Batch.from_data_list([gen_fully_connected_graph(num_nodes) for _ in range(batch_size)])
batch = batch.to(args.device)

In [6]:
node_pos = to_dense_batch(batch.pos, batch.batch)[0]
env = TSPEnv(node_pos)
env

TSP Environment, with 64 graphs of size 10

In [7]:
model.encode(batch)

Batch(batch=[640], edge_index=[2, 6400], pos=[640, 2], x=[640, 64])

In [60]:
critic.eval(batch)

TypeError: eval() missing 1 required positional argument: 'target'

In [31]:
log_p_s = []
selected_s = []
reward_s = []
done = False
state = env.reset(node_pos)
step = 0
while (not done) and (step < 999):
    selected, log_p = model(state)
    state, reward, done, _ = env.step(selected)
    log_p_s.append(log_p)
    selected_s.append(selected)
    reward_s.append(reward)
    step += 1

In [32]:
seqs = torch.stack(selected_s, 1)

In [33]:
logps = torch.stack(log_p_s, 1)

In [60]:
log_likelihood = logps.gather(2, seqs).squeeze().sum(1)
log_likelihood

tensor([-12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018, -12.8018,
        -12.8018], grad_fn=<SumBackward1>)

## Test Train

In [8]:
from args import get_args
from datasets.tsp import TSPDataset
from train import rollout, validate
from rl_algorithms.reinforce import _calc_log_likelihood
from train import warmup_baseline
from tqdm import tqdm

In [12]:
dataset = TSPDataset(1000, min_num_node=num_nodes, max_num_node=num_nodes)


Processing...
Done!


In [10]:
rollout(model, dataset, env, args)

tensor([-3.8022, -3.8708, -5.5611,  ..., -3.2411, -4.6239, -4.9271])

In [46]:
validate(model, dataset, env, args)

Validating...
Validation overall avg_cost: 4.079802513122559 +- 0.02389952540397644


tensor(4.0798)

In [14]:
import torch.optim as optim
optimizer = optim.Adam(
    [{"params": model.parameters(), "lr": 0.001}] + 
    [{"params": critic.parameters(), "lr": 0.001}]
)
warmup_baseline(critic, dataset, env, optimizer, args)

100%|██████████| 4/4 [00:00<00:00, 22.57it/s]

Warmup Critic Baseline with training dataset





In [42]:
for i in tqdm(range(1000)):
    batch = batch.to(args.device)
    model.train()
    model.encode(batch)
    model.set_decode_type("sampling")
    log_p_s = []
    selected_s = []
    reward_s = []
    done = False
    state = env.reset(to_dense_batch(batch.pos, batch.batch)[0])
    step = 0
    while (not done) and (step < 999):
        selected, log_p = model(state)
#         selected = torch.tensor([[step+1]], dtype=torch.long, device=args.device)
        state, reward, done, _ = env.step(selected)
        log_p_s.append(log_p)
        selected_s.append(selected)
        reward_s.append(reward)
        step += 1

    _log_p = torch.stack(log_p_s, 1)
    actions = torch.stack(selected_s, 1)
    log_likelihood = _calc_log_likelihood(_log_p, actions)
    reward = -reward.unsqueeze(-1)
    bl_val, bl_loss = critic.eval(batch, reward)
    rl_loss = ((reward-bl_val)*log_likelihood).mean()
    loss = rl_loss + bl_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    

100%|██████████| 1000/1000 [00:47<00:00, 20.99it/s]


In [43]:
print(rl_loss)
print(bl_loss)
print(log_likelihood)

tensor(0.3281, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.3749, device='cuda:0', grad_fn=<MseLossBackward>)
tensor([-11.9626, -13.4902, -12.8316, -12.7669, -12.3703, -12.4275, -12.1155,
        -11.9612, -12.8306, -12.7549, -12.7864, -12.8981, -12.7001, -13.1047,
        -13.1736, -12.8995, -12.4332, -13.2156, -11.6555, -13.6180, -13.1413,
        -13.3393, -12.5452, -12.1836, -12.1994, -13.1358, -12.7499, -13.2295,
        -11.7873, -12.8880, -12.9807, -13.2419, -12.5407, -13.2133, -12.7517,
        -13.2434, -13.2781, -12.7076, -12.6697, -13.0293, -12.4183, -12.3979,
        -13.3256, -13.0834, -13.2280, -12.6623, -12.3575, -12.0317, -12.2767,
        -12.9623, -11.6372, -13.3611, -12.3948, -12.2976, -12.5602, -12.6557,
        -12.4446, -12.2365, -13.7505, -12.2146, -13.6701, -13.2010, -12.7464,
        -12.5830], device='cuda:0', grad_fn=<SumBackward1>)


In [44]:
_log_p.exp().gather(2, actions).squeeze(-1)[0:8:2]

tensor([[0.1451, 0.1215, 0.1905, 0.1667, 0.1898, 0.2674, 0.3751, 0.5984, 1.0000],
        [0.1396, 0.1160, 0.0994, 0.1837, 0.2139, 0.2614, 0.3131, 0.5169, 1.0000],
        [0.1140, 0.1405, 0.1660, 0.1308, 0.2130, 0.3266, 0.3131, 0.5600, 1.0000],
        [0.1459, 0.0923, 0.1592, 0.2028, 0.2299, 0.3024, 0.3132, 0.5790, 1.0000]],
       device='cuda:0', grad_fn=<SliceBackward>)

In [45]:
print(_log_p[0].exp())

tensor([[0.0000, 0.1039, 0.1433, 0.0961, 0.1013, 0.1098, 0.1128, 0.1124, 0.1451,
         0.0753],
        [0.0000, 0.1215, 0.1675, 0.1124, 0.1185, 0.1284, 0.1320, 0.1315, 0.0000,
         0.0881],
        [0.0000, 0.0000, 0.1905, 0.1280, 0.1349, 0.1462, 0.1502, 0.1497, 0.0000,
         0.1005],
        [0.0000, 0.0000, 0.0000, 0.1581, 0.1667, 0.1806, 0.1856, 0.1850, 0.0000,
         0.1240],
        [0.0000, 0.0000, 0.0000, 0.1898, 0.0000, 0.2167, 0.2226, 0.2219, 0.0000,
         0.1490],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2674, 0.2748, 0.2738, 0.0000,
         0.1839],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3751, 0.3739, 0.0000,
         0.2509],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5984, 0.0000,
         0.4016],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000]], device='cuda:0', grad_fn=<ExpBackward>)


In [25]:
# print(reward)
# print(bl_val)


In [38]:
g = batch.to_data_list()[0]
g.x = torch.stack([torch.arange(10).type(torch.float)]*4, 1)

In [39]:
g.edge_index[:,0:10]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [40]:
from torch_geometric.nn import TransformerConv
import torch.nn.functional as F
import torch.nn as nn

gnn = TransformerConv(4,4)

opt = optim.Adam(
    [{"params": gnn.parameters(), "lr": 0.001}]
)
target = torch.arange(10).type(torch.float)
target

tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

In [75]:
for i in range(100):
    out = gnn(g.x, g.edge_index).mean(1)
    loss = F.mse_loss(out, target)
    
    opt.zero_grad()
    loss.backward()
    opt.step()

In [76]:
gnn(g.x, g.edge_index).mean(1)

tensor([0.3300, 1.1001, 1.9340, 2.8558, 3.8780, 4.9753, 6.0786, 7.1146, 8.0510,
        8.8963], grad_fn=<MeanBackward1>)

In [77]:
loss

tensor(0.0192, grad_fn=<MseLossBackward>)