In [1]:
import json
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
from IPython.core.debugger import set_trace


from gcn_model.chaitjo import config as chaitjo_config
from gcn_model.chaitjo.models.gcn_model import ResidualGatedGCNModel
from gcn_model.chaitjo.utils.graph_reader import GraphReader


In [2]:
model_path = 'gcn_model/chaitjo/'
graph_path = 'dataset/round5/node/'
feat_path = 'dataset/round5/feat300/'

In [3]:
config_path = model_path + 'tsp-models/tsp20/config.json'
config = chaitjo_config.Settings(json.load(open(config_path)))
print(config)

{'expt_name': 'tsp20', 'gpu_id': '1', 'train_filepath': './data/tsp20_train_concorde.txt', 'val_filepath': './data/tsp20_val_concorde.txt', 'test_filepath': './data/tsp20_test_concorde.txt', 'num_nodes': 20, 'num_neighbors': -1, 'node_dim': 2, 'voc_nodes_in': 2, 'voc_nodes_out': 2, 'voc_edges_in': 3, 'voc_edges_out': 2, 'beam_size': 1280, 'hidden_dim': 300, 'num_layers': 30, 'mlp_layers': 3, 'aggregation': 'mean', 'max_epochs': 1500, 'val_every': 5, 'test_every': 100, 'batch_size': 20, 'batches_per_epoch': 500, 'accumulation_steps': 1, 'learning_rate': 0.001, 'decay_rate': 1.01}


In [4]:
dtypeFloat = torch.cuda.FloatTensor
dtypeLong = torch.cuda.LongTensor

net = nn.DataParallel(ResidualGatedGCNModel(config, dtypeFloat, dtypeLong))
print(net)

DataParallel(
  (module): ResidualGatedGCNModel(
    (nodes_coord_embedding): Linear(in_features=2, out_features=300, bias=False)
    (edges_values_embedding): Linear(in_features=1, out_features=150, bias=False)
    (edges_embedding): Embedding(3, 150)
    (gcn_layers): ModuleList(
      (0): ResidualGatedGCNLayer(
        (node_feat): NodeFeatures(
          (U): Linear(in_features=300, out_features=300, bias=True)
          (V): Linear(in_features=300, out_features=300, bias=True)
        )
        (edge_feat): EdgeFeatures(
          (U): Linear(in_features=300, out_features=300, bias=True)
          (V): Linear(in_features=300, out_features=300, bias=True)
        )
        (bn_node): BatchNormNode(
          (batch_norm): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
        (bn_edge): BatchNormEdge(
          (batch_norm): BatchNorm2d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
      )
      (1): Re

In [6]:
checkpoint_path = model_path + 'tsp-models/tsp20/best_val_checkpoint.tar'
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
net.cuda()
net.eval()

DataParallel(
  (module): ResidualGatedGCNModel(
    (nodes_coord_embedding): Linear(in_features=2, out_features=300, bias=False)
    (edges_values_embedding): Linear(in_features=1, out_features=150, bias=False)
    (edges_embedding): Embedding(3, 150)
    (gcn_layers): ModuleList(
      (0): ResidualGatedGCNLayer(
        (node_feat): NodeFeatures(
          (U): Linear(in_features=300, out_features=300, bias=True)
          (V): Linear(in_features=300, out_features=300, bias=True)
        )
        (edge_feat): EdgeFeatures(
          (U): Linear(in_features=300, out_features=300, bias=True)
          (V): Linear(in_features=300, out_features=300, bias=True)
        )
        (bn_node): BatchNormNode(
          (batch_norm): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
        (bn_edge): BatchNormEdge(
          (batch_norm): BatchNorm2d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
      )
      (1): Re

In [8]:
for batch in GraphReader(graph_path):
    print(batch)
    break

{'edges': array([[[2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 2., 1., 

In [9]:
with torch.no_grad():
    for batch in GraphReader(graph_path):
        
        # Convert batch to torch Variables
        x_edges = Variable(torch.LongTensor(batch.edges).type(dtypeLong), requires_grad=False)
        x_edges_values = Variable(torch.FloatTensor(batch.edges_values).type(dtypeFloat), requires_grad=False)
        x_nodes = Variable(torch.LongTensor(batch.nodes).type(dtypeLong), requires_grad=False)
        x_nodes_coord = Variable(torch.FloatTensor(batch.nodes_coord).type(dtypeFloat), requires_grad=False)
        gFilename = batch.gFilename
#         set_trace()
        # Forward pass
        feats = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord)
        feats = np.squeeze(feats.detach().cpu().numpy(), axis=0)
        print(gFilename)
        np.save(feat_path + gFilename, feats)



GobvEa8_29
ncjIZph_29
DmlPhPL_23
VSFeJbu_27
v1QgjK2_21
HajFJy8_25
mK2LEUc_20
ifqkROg_25
S1vWrMZ_27
g8D8XJ0_24
Q5zGOH2_23
aMzwSDU_26
SK7yT4R_28
RywxplT_25
TlLuIx6_25
kAxMdeS_20
KD4NC2D_22
WwGBiXP_25
rf5Qw05_25
41CM1ou_20
9RUT1km_24
u4PNLAv_22
UcfOzyg_23
djKyhGm_25
gmjfpu7_24
UUU46YX_22
vbuOIxv_29
iJ12zqQ_28
oNczb7D_21
jloraYj_22
D5HsbR1_27
EKNhTe5_24
dDvy76B_20
2U60EbM_20
Cysp6Xz_25
4QoDZso_27
JDQhboQ_28
sSOrVOk_29
RaZkiNO_28
pVnjRLy_20
FXvGbbc_20
FFPNocV_21
9GjxOmH_22
h54wYsQ_27
8R2PMqt_25
vz6U5aU_21
XygL65H_22
VmwvLoc_20
7cov5QY_23
xGbpljr_28
sdLYopd_22
un4DYCY_24
AMpZvq7_29
6DPxHyU_25
iRJi3rZ_24
RzA4bLg_28
W2jwHI4_23
dtaDvZB_23
bbGs1hW_26
ZEsXrsF_26
3qtjWCZ_26
KLKnb6R_26
DlJ1DHc_20
fxWio3P_26
v51LunR_21
1U8GiGk_22
t4HaTuT_24
Hd0D1od_29
2n6TjeS_23
4Sq1ZKL_25
FzI252H_28
6MtqVKk_24
CNfi0jE_20
qAbH071_21
wNxhCFd_20
F568jmC_23
lfpaqbm_29
Wee1WkU_25
87Tm4tS_23
k9vh15y_20
kF9SBGM_23
OJ0qxVA_27
SyKTu6a_21
AVpr39m_26
GUApY74_27
mP15g2c_21
3hJFpCm_24
hHonR9v_21
Wy1wIEv_21
RJfT2HM_25
041NNsa_29