In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np 
import torch_geometric.nn as pyg_nn 
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Dataset, Data, DataLoader
import torch.optim as optim
import os.path as osp
import scipy.io as sio
from dataset import QUASARDataset
from dual_model import DualModelFtype
from dual_train import validate

In [4]:
# Device
DEVICE = torch.device('cuda:0')
# Model
GNN_TYPE = 'SAGE'
GNN_HIDDEN_DIM = 64
GNN_OUT_DIM = GNN_HIDDEN_DIM
GNN_LAYER = 31
NODE_MODE = 1
DATA_GRAPH_TYPE = 1
DROPOUT = 0.2
MLP_LAYER = 2
RESIDUAL = True
BATCHNORM = True
DUAL_OUT_DIM = 6
model = DualModelFtype(
        node_feature_mode=NODE_MODE,
        gnn_type=GNN_TYPE,
        mp_hidden_dim=GNN_HIDDEN_DIM,mp_output_dim=GNN_OUT_DIM,mp_num_layers=GNN_LAYER, 
        dual_node_mlp_hidden_dim=GNN_OUT_DIM,dual_node_mlp_output_dim=10,
        node_mlp_num_layers=MLP_LAYER,
        dual_edge_mlp_hidden_dim=GNN_OUT_DIM,dual_edge_mlp_output_dim=DUAL_OUT_DIM,
        edge_mlp_num_layers=MLP_LAYER,
        dropout_rate=DROPOUT,
        relu_slope=0.1,
        residual=RESIDUAL,
        batchnorm=BATCHNORM)
model.load_state_dict(torch.load('./models/dual_model_N30-5000_SAGE_31_64_2_6_True_True.pth'))
model.double()
model.to(DEVICE)
# Dataset
setname = 'N30-100'
dir = f'/home/hank/Datasets/QUASAR/{setname}'
data = QUASARDataset(dir,num_graphs=100,remove_self_loops=True)
loader = DataLoader(data,batch_size=1,shuffle=False)

def rel_norm_err(A,B):
    norm_err = torch.norm(A-B,p='fro')
    norm_gt = torch.norm(B,p='fro')
    return torch.div(norm_err,norm_gt)

Model: node_feature_mode = 1, mp_input_dim = 6, relu_slope = 0.1. GNN type: SAGE. Residual: True. BatchNorm: True.
Data graph type: 1.


In [5]:
dual_pred = []
print(f'DUAL_OUT_DIM: {DUAL_OUT_DIM}.')
if DUAL_OUT_DIM == 6:
    with torch.no_grad():
        model.eval()
        # Generate data
        for batch in loader:
            batch.to(DEVICE)
            _, V, E = model(batch)
            sol = model.recover(V,E,batch)
            Aty_hat = sol[0]
            C = torch.tensor(batch.C[0],dtype=torch.float64,device=Aty_hat.device)
            S_hat = C - Aty_hat
            Aty_gt = torch.tensor(batch.Aty[0],dtype=torch.float64,device=Aty_hat.device)
            S_gt = torch.tensor(batch.S[0],dtype=torch.float64,device=Aty_hat.device)
            err_Aty = rel_norm_err(Aty_hat,Aty_gt)
            err_S = rel_norm_err(S_hat,S_gt)
            print('Aty norm err: {:.4f}, S norm err: {:.4f}.'.format(err_Aty,err_S))
            dual_pred.append(Aty_hat.cpu().numpy())
elif DUAL_OUT_DIM == 16:
    with torch.no_grad():
        model.eval()
        # Generate data
        for batch in loader:
            batch.to(DEVICE)
            _, V, E = model(batch)
            sol = model.recover(V,E,batch)
            S_hat = sol[0]
            C = torch.tensor(batch.C[0],dtype=torch.float64,device=S_hat.device)
            Aty_hat = C - S_hat
            Aty_gt = torch.tensor(batch.Aty[0],dtype=torch.float64,device=Aty_hat.device)
            S_gt = torch.tensor(batch.S[0],dtype=torch.float64,device=Aty_hat.device)
            err_Aty = rel_norm_err(Aty_hat,Aty_gt)
            err_S = rel_norm_err(S_hat,S_gt)
            print('Aty norm err: {:.4f}, S norm err: {:.4f}.'.format(err_Aty,err_S))
            dual_pred.append(Aty_hat.cpu().numpy())

mdict = {"dual_pred": dual_pred}
if DUAL_OUT_DIM == 6:
    fname = f'{setname}_dual_sol_Aty.mat'
elif DUAL_OUT_DIM == 16:
    fname = f'{setname}_dual_sol_S.mat'
sio.savemat(fname,mdict)

DUAL_OUT_DIM: 6.
Aty norm err: 0.6997, S norm err: 0.6953.
Aty norm err: 0.7080, S norm err: 0.7017.
Aty norm err: 0.7076, S norm err: 0.6996.
Aty norm err: 0.7018, S norm err: 0.6948.
Aty norm err: 0.7209, S norm err: 0.7160.
Aty norm err: 0.7113, S norm err: 0.7049.
Aty norm err: 0.7173, S norm err: 0.7117.
Aty norm err: 0.7152, S norm err: 0.7092.
Aty norm err: 0.7098, S norm err: 0.7025.
Aty norm err: 0.7030, S norm err: 0.6961.
Aty norm err: 0.7019, S norm err: 0.6967.
Aty norm err: 0.7009, S norm err: 0.6950.
Aty norm err: 0.7016, S norm err: 0.6959.
Aty norm err: 0.7035, S norm err: 0.6974.
Aty norm err: 0.7107, S norm err: 0.7049.
Aty norm err: 0.7068, S norm err: 0.7009.
Aty norm err: 0.7190, S norm err: 0.7141.
Aty norm err: 0.7052, S norm err: 0.6990.
Aty norm err: 0.7009, S norm err: 0.6961.
Aty norm err: 0.7054, S norm err: 0.6997.
Aty norm err: 0.7015, S norm err: 0.6958.
Aty norm err: 0.7019, S norm err: 0.6954.
Aty norm err: 0.7173, S norm err: 0.7106.
Aty norm err: 0.6