In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np 
import torch_geometric.nn as pyg_nn 
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Dataset, Data, DataLoader
import torch.optim as optim
import os.path as osp
import scipy.io as sio
from dataset import QUASARDataset
from dual_model import DualModel

In [2]:
dir = '/home/hank/Datasets/QUASAR/small'
dataset = QUASARDataset(dir,num_graphs=100,remove_self_loops=True)
test_dir = '/home/hank/Datasets/QUASAR/mix'
testset = QUASARDataset(test_dir,num_graphs=200,remove_self_loops=True)

Data graph type: 1.
Data graph type: 1.


In [3]:
GNN_TYPE = 'SAGE'
GNN_HIDDEN_DIM = 64
GNN_OUT_DIM = 64
GNN_LAYER = 4
NODE_MODE = 1
DATA_GRAPH_TYPE = 1
DROPOUT = 0
MLP_LAYER = 2
model   = DualModel(node_feature_mode=NODE_MODE,
                     gnn_type=GNN_TYPE,
                     mp_hidden_dim=GNN_HIDDEN_DIM,mp_output_dim=GNN_OUT_DIM,mp_num_layers=GNN_LAYER, 
                     dual_node_mlp_hidden_dim=64,dual_node_mlp_output_dim=10,
                     node_mlp_num_layers=MLP_LAYER,
                     dual_edge_mlp_hidden_dim=64,dual_edge_mlp_output_dim=6,
                     edge_mlp_num_layers=MLP_LAYER, 
                     dropout_rate=DROPOUT,
                     relu_slope=0.1)
model.load_state_dict(torch.load('./models/dual_model_SAGE_4_64_64_1_1_2000_0.0_2.pth'))
model.double()
model.eval()

Model: node_feature_mode = 1, mp_input_dim = 6, relu_slope = 0.1. GNN type: SAGE.


DualModel(
  (mp_convs): ModuleList(
    (0): SAGEConv(6, 64)
    (1): SAGEConv(64, 64)
    (2): SAGEConv(64, 64)
    (3): SAGEConv(64, 64)
    (4): SAGEConv(64, 64)
    (5): SAGEConv(64, 64)
  )
  (dual_node_mlp): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): Linear(in_features=64, out_features=10, bias=True)
  )
  (dual_edge_mlp): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): Linear(in_features=64, out_features=6, bias=True)
  )
)

In [4]:
# results on train set
loader = DataLoader(dataset,batch_size=1,shuffle=True)
for batch in loader:
    _, S, Aty = model(batch)
    dual_loss = model.loss(batch,S,Aty)
    print('batch loss: {:.4f}. # nodes: {:d}'.format(
                dual_loss.item(),batch.num_nodes))
    Atyopt = batch.Aty[0]
    Aty    = Aty[0]



batch loss: 0.0658. # nodes: 21
batch loss: 0.0997. # nodes: 21
batch loss: 0.0497. # nodes: 21
batch loss: 0.0674. # nodes: 21
batch loss: 0.0627. # nodes: 21
batch loss: 0.0531. # nodes: 21
batch loss: 0.0537. # nodes: 21
batch loss: 0.0559. # nodes: 21
batch loss: 0.0850. # nodes: 21
batch loss: 0.0597. # nodes: 21
batch loss: 0.0674. # nodes: 21
batch loss: 0.0655. # nodes: 21
batch loss: 0.0510. # nodes: 21
batch loss: 0.0871. # nodes: 21
batch loss: 0.0651. # nodes: 21
batch loss: 0.0733. # nodes: 21
batch loss: 0.0795. # nodes: 21
batch loss: 0.0724. # nodes: 21
batch loss: 0.0870. # nodes: 21
batch loss: 0.0843. # nodes: 21
batch loss: 0.0673. # nodes: 21
batch loss: 0.0796. # nodes: 21
batch loss: 0.0495. # nodes: 21
batch loss: 0.0797. # nodes: 21
batch loss: 0.0771. # nodes: 21
batch loss: 0.0662. # nodes: 21
batch loss: 0.0588. # nodes: 21
batch loss: 0.0473. # nodes: 21
batch loss: 0.0752. # nodes: 21
batch loss: 0.0689. # nodes: 21
batch loss: 0.0549. # nodes: 21
batch lo

In [5]:
# results on test set
loader = DataLoader(testset,batch_size=1,shuffle=True)
for batch in loader:
    _, S, Aty = model(batch)
    dual_loss = model.loss(batch,S,Aty)
    print('batch loss: {:.4f}. # nodes: {:d}'.format(
                dual_loss.item(),batch.num_nodes))
    Atyopt = batch.Aty[0]
    Aty    = Aty[0]

batch loss: 0.4130. # nodes: 21
batch loss: 0.2597. # nodes: 21
batch loss: 0.3967. # nodes: 31
batch loss: 0.4582. # nodes: 31
batch loss: 0.3249. # nodes: 21
batch loss: 0.3989. # nodes: 31
batch loss: 0.2399. # nodes: 21
batch loss: 0.5908. # nodes: 31
batch loss: 0.4067. # nodes: 31
batch loss: 0.3520. # nodes: 21
batch loss: 0.5084. # nodes: 31
batch loss: 0.4294. # nodes: 31
batch loss: 0.3257. # nodes: 21
batch loss: 0.4061. # nodes: 31
batch loss: 0.4345. # nodes: 31
batch loss: 0.2260. # nodes: 21
batch loss: 0.1885. # nodes: 21
batch loss: 0.3833. # nodes: 21
batch loss: 0.2188. # nodes: 21
batch loss: 0.1755. # nodes: 21
batch loss: 0.5114. # nodes: 31
batch loss: 0.4558. # nodes: 31
batch loss: 0.4728. # nodes: 31
batch loss: 0.3957. # nodes: 31
batch loss: 0.6119. # nodes: 31
batch loss: 0.4698. # nodes: 31
batch loss: 0.3938. # nodes: 31
batch loss: 0.5417. # nodes: 31
batch loss: 0.2695. # nodes: 21
batch loss: 0.4996. # nodes: 31
batch loss: 0.3964. # nodes: 31
batch lo

In [4]:
# int(0.5*5)
200 % 199

1