In [3]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import collections
import time
import numpy as np

import sys

import os

# Learning Graph Structures
Consider the following Structure Learning Convolution (Zhang et al. 2020):

\begin{equation}
  y_i=f\left(\sum_{e_{i,j}\in\varepsilon}S_{ij}w_jx_j\right) \quad (1)
\end{equation}
We introduce a learnable graph Structure $S$ which encodes correlations $S_{ij}$ between node $n_i$ and $n_j$ where $S$ is a learnable variable. In $(1)$ $x_j$ is the j-th nodes' embedded feature vector and $y_i$ is the convolved output signal of node $i$. 

Given a Graph $G=(V,E)$ with signal $X\in\mathbb{R}^{N\times D}$ and adjacency matrix $A$ we can use $(1)$ to define a graph convolutional layer as:
\begin{equation}
  \tilde{A}=D^{-1/2}AD^{-1/2} \quad (2)
\end{equation}
\begin{equation}
  H^{(l+1)}=\sigma(\tilde{A}\circ SH^{(l)}W^{(l)}) \quad (3)
\end{equation}
where $W^{(l)}\in\mathbb{R}^{C_{in}\times C_{out}}$ is a learnable matrix, $C_{in}$ is the number of input chanels, $C_{out}$ is the number of output chanels and $\circ$ denotes the Hadamard product. Remark that $H^{(0)}=X\in\mathbb{R}^{N\times D}$ is the graph signal of all nodes of the Graph.

## Target of the structure learning
Consider a signal $X_t$ at time step $t$ and a convolutional neural network as defined by equation $(2)$ with $l$ layers $H^{(1)},...,H^{(l)}$ and associated filters $W^{(1)},...,W^{(l)}$ we try to predict the graphs signal at timestep $t+1$.

In [5]:
# Load data into a dataset (for time-series preferrably)

# Generate a data loader

In [None]:
import math

from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
import torch
import torch.nn.functional as F

# define SLC layer
class SLConv(Module):
    def __init__(self, in_chanels, out_chanels):
        super(SLConv,self).__init__()
        self.in_chanels = in_chanels
        self.out_chanels = out_chanels
        self.weight = Parameter(torch.FloatTensor(in_chanels, out_chanels))        
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        
    def forward(self, x, adj, S):
        x = torch.mm(x, self.weight) # (N,out_chanels)
        weighting = torch.mul(S, adj) # (N,N)
        output = torch.mm(weighting,x)
        return output
    
    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_chanels) + ' -> ' \
               + str(self.out_chanels) + ')'

# model defitinition
class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, N):
        super(GCN, self).__init__()

        self.gc1 = SLConv(nfeat, nhid)
        self.gc2 = SLConv(nhid, nclass)
        
        self.S = Parameter(torch.FloatTensor(N,N)).data.uniform_(-1.,1.)

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj, self.S))
        x = self.gc2(x, adj, self.S)
        return x

In [None]:
# 3 x 3 Matrix
adj = torch.tensor([[1,0,0], [1,1,0], [0,1,1]])
S = torch.rand(3,3)
in_chanels = 1
out_chanels = 1
X = torch.rand(3,1)


model = GCN(1, 1, 1, 3)
model.forward(X,adj)

In [39]:
# 3 x 3 Matrix
adj = torch.tensor([[1,0,0], [1,1,0], [0,1,1]])
S = torch.rand(3,3)
in_chanels = 1
out_chanels = 1
X = torch.rand(3,1)


model = GCN(1, 1, 1, 3)
model.forward(X,adj)

tensor([[-0.0081],
        [-0.0740],
        [-0.0450]], grad_fn=<MmBackward>)