# Example code for running GISST
This notebook serves as a quick tutorial for using the APIs in this repository. It provides an example of training a node-classification GISST model and getting node feature and edge interpretation results for a specific node. The other models and interpretation methods follow the same API design.

Dataset splitting for cross validation, model evaluation, and hyperparameter tuning are intentionally omitted here for simplicity. For details, please refer to the source code such as `synthesize_graph.py` and `train_model.py`.

In [1]:
import torch
from torch_geometric.data import Data
from synthesize_graph import syn_ba_house, syn_node_feat
from sig.utils.synthetic_graph import build_house, build_cycle
from sig.utils.pyg_utils import get_pyg_edge_index
from sig.nn.loss.regularization_loss import reg_sig_loss
from sig.nn.loss.classification_loss import cross_entropy_loss
from sig.nn.models.sigcn import SIGCN
from sig.explainers.sig_explainer import SIGExplainer

### 1. Prepare the node classification graph
First prepare a graph using `networkx`. Store the node class labels as a list and the node feature matrix as a numpy array.

In [2]:
graph, labels, _ = syn_ba_house()
node_feat, _ = syn_node_feat(labels, sigma_scale=0.1)
num_class = len(set(labels))

In [3]:
print('graph: {}'.format(type(graph)))
print('labels: {}'.format(type(labels)))
print('node_feat: {}'.format(type(node_feat)))

graph: <class 'networkx.classes.graph.Graph'>
labels: <class 'list'>
node_feat: <class 'numpy.ndarray'>


In [4]:
print('graph num nodes: {}'.format(graph.number_of_nodes()))
print('graph num edges: {}'.format(graph.number_of_edges()))
print('labels len: {}'.format(len(labels)))
print('node_feat shape: {}'.format(node_feat.shape))

graph num nodes: 700
graph num edges: 2238
labels len: 700
node_feat shape: (700, 50)


### 2. Format the graph data for PyTorch Geometric
Under the hood, the APIs are implemented using PyTorch Geometric. Hence the graph data should be converted into an instance of `torch_geometric.data.Data` (https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs). As a side note, for graph-level classification, the graphs should be converted into an instance of `torch_geometric.data.Dataset` or a list of `torch_geometric.data.Data` for mini-batching via a data loader (https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#mini-batches).

In [5]:
data = Data()
data.x = torch.Tensor(node_feat)
data.y = torch.LongTensor(labels)
data.edge_index = get_pyg_edge_index(graph)

In [6]:
print(data)

Data(edge_index=[2, 4476], x=[700, 50], y=[700])


### 3. Initialize the model and optimizer

In [7]:
model = SIGCN(
    input_size=data.x.size(1),
    output_size=num_class,
    hidden_conv_sizes=(8, 8), 
    hidden_dropout_probs=(0, 0)
)
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=0.05, 
    weight_decay=0.001
)

### 4. Train the model

In [8]:
def train():
    model.train()
    optimizer.zero_grad()
    out, x_prob, edge_prob = model(
        data.x, 
        data.edge_index, 
        return_probs=True
    )
    loss_x_l1, \
    loss_x_ent, \
    loss_edge_l1, \
    loss_edge_ent = reg_sig_loss(
        x_prob, 
        edge_prob, 
        coeffs={
            'x_l1': 0.01,
            'x_ent': 0.05,
            'edge_l1': 0.01,
            'edge_ent': 0.05
        }
    )
    loss = cross_entropy_loss(out, data.y) \
        + loss_x_l1 + loss_x_ent + loss_edge_l1 + loss_edge_ent
    loss.backward(retain_graph=True)
    optimizer.step()

In [9]:
for epoch in range(1, 51):
    train()
    if epoch % 10 == 0:
        print('Finished training for %d epochs' % epoch)

Finished training for 10 epochs
Finished training for 20 epochs
Finished training for 30 epochs
Finished training for 40 epochs
Finished training for 50 epochs


### 5. Run the explainer on a node

In [10]:
explainer = SIGExplainer(model)

#### 5.1 Node feature and edge probability

In [11]:
node_feat_prob, edge_prob = explainer.explain_node(
    node_index=15,
    x=data.x,
    edge_index=data.edge_index,
    use_grad=False,
    y=None,
    loss_fn=None,
    take_abs=False,
    pred_for_grad=False
)

In [12]:
print('node_feat_prob shape: {}'.format(node_feat_prob.shape))
print(node_feat_prob)

node_feat_prob shape: torch.Size([50])
tensor([0.6073, 0.5938, 0.5674, 0.6318, 0.6238, 0.6164, 0.6025, 0.5960, 0.5018,
        0.5844, 0.6137, 0.6091, 0.5878, 0.5704, 0.6085, 0.5630, 0.6246, 0.5523,
        0.6082, 0.5521, 0.5457, 0.6023, 0.5746, 0.6129, 0.5883, 0.5937, 0.4954,
        0.6117, 0.6076, 0.6095, 0.5415, 0.5882, 0.5887, 0.5526, 0.5419, 0.5848,
        0.5249, 0.5669, 0.5302, 0.6130, 0.5383, 0.5244, 0.5030, 0.5338, 0.5114,
        0.5007, 0.5078, 0.5055, 0.4957, 0.5635], grad_fn=<ClampBackward>)


In [13]:
print('edge_prob shape: {}'.format(edge_prob.shape))
print(edge_prob)

edge_prob shape: torch.Size([4476])
tensor([1.0000e-05, 1.0000e-05, 1.0000e-05,  ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00], grad_fn=<IndexPutBackward>)


#### 5.2 Node feature and edge probability gradient

In [14]:
node_feat_score, edge_score = explainer.explain_node(
    node_index=15,
    x=data.x,
    edge_index=data.edge_index,
    use_grad=True,
    y=data.y,
    loss_fn=cross_entropy_loss,
    take_abs=False,
    pred_for_grad=True
)

In [15]:
print('node_feat_score shape: {}'.format(node_feat_score.shape))
print(node_feat_score)

node_feat_score shape: torch.Size([50])
tensor([ 1.8113e-07,  1.7804e-07,  1.3067e-07,  1.4501e-07,  9.5189e-08,
         2.6519e-07,  8.1901e-08,  9.8601e-08, -3.8628e-09,  1.9757e-07,
         1.3460e-07,  1.6377e-07,  1.0515e-07,  1.6313e-07,  1.9874e-07,
         4.8434e-08,  2.7607e-07,  1.5328e-07,  3.0147e-07,  1.1535e-07,
         2.2737e-08,  9.6002e-08,  1.7883e-07,  1.8335e-07,  1.4471e-07,
         3.5769e-08, -1.9045e-07,  9.6916e-08,  2.2862e-07,  2.6558e-07,
         3.1914e-08,  7.0649e-08,  1.7684e-07,  2.5435e-07,  1.2272e-07,
         2.2647e-08, -7.0268e-09, -7.1459e-08,  8.7770e-08,  2.1514e-07,
         4.7412e-10, -4.0880e-09, -1.6669e-09, -1.2434e-08,  4.4608e-09,
        -1.2842e-08,  1.9108e-09, -1.0655e-08, -1.7267e-08, -6.8925e-09])


In [16]:
print('edge_score shape: {}'.format(edge_score.shape))
print(edge_score)

edge_score shape: torch.Size([4476])
tensor([-4.8283e-11, -4.0702e-11, -4.8302e-11,  ...,  0.0000e+00,
         0.0000e+00,  0.0000e+00])
