In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np 
import torch_geometric.nn as pyg_nn 
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Dataset, Data, DataLoader
import torch.optim as optim
import os.path as osp
import scipy.io as sio
from datetime import datetime
from tensorboardX import SummaryWriter
from dataset import QUASARDataset
from model import ModelS

In [2]:
dir = '/home/hank/Datasets/QUASAR/small'
dataset = QUASARDataset(dir,num_graphs=100,remove_self_loops=True)

In [3]:
model   = ModelS(node_feature_mode=3,
                     gnn_type='SAGE',
                     mp_hidden_dim=128,mp_output_dim=128,mp_num_layers=8, 
                     primal_node_mlp_hidden_dim=64,primal_node_mlp_output_dim=10,
                     dual_node_mlp_hidden_dim=64,dual_node_mlp_output_dim=10,
                     node_mlp_num_layers=0,
                     primal_edge_mlp_hidden_dim=64,primal_edge_mlp_output_dim=10, 
                     dual_edge_mlp_hidden_dim=64,dual_edge_mlp_output_dim=6, 
                     edge_mlp_num_layers=0, 
                     dropout_rate=0.0,
                     relu_slope=0.1)
model.load_state_dict(torch.load('./models/model_SAGE_8_128_128_20220505-210509.pth'))
model.double()
model.eval()

Model: node_feature_mode = 3, mp_input_dim = 16, relu_slope = 0.1. GNN type: SAGE.


ModelS(
  (mp_convs): ModuleList(
    (0): SAGEConv(16, 128)
    (1): SAGEConv(128, 128)
    (2): SAGEConv(128, 128)
    (3): SAGEConv(128, 128)
    (4): SAGEConv(128, 128)
    (5): SAGEConv(128, 128)
    (6): SAGEConv(128, 128)
    (7): SAGEConv(128, 128)
    (8): SAGEConv(128, 128)
    (9): SAGEConv(128, 128)
  )
  (primal_node_mlp): ModuleList(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=10, bias=True)
  )
  (dual_node_mlp): ModuleList(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=10, bias=True)
  )
  (primal_edge_mlp): ModuleList(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=10, bias=True)
  )
  (dual_edge_mlp): ModuleList(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=6, bias=True)
  )
)

In [4]:
loader = DataLoader(dataset,batch_size=1,shuffle=True)
for batch in loader:
    xorg = batch.x
    x, X, S, Aty = model(batch)
    primal_loss, dual_loss = model.loss(batch,X,S,Aty)
    loss = primal_loss + dual_loss
    print('batch loss: {:.4f}, primal: {:.4f}, dual: {:.4f}'.format(
                loss.item(),primal_loss.item(),dual_loss.item()))
    X = X[0]
    Aty = Aty[0]
    Xopt = batch.X[0]
    Atyopt = batch.Aty[0]
    print(X)
    print(Aty)
    print(Xopt)
    print(Atyopt)
    break


batch loss: 1.4276, primal: 0.9489, dual: 0.4788
tensor([[ 0.2011,  0.0745, -0.0923,  ..., -0.0093,  0.0368,  0.0552],
        [ 0.0745,  0.1219, -0.0401,  ...,  0.0105, -0.0066,  0.0514],
        [-0.0923, -0.0401,  0.2030,  ..., -0.0066,  0.1044, -0.0700],
        ...,
        [-0.0093,  0.0105, -0.0066,  ...,  0.2140, -0.0444,  0.0487],
        [ 0.0368, -0.0066,  0.1044,  ..., -0.0444,  0.1680, -0.0308],
        [ 0.0552,  0.0514, -0.0700,  ...,  0.0487, -0.0308,  0.4000]],
       dtype=torch.float64, grad_fn=<CatBackward0>)
tensor([[ 5.2864e+03, -4.0884e+01,  2.0559e+02,  ...,  7.5325e-01,
          1.3839e-01, -2.1197e-01],
        [-4.0884e+01,  5.2038e+03,  1.1447e+02,  ...,  0.0000e+00,
          6.5001e-01, -4.0258e-02],
        [ 2.0559e+02,  1.1447e+02,  5.3085e+03,  ..., -6.5001e-01,
          0.0000e+00,  5.0069e-01],
        ...,
        [ 7.5325e-01,  0.0000e+00, -6.5001e-01,  ..., -2.1735e+02,
         -3.7321e+00,  7.1932e+00],
        [ 1.3839e-01,  6.5001e-01,  0.00



In [12]:
a = torch.tensor([1,1,-1])
print(a[:-1])
print(a[-1])
filename = f'./models/model_SAGE_{4}_{128}_{128}_{datetime.now().strftime("%Y%m%d-%H%M%S")}.pth'
print(filename)
torch.device('cuda')

tensor([1, 1])
tensor(-1)
./models/model_SAGE_4_128_128_20220505-175752.pth


device(type='cuda')