In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np 
import torch_geometric.nn as pyg_nn 
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Dataset, Data, DataLoader
import torch.optim as optim
import os.path as osp
import scipy.io as sio
from dataset import QUASARDataset
from dual_model import DualModelFtype
from dual_train import validate

In [14]:
# Device
DEVICE = torch.device('cuda:1')
# Model
GNN_TYPE = 'SAGE'
GNN_HIDDEN_DIM = 64
GNN_OUT_DIM = GNN_HIDDEN_DIM
GNN_LAYER = 15
NODE_MODE = 1
DATA_GRAPH_TYPE = 1
DROPOUT = 0.2
MLP_LAYER = 2
RESIDUAL = True
BATCHNORM = True
DUAL_OUT_DIM = 6
model = DualModelFtype(
        node_feature_mode=NODE_MODE,
        gnn_type=GNN_TYPE,
        mp_hidden_dim=GNN_HIDDEN_DIM,mp_output_dim=GNN_OUT_DIM,mp_num_layers=GNN_LAYER, 
        dual_node_mlp_hidden_dim=GNN_OUT_DIM,dual_node_mlp_output_dim=10,
        node_mlp_num_layers=MLP_LAYER,
        dual_edge_mlp_hidden_dim=GNN_OUT_DIM,dual_edge_mlp_output_dim=DUAL_OUT_DIM,
        edge_mlp_num_layers=MLP_LAYER,
        dropout_rate=DROPOUT,
        relu_slope=0.1,
        residual=RESIDUAL,
        batchnorm=BATCHNORM)
# model.load_state_dict(torch.load('./models/good/dual_model_N30-5000_SAGE_31_64_2_6_True_True.pth'))
model.load_state_dict(torch.load('./models/good/dual_model_N50-5000_SAGE_15_64_2_6_True_True_1_1201.pth'))
model.double()
model.to(DEVICE)
# Dataset
setname = 'N100-100'
num_graphs = 100
dir = f'/home/hank/Datasets/QUASAR/{setname}'
data = QUASARDataset(dir,num_graphs=num_graphs,remove_self_loops=True)
loader = DataLoader(data,batch_size=1,shuffle=False)

def rel_norm_err(A,B):
    norm_err = torch.norm(A-B,p='fro')
    norm_gt = torch.norm(B,p='fro')
    return torch.div(norm_err,norm_gt)

Model: node_feature_mode = 1, mp_input_dim = 6, relu_slope = 0.1. GNN type: SAGE. Residual: True. BatchNorm: True.
Data graph type: 1.


In [15]:
dual_pred = []
S_norm_err = []
Aty_norm_err = []
print(f'DUAL_OUT_DIM: {DUAL_OUT_DIM}.')
if DUAL_OUT_DIM == 6:
    with torch.no_grad():
        model.eval()
        # Generate data
        for batch in loader:
            batch.to(DEVICE)
            _, V, E = model(batch)
            sol = model.recover(V,E,batch)
            Aty_hat = sol[0]
            C = torch.tensor(batch.C[0],dtype=torch.float64,device=Aty_hat.device)
            S_hat = C - Aty_hat
            Aty_gt = torch.tensor(batch.Aty[0],dtype=torch.float64,device=Aty_hat.device)
            S_gt = torch.tensor(batch.S[0],dtype=torch.float64,device=Aty_hat.device)
            err_Aty = rel_norm_err(Aty_hat,Aty_gt)
            err_S = rel_norm_err(S_hat,S_gt)
            print('Aty norm err: {:.4f}, S norm err: {:.4f}.'.format(err_Aty,err_S))
            dual_pred.append(Aty_hat.cpu().numpy())
            S_norm_err.append(err_S.cpu().numpy())
            Aty_norm_err.append(err_Aty.cpu().numpy())
elif DUAL_OUT_DIM == 16:
    with torch.no_grad():
        model.eval()
        # Generate data
        for batch in loader:
            batch.to(DEVICE)
            _, V, E = model(batch)
            sol = model.recover(V,E,batch)
            S_hat = sol[0]
            C = torch.tensor(batch.C[0],dtype=torch.float64,device=S_hat.device)
            Aty_hat = C - S_hat
            Aty_gt = torch.tensor(batch.Aty[0],dtype=torch.float64,device=Aty_hat.device)
            S_gt = torch.tensor(batch.S[0],dtype=torch.float64,device=Aty_hat.device)
            err_Aty = rel_norm_err(Aty_hat,Aty_gt)
            err_S = rel_norm_err(S_hat,S_gt)
            print('Aty norm err: {:.4f}, S norm err: {:.4f}.'.format(err_Aty,err_S))
            dual_pred.append(Aty_hat.cpu().numpy())
            S_norm_err.append(err_S.cpu().numpy())
            Aty_norm_err.append(err_Aty.cpu().numpy())

S_norm_err = np.array(S_norm_err)
Aty_norm_err = np.array(Aty_norm_err)
mdict = {"dual_pred": dual_pred, "S_norm_err": S_norm_err, "Aty_norm_err": Aty_norm_err}
# mdict = {"Aty_norm_err": Aty_norm_err, "S_norm_err": S_norm_err}
if DUAL_OUT_DIM == 6:
    fname = f'{setname}_dual_sol_Aty.mat'
elif DUAL_OUT_DIM == 16:
    fname = f'{setname}_dual_sol_S.mat'
sio.savemat(fname,mdict)

DUAL_OUT_DIM: 6.
Aty norm err: 0.5077, S norm err: 0.5045.
Aty norm err: 0.5101, S norm err: 0.5056.
Aty norm err: 0.5221, S norm err: 0.5162.
Aty norm err: 0.5187, S norm err: 0.5135.
Aty norm err: 0.5229, S norm err: 0.5194.
Aty norm err: 0.5085, S norm err: 0.5040.
Aty norm err: 0.5292, S norm err: 0.5251.
Aty norm err: 0.5215, S norm err: 0.5171.
Aty norm err: 0.5051, S norm err: 0.4999.
Aty norm err: 0.5063, S norm err: 0.5012.
Aty norm err: 0.5078, S norm err: 0.5041.
Aty norm err: 0.5057, S norm err: 0.5015.
Aty norm err: 0.5221, S norm err: 0.5178.
Aty norm err: 0.4989, S norm err: 0.4947.
Aty norm err: 0.5184, S norm err: 0.5142.
Aty norm err: 0.5110, S norm err: 0.5067.
Aty norm err: 0.5351, S norm err: 0.5315.
Aty norm err: 0.5492, S norm err: 0.5443.
Aty norm err: 0.5194, S norm err: 0.5159.
Aty norm err: 0.5035, S norm err: 0.4995.
Aty norm err: 0.5053, S norm err: 0.5012.
Aty norm err: 0.4949, S norm err: 0.4903.
Aty norm err: 0.5273, S norm err: 0.5223.
Aty norm err: 0.5