In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np 
import torch_geometric.nn as pyg_nn 
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Dataset, Data
import os.path as osp
import scipy.io as sio
from dataset import QUASARDataset

In [30]:
dir = '/Users/hankyang/Datasets/QUASAR'
dataset = QUASARDataset(dir)
data = dataset[0]

Processing...
Done!


In [31]:
class ModelS(nn.Module):
    def __init__(self, mp_input_dim=6,
                       mp_hidden_dim=32,
                       mp_output_dim=64,
                       mp_num_layers=1, 
                       primal_node_mlp_hidden_dim=64,
                       primal_node_mlp_output_dim=10,
                       dual_node_mlp_hidden_dim=64,
                       dual_node_mlp_output_dim=10,
                       node_mlp_num_layers=1,
                       primal_edge_mlp_hidden_dim=64, 
                       primal_edge_mlp_output_dim=10, 
                       dual_edge_mlp_hidden_dim=64, 
                       dual_edge_mlp_output_dim=16, 
                       edge_mlp_num_layers=1, 
                       dropout_rate=0.2):
        super(ModelS,self).__init__()
        # Message passing
        self.mp_convs = nn.ModuleList()
        self.mp_convs.append(pyg_nn.SAGEConv(mp_input_dim,mp_hidden_dim))
        for i in range(mp_num_layers):
            self.mp_convs.append(pyg_nn.SAGEConv(mp_hidden_dim,mp_hidden_dim))
        self.mp_convs.append(pyg_nn.SAGEConv(mp_hidden_dim,mp_output_dim))

        # Post message passing
        # Primal node MLP
        self.primal_node_mlp = nn.ModuleList()
        self.primal_node_mlp.append(
            nn.Linear(mp_output_dim,primal_node_mlp_hidden_dim,dtype=torch.float64))
        for i in range(node_mlp_num_layers):
            self.primal_node_mlp.append(
                nn.Linear(primal_node_mlp_hidden_dim,primal_node_mlp_hidden_dim,dtype=torch.float64))
        self.primal_node_mlp.append(
            nn.Linear(primal_node_mlp_hidden_dim,primal_node_mlp_output_dim,dtype=torch.float64))
        # Dual node MLP
        self.dual_node_mlp = nn.ModuleList()
        self.dual_node_mlp.append(
            nn.Linear(mp_output_dim,dual_node_mlp_hidden_dim,dtype=torch.float64))
        for i in range(node_mlp_num_layers):
            self.dual_node_mlp.append(
                nn.Linear(dual_node_mlp_hidden_dim,dual_node_mlp_hidden_dim,dtype=torch.float64))
        self.dual_node_mlp.append(
            nn.Linear(dual_node_mlp_hidden_dim,dual_node_mlp_output_dim,dtype=torch.float64))
        # Primal edge MLP
        self.primal_edge_mlp = nn.ModuleList()
        self.primal_edge_mlp.append(
            nn.Linear(mp_output_dim,primal_edge_mlp_hidden_dim,dtype=torch.float64))
        for i in range(edge_mlp_num_layers):
            self.primal_edge_mlp.append(
                nn.Linear(primal_edge_mlp_hidden_dim,primal_edge_mlp_hidden_dim,dtype=torch.float64))
        self.primal_edge_mlp.append(
            nn.Linear(primal_edge_mlp_hidden_dim,primal_edge_mlp_output_dim,dtype=torch.float64))
        # Dual edge MLP
        self.dual_edge_mlp = nn.ModuleList()
        self.dual_edge_mlp.append(
            nn.Linear(mp_output_dim,dual_edge_mlp_hidden_dim,dtype=torch.float64))
        for i in range(edge_mlp_num_layers):
            self.dual_edge_mlp.append(
                nn.Linear(dual_edge_mlp_hidden_dim,dual_edge_mlp_hidden_dim,dtype=torch.float64))
        self.dual_edge_mlp.append(
            nn.Linear(dual_edge_mlp_hidden_dim,dual_edge_mlp_output_dim,dtype=torch.float64))
        self.dropout_rate = dropout_rate

    def forward(self,data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        num_nodes = data.num_nodes
        ud_edges  = data.ud_edges
        edge_map  = data.edge_map
        # Message passing
        for mp_layer in self.mp_convs:
            x = mp_layer(x,edge_index)
            x = F.relu(x)
            x = F.dropout(x,p=self.dropout_rate,training=self.training)
        
        # Post message passing
        # Primal node
        vp = []
        for i in range(num_nodes):
            xi = x[i,:] # feature of i-th node
            for mlp_layer in self.primal_node_mlp:
                xi = mlp_layer(xi)
                xi = F.relu(xi)
                xi = F.dropout(xi,p=self.dropout_rate,training=self.training)
            vp.append(xi)
        vp = torch.stack(vp) # num_nodes x primal_node_mlp_output_dim
        # Dual node
        vd = []
        for i in range(num_nodes):
            xi = x[i,:]
            for mlp_layer in self.dual_node_mlp:
                xi = mlp_layer(xi)
                xi = F.relu(xi)
                xi = F.dropout(xi,p=self.dropout_rate,training=self.training)
            vd.append(xi)
        vd = torch.stack(vd) # num_nodes x dual_node_mlp_output_dim
        # Primal edge
        ep = []
        for edge in ud_edges:
            xi  = x[edge[0],:]
            xj  = x[edge[1],:]
            xij = xi + xj
            for mlp_layer in self.primal_edge_mlp:
                xij = mlp_layer(xij)
                xij = F.relu(xij)
                xij = F.dropout(xij,p=self.dropout_rate,training=self.training)
            ep.append(xij)
        ep = torch.stack(ep)
        # Dual edge
        ed = []
        for edge in ud_edges:
            xi  = x[edge[0],:]
            xj  = x[edge[1],:]
            xij = xi + xj
            for mlp_layer in self.dual_edge_mlp:
                xij = mlp_layer(xij)
                xij = F.relu(xij)
                xij = F.dropout(xij,p=self.dropout_rate,training=self.training)
            ed.append(xij)
        ed = torch.stack(ed)

        # Recover primal X
        X = recover_X(vp,ep,edge_map)
        print(X[0,0].grad_fn)
        print(x.grad_fn)
        return x, X

    def smat(self,x):
        X = torch.tensor([[x[0],x[1],x[2],x[3]],
                          [x[1],x[4],x[5],x[6]], 
                          [x[2],x[5],x[7],x[8]], 
                          [x[3],x[6],x[8],x[9]]],dtype=torch.float64)
        return X

    def mat(self,x,n):
        return x.view((n,n))

    def recover_X(self,vp,ep,edge_map):
        N = vp.shape[0] # number of nodes
        rows = []
        for i in range(N):
            row = []
            for j in range(N):
                if i == j: # diagonal blocks, using node features vp
                    blk = smat(vp[i,:])
                else: # off-diagonal blocks, using edge features ep
                    edge_id = edge_map[i,j]
                    blk = smat(ep[edge_id,:])
                row.append(blk)
            row_mat = torch.cat(row,dim=1)
            rows.append(row_mat)
        X = torch.cat(rows,dim=0)
        return X

    

In [32]:
# Test evaluate model
model = ModelS()
model.double()
model.eval()
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name)
#         print(param.data)
x, X = model(data)
print(torch.norm(X-X.t(),p='fro'))
print(X)

None
<ReluBackward0 object at 0x7fc060500430>
tensor(0., dtype=torch.float64)
tensor([[0.0000, 0.0300, 0.0000,  ..., 0.0000, 0.0974, 0.0959],
        [0.0300, 0.0986, 0.0489,  ..., 0.1042, 0.0000, 0.0460],
        [0.0000, 0.0489, 0.0000,  ..., 0.0000, 0.0000, 0.0578],
        ...,
        [0.0000, 0.1042, 0.0000,  ..., 0.1001, 0.0482, 0.1189],
        [0.0974, 0.0000, 0.0000,  ..., 0.0482, 0.0000, 0.1006],
        [0.0959, 0.0460, 0.0578,  ..., 0.1189, 0.1006, 0.1263]],
       dtype=torch.float64)
