In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np 
import torch_geometric.nn as pyg_nn 
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Dataset, Data, DataLoader
import torch.optim as optim
import os.path as osp
import scipy.io as sio
from dataset import QUASARDataset
from primal_model import PrimalModel

In [7]:
dir = '/home/hank/Datasets/QUASAR/small'
dataset = QUASARDataset(dir,num_graphs=100,remove_self_loops=True)
test_dir = '/home/hank/Datasets/QUASAR/mix'
testset = QUASARDataset(test_dir,num_graphs=200,remove_self_loops=True)

Data graph type: 1.
Data graph type: 1.


Processing...


Expected # graphs: 200. Actual # graphs: 200.


Done!


In [8]:
GNN_TYPE = 'SAGE'
GNN_HIDDEN_DIM = 64
GNN_OUT_DIM = 64
GNN_LAYER = 4
NODE_MODE = 1
DATA_GRAPH_TYPE = 1
DROPOUT = 0
MLP_LAYER = 2
model   = PrimalModel(node_feature_mode=NODE_MODE,
                     gnn_type=GNN_TYPE,
                     mp_hidden_dim=GNN_HIDDEN_DIM,mp_output_dim=GNN_OUT_DIM,mp_num_layers=GNN_LAYER, 
                     primal_node_mlp_hidden_dim=64,primal_node_mlp_output_dim=10,
                     node_mlp_num_layers=MLP_LAYER,
                     primal_edge_mlp_hidden_dim=64,primal_edge_mlp_output_dim=10,
                     edge_mlp_num_layers=MLP_LAYER, 
                     dropout_rate=DROPOUT,
                     relu_slope=0.1)
model.load_state_dict(torch.load('./models/primal_model_SAGE_4_64_64_1_1_1000_0.0_2.pth'))
model.double()
model.eval()

Model: node_feature_mode = 1, mp_input_dim = 6, relu_slope = 0.1. GNN type: SAGE.


PrimalModel(
  (mp_convs): ModuleList(
    (0): SAGEConv(6, 64)
    (1): SAGEConv(64, 64)
    (2): SAGEConv(64, 64)
    (3): SAGEConv(64, 64)
    (4): SAGEConv(64, 64)
    (5): SAGEConv(64, 64)
  )
  (primal_node_mlp): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): Linear(in_features=64, out_features=10, bias=True)
  )
  (primal_edge_mlp): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [9]:
# results on train set
loader = DataLoader(dataset,batch_size=1,shuffle=True)
for batch in loader:
    xorg = batch.x
    x, X = model(batch)
    primal_loss = model.loss(batch,X)
    print('batch loss: {:.4f}.'.format(
                primal_loss.item()))
    X = X[0]
    Xopt = batch.X[0]



batch loss: 0.0219.
batch loss: 0.0515.
batch loss: 0.4335.
batch loss: 0.0777.
batch loss: 0.0629.
batch loss: 0.0476.
batch loss: 0.0529.
batch loss: 0.0604.
batch loss: 0.0528.
batch loss: 0.0867.
batch loss: 0.0546.
batch loss: 0.0475.
batch loss: 0.0450.
batch loss: 0.0388.
batch loss: 0.0404.
batch loss: 0.0300.
batch loss: 0.0997.
batch loss: 0.0815.
batch loss: 0.1194.
batch loss: 0.0530.
batch loss: 0.0523.
batch loss: 0.0761.
batch loss: 0.0556.
batch loss: 0.0638.
batch loss: 0.0721.
batch loss: 0.0469.
batch loss: 0.0648.
batch loss: 0.0523.
batch loss: 0.0473.
batch loss: 0.0451.
batch loss: 0.0535.
batch loss: 0.0366.
batch loss: 0.0776.
batch loss: 0.0793.
batch loss: 0.0416.
batch loss: 0.0789.
batch loss: 0.0451.
batch loss: 0.0918.
batch loss: 0.0734.
batch loss: 0.0496.
batch loss: 0.0475.
batch loss: 0.0683.
batch loss: 0.0677.
batch loss: 0.0232.
batch loss: 0.1393.
batch loss: 0.0998.
batch loss: 0.0626.
batch loss: 0.0725.
batch loss: 0.0517.
batch loss: 0.0658.


In [13]:
# results on test set
loader = DataLoader(testset,batch_size=1,shuffle=True)
for batch in loader:
    xorg = batch.x
    x, X = model(batch)
    primal_loss = model.loss(batch,X)
    print('batch loss: {:.4f}, # nodes: {:d}'.format(
                primal_loss.item(),batch.num_nodes))
    X = X[0]
    Xopt = batch.X[0]

batch loss: 1.1055, # nodes: 31
batch loss: 1.0099, # nodes: 31
batch loss: 0.6909, # nodes: 31
batch loss: 1.0999, # nodes: 21
batch loss: 1.1855, # nodes: 21
batch loss: 1.0193, # nodes: 31
batch loss: 1.0875, # nodes: 21
batch loss: 0.9957, # nodes: 21
batch loss: 1.1677, # nodes: 21
batch loss: 1.0508, # nodes: 31
batch loss: 1.1411, # nodes: 31
batch loss: 1.1745, # nodes: 21
batch loss: 0.8419, # nodes: 31
batch loss: 0.9663, # nodes: 31
batch loss: 1.1386, # nodes: 21
batch loss: 1.1529, # nodes: 21
batch loss: 1.2428, # nodes: 31
batch loss: 1.0744, # nodes: 31
batch loss: 1.2386, # nodes: 21
batch loss: 0.9069, # nodes: 31
batch loss: 1.1059, # nodes: 21
batch loss: 1.1685, # nodes: 21
batch loss: 1.1222, # nodes: 31
batch loss: 1.0433, # nodes: 31
batch loss: 1.1852, # nodes: 21
batch loss: 1.0296, # nodes: 21
batch loss: 1.1383, # nodes: 21
batch loss: 1.0608, # nodes: 31
batch loss: 1.2631, # nodes: 31
batch loss: 1.2420, # nodes: 31
batch loss: 1.2987, # nodes: 31
batch lo

In [3]:
GNN_TYPE = 'SAGE'
GNN_HIDDEN_DIM = 64
GNN_OUT_DIM = 64
GNN_LAYER = 4
NODE_MODE = 1
DATA_GRAPH_TYPE = 1
NUM_EPOCHES = 2000
DROPOUT = 0.0
MLP_LAYER = 2
pfname = "./models/p_model_{}_{}_{}_{}.pth".format(
    GNN_TYPE,GNN_LAYER,MLP_LAYER,NUM_EPOCHES)
print(pfname)

./models/p_model_SAGE_4_2_2000.pth
