In [13]:
import torch
from torch_geometric.datasets import TUDataset
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.logging import init_wandb, log
from torch_geometric.nn import GATConv, GCNConv
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from torch.autograd import Variable

In [14]:
dataset = TUDataset(root='data/TUDataset', name='MUTAG')

In [15]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: MUTAG(188):
Number of graphs: 188
Number of features: 7
Number of classes: 2

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [16]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:150]
test_dataset = dataset[150:]
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 150
Number of test graphs: 38


In [17]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [18]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def get_embedding_outputs(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        embedding = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(embedding, p=0.5, training=self.training)
        x = self.lin(x)
        return embedding, x 

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

In [19]:
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
print(model)

GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [20]:
def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


In [21]:
for epoch in range(1, 5):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368


In [22]:
class ProbGraph:
    def __init__(self, data):
        self.has_edge_attributes = data.edge_attr is not None 
        self.has_node_attributes = data.x is not None
        self.Omega = Variable(torch.Tensor(to_dense_adj(data.edge_index)), requires_grad=True)
        self.params = [self.Omega]
        if self.has_edge_attributes: 
            self.H = Variable(torch.Tensor(data.edge_attr), requires_grad=True)
            self.params.append(self.H)
        if self.has_node_attributes: 
            self.Xi = Variable(torch.Tensor(data.x), requires_grad=True)
            self.params.append(self.Xi)

    def parameters(self):
        return self.params

    def sample_train(self, K, tau_a, tau_z, tau_x):
        sampled_graphs = []
        for _ in range(K):
            sampled_graph = dict()
            a_epsilon = torch.rand(self.Omega.shape)
            a = torch.sigmoid((self.Omega + torch.log(a_epsilon)-torch.log(1-a_epsilon))/tau_a)
            # print([type(a[i]) for i in range(a.shape[0])])
            # print([type(dense_to_sparse(a[i])) for i in range(a.shape[0])])
            sampled_graph["edge_index"] = dense_to_sparse(a)[0]

            if self.has_edge_attributes:
                z_epsilon = torch.rand(self.H.shape)
                z = torch.softmax((self.H - torch.log(-torch.log(z_epsilon)))/tau_z, 0) # UNFINISHED: DOUBLE CHECK DIMENSION
                sampled_graph["edge_attr"] = z
            
            if self.has_node_attributes:
                x_epsilon = torch.rand(self.Xi.shape)
                x = torch.softmax((self.Xi - torch.log(-torch.log(x_epsilon)))/tau_x, 0) # UNFINISHED: DOUBLE CHECK DIMENSION
                sampled_graph["x"] = x

            sampled_graphs.append(Data(**sampled_graph))

        return Batch.from_data_list(sampled_graphs)
    
    def sample_explanation(self):
        # Sigma = torch.sigmoid(self.Omega)
        # P = torch.nn.Softmax(self.Xi)
        # Q = torch.nn.Softmax(self.H)
        raise NotImplementedError

class GNNInterpreter():
    def __init__(self, get_embedding_outputs, train_loader):
        super().__init__()

        self.tau_a = 5
        self.tau_z = 5
        self.tau_x = 5
        self.mu = 1

        self.K = 10 # Monte Carlo # Samples

        self.get_embedding_outputs = get_embedding_outputs

        self.average_phi = self.get_average_phi(train_loader).detach()

    def get_average_phi(self, dataset):
        dataloader = DataLoader(train_dataset, batch_size=1)
        embedding_sum = None
        n_instances = torch.zeros(dataset.num_classes)
        for batch in dataloader:
            embeddings = self.get_embedding_outputs(batch)[0]
            if embedding_sum is None: 
                embedding_sum = torch.zeros(dataset.num_classes, embeddings.shape[-1]) # UNFINISHED: Ensure Correct Dimension
            embedding_sum[batch.y] += embeddings
            n_instances[batch.y] += train_loader.batch_size
        # print(embedding_sum.shape, torch.unsqueeze(n_instances, 0).shape)
        return embedding_sum / torch.unsqueeze(n_instances, 1)

    def regularizer(self):
        return 0

    def train(self, init_graph, class_index, max_iter=1000):

        pg = ProbGraph(init_graph)

        optimizer = torch.optim.SGD(pg.parameters(), lr=0.001, momentum=0.9)

        for _ in range(max_iter):
            optimizer.zero_grad()

            sampled_graphs = pg.sample_train(self.K, self.tau_a, self.tau_z, self.tau_x)

            embeddings, outputs = self.get_embedding_outputs(sampled_graphs)

            class_probabilities = outputs[:, class_index]
            embedding_similarities = F.cosine_similarity(embeddings, self.average_phi[class_index], dim=-1) # UNFINISHED: Check Dimension
            loss = torch.mean(class_probabilities+self.mu*embedding_similarities) + self.regularizer()

            loss.backward()
            optimizer.step()

        return pg


In [23]:
# for data in train_loader:
#     print(data.x.shape, data.edge_index.shape, data.batch.shape)
#     print(dir(data))
    # model.get_embedding_outputs(batch)
#     break

In [24]:
model.eval()
i = GNNInterpreter(model.get_embedding_outputs, train_dataset)

i.train(dataset[0], 0)

<__main__.ProbGraph at 0x7ff73f5cee90>