In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np 
import torch_geometric.nn as pyg_nn 
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Dataset, Data, DataLoader
import torch.optim as optim
import os.path as osp
import scipy.io as sio
from dataset import QUASARDataset
from primal_model import PrimalModel

In [16]:
dir = '/home/hank/Datasets/QUASAR/small'
dataset = QUASARDataset(dir,num_graphs=100,remove_self_loops=True)

In [17]:
GNN_TYPE = 'SAGE'
GNN_HIDDEN_DIM = 64
GNN_OUT_DIM = 64
GNN_LAYER = 4
NODE_MODE = 1
DATA_GRAPH_TYPE = 1
DROPOUT = 0.2
MLP_LAYER = 1
model   = PrimalModel(node_feature_mode=NODE_MODE,
                     gnn_type=GNN_TYPE,
                     mp_hidden_dim=GNN_HIDDEN_DIM,mp_output_dim=GNN_OUT_DIM,mp_num_layers=GNN_LAYER, 
                     primal_node_mlp_hidden_dim=64,primal_node_mlp_output_dim=10,
                     node_mlp_num_layers=MLP_LAYER,
                     primal_edge_mlp_hidden_dim=64,primal_edge_mlp_output_dim=10,
                     edge_mlp_num_layers=MLP_LAYER, 
                     dropout_rate=DROPOUT,
                     relu_slope=0.1)
model.load_state_dict(torch.load('./models/primal_model_SAGE_4_64_64_1_1_1000_0_1.pth'))
model.double()
model.eval()

Model: node_feature_mode = 1, mp_input_dim = 6, relu_slope = 0.1. GNN type: SAGE.


PrimalModel(
  (mp_convs): ModuleList(
    (0): SAGEConv(6, 64)
    (1): SAGEConv(64, 64)
    (2): SAGEConv(64, 64)
    (3): SAGEConv(64, 64)
    (4): SAGEConv(64, 64)
    (5): SAGEConv(64, 64)
  )
  (primal_node_mlp): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=10, bias=True)
  )
  (primal_edge_mlp): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [18]:
loader = DataLoader(dataset,batch_size=1,shuffle=True)
for batch in loader:
    xorg = batch.x
    x, X = model(batch)
    primal_loss = model.loss(batch,X)
    print('batch loss: {:.4f}.'.format(
                primal_loss.item()))
    X = X[0]
    Xopt = batch.X[0]
    # break

batch loss: 0.0937.
batch loss: 0.0512.
batch loss: 0.0386.
batch loss: 0.1086.
batch loss: 0.0488.
batch loss: 0.1014.
batch loss: 0.0587.
batch loss: 0.0942.
batch loss: 0.0940.
batch loss: 0.0631.
batch loss: 0.1092.
batch loss: 0.0386.
batch loss: 0.1228.
batch loss: 0.0508.
batch loss: 0.0957.
batch loss: 0.1143.
batch loss: 0.0390.
batch loss: 0.1035.
batch loss: 0.0468.
batch loss: 0.0753.
batch loss: 0.0605.
batch loss: 0.0408.
batch loss: 0.0745.
batch loss: 0.0534.
batch loss: 0.1288.
batch loss: 0.0574.
batch loss: 0.0500.
batch loss: 0.0440.
batch loss: 0.0559.
batch loss: 0.0994.
batch loss: 0.0958.
batch loss: 0.0252.
batch loss: 0.0768.
batch loss: 0.0651.
batch loss: 0.0315.
batch loss: 0.1294.
batch loss: 0.0493.
batch loss: 0.0679.
batch loss: 0.0812.
batch loss: 0.0794.
batch loss: 0.0965.
batch loss: 0.1216.
batch loss: 0.0480.
batch loss: 0.0715.
batch loss: 0.0772.
batch loss: 0.0631.
batch loss: 0.1045.
batch loss: 0.0523.
batch loss: 0.0621.
batch loss: 0.0526.
