In [193]:
import torch
from torch_geometric.datasets import TUDataset
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.logging import init_wandb, log
from torch_geometric.nn import GATConv, GCNConv
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from torch.autograd import Variable

In [194]:
dataset = TUDataset(root='data/TUDataset', name='MUTAG')

In [195]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: MUTAG(188):
Number of graphs: 188
Number of features: 7
Number of classes: 2

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [196]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:150]
test_dataset = dataset[150:]
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 150
Number of test graphs: 38


In [197]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [198]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
# from torch_geometric.nn.dense import DenseGCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def get_embedding_outputs(self, data):
        x, edge_index, edge_weight, batch = data.x, data.edge_index, data.edge_weight, data.batch
        
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        embedding = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(embedding, p=0.5, training=self.training)
        x = self.lin(x)
        
        return embedding, x

    def forward(self, x, edge_index, batch, edge_weight=None):
        if edge_weight is None: edge_weight = torch.ones(edge_index.shape[1])
        
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

In [199]:
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
print(model)

GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [200]:
def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


In [201]:
for epoch in range(1, 5):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368


In [202]:
from torchviz import make_dot

class ProbGraph:
    def __init__(self, data):
        self.has_edge_attributes = data.edge_attr is not None 
        self.has_node_attributes = data.x is not None
        # self.Omega = Variable(torch.Tensor(to_dense_adj(data.edge_index)), requires_grad=True)
        self.Omega = Variable(torch.Tensor(0.5*torch.ones((data.edge_index.shape[0], data.edge_index.shape[0]))), requires_grad=True)
        self.params = {"Omega": self.Omega}
        if self.has_edge_attributes: 
            # self.H = Variable(torch.Tensor(data.edge_attr), requires_grad=True)
            self.H = Variable(torch.Tensor(0.5*torch.ones(data.edge_attr.shape)), requires_grad=True)
            self.params["H"] = self.H
        if self.has_node_attributes: 
            # self.Xi = Variable(torch.Tensor(data.x), requires_grad=True)
            self.Xi = Variable(torch.Tensor(0.5*torch.ones(data.x.shape)), requires_grad=True)
            self.params["Xi"] = self.Xi

    def parameters(self):
        return self.params

    def sample_train(self, K, tau_a, tau_z, tau_x):
        sampled_graphs = []
        for _ in range(K):
            sampled_graph = dict()
            a_epsilon = torch.rand(self.Omega.shape)
            a = torch.sigmoid((self.Omega + torch.log(a_epsilon)-torch.log(1-a_epsilon))/tau_a)
            # print([type(a[i]) for i in range(a.shape[0])])
            # print([type(dense_to_sparse(a[i])) for i in range(a.shape[0])])
            sampled_graph["edge_index"], sampled_graph["edge_weight"] = dense_to_sparse(a)

            if self.has_edge_attributes:
                z_epsilon = torch.rand(self.H.shape)
                z = torch.softmax((self.H - torch.log(-torch.log(z_epsilon)))/tau_z, 0) # UNFINISHED: DOUBLE CHECK DIMENSION
                sampled_graph["edge_attr"] = z
            
            if self.has_node_attributes:
                x_epsilon = torch.rand(self.Xi.shape)
                x = torch.softmax((self.Xi - torch.log(-torch.log(x_epsilon)))/tau_x, 0) # UNFINISHED: DOUBLE CHECK DIMENSION
                sampled_graph["x"] = x

            sampled_graphs.append(Data(**sampled_graph))

        return Batch.from_data_list(sampled_graphs)
    
    def get_latents(self):
        latent_dict = {"Theta": torch.sigmoid(self.Omega)}
        if self.has_node_attributes:
            latent_dict["P"] = torch.softmax(self.Xi, 0) # UNFINISHED: DOUBLE CHECK DIMENSION
        if self.has_edge_attributes:
            latent_dict["Q"] = torch.softmax(self.H, 0) # UNFINISHED: DOUBLE CHECK DIMENSION
        return latent_dict
    
    def sample_explanations(self, n=1):
        latent_dict = self.get_latents()

        A_dist = torch.distributions.Bernoulli(latent_dict["Theta"])
        if self.has_node_attributes:
            X_dist = torch.distributions.Categorical(latent_dict["P"])
        if self.has_edge_attributes:
            Z_dist = torch.distributions.Categorical(latent_dict["Q"])

        sampled_graphs = []
        for _ in range(n):
            sampled_graph = {"A": A_dist.sample()}
            if self.has_node_attributes:
                sampled_graph["X"] = X_dist.sample()
            if self.has_edge_attributes:
                sampled_graph["Z"] = Z_dist.sample()
            sampled_graphs.append(sampled_graph)

        return sampled_graphs

class GNNInterpreter():
    def __init__(self, get_embedding_outputs, train_loader):
        super().__init__()

        self.tau_a = 0.2
        self.tau_z = 0.2
        self.tau_x = 0.2
        self.mu = 0 #10

        self.K = 10 # Monte Carlo # Samples
        self.B = 20 # Max Budget

        self.reg_weights = {
            "L1": [10, 5],
            "L2": [5, 2],
            "Rb": [20, 10],
            "Rc": [1,2],
        }

        self.get_embedding_outputs = get_embedding_outputs

        self.average_phi = self.get_average_phi(train_loader).detach()

    def get_average_phi(self, dataset):
        dataloader = DataLoader(train_dataset, batch_size=1)
        embedding_sum = None
        n_instances = torch.zeros(dataset.num_classes)
        for batch in dataloader:
            embeddings = self.get_embedding_outputs(batch)[0]
            if embedding_sum is None: 
                embedding_sum = torch.zeros(dataset.num_classes, embeddings.shape[-1]) # UNFINISHED: Ensure Correct Dimension
            embedding_sum[batch.y] += embeddings
            n_instances[batch.y] += train_loader.batch_size
        return embedding_sum / torch.unsqueeze(n_instances, 1)
    
    def bernoulli_kl(self, p1, p2):
        return p1*torch.log(p1/p2) + (1-p1)*torch.log((1-p1)/(1-p2))

    def regularizer(self, pg, class_index):
        omega_l1 = torch.norm(pg.Omega, 1)
        omega_l2 = torch.norm(pg.Omega, 2)

        budget_penalty = F.softplus(omega_l1-self.B)**2

        Theta = torch.sigmoid(pg.Omega).squeeze()
        connectivity_incentive = 0
        for i in range(pg.Xi.shape[0]):
            for j in range(pg.Xi.shape[0]):
                Pij = Theta[i][j]
                for k in range(pg.Xi.shape[0]):
                    Pik = Theta[i][k]
                    connectivity_incentive += self.bernoulli_kl(Pij, Pik)

        return omega_l1 * self.reg_weights["L1"][class_index] \
              + omega_l2 * self.reg_weights["L2"][class_index] \
              + budget_penalty * self.reg_weights["Rb"][class_index] \
              + connectivity_incentive * self.reg_weights["Rc"][class_index]

    def train(self, init_graph, class_index, max_iter=1000):

        pg = ProbGraph(init_graph)

        optimizer = torch.optim.SGD(pg.parameters().values(), lr=1, momentum=0.9)

        for _ in range(max_iter):
            optimizer.zero_grad()

            sampled_graphs = pg.sample_train(self.K, self.tau_a, self.tau_z, self.tau_x)
            # print(sampled_graphs.get_example(0).requires_grad_())

            embeddings, outputs = self.get_embedding_outputs(sampled_graphs)

            class_probabilities = outputs[:, class_index]
            embedding_similarities = F.cosine_similarity(embeddings, self.average_phi[class_index], dim=1) # UNFINISHED: Check Dimension
            # print(class_probabilities, embedding_similarities)
            loss = torch.mean(class_probabilities+self.mu*embedding_similarities)# + self.regularizer(pg, class_index)

            if False:
                all_params = pg.parameters()
                all_params.update(model.named_parameters())
                self.computation_graph = make_dot(loss, all_params)
            print("Loss:", float(loss))

            loss.backward()
            # print("Node features equal:", torch.equal(torch.round(pg.Xi), init_graph.x))
            optimizer.step()

        return pg

In [203]:
i = GNNInterpreter(model.get_embedding_outputs, train_dataset)

pg = i.train(dataset[0], 1)
# pg.sample_explanations()

Loss: -0.11080356687307358
False False True
True True
Loss: -0.5464295744895935
False False True
True True
Loss: -0.10576684772968292
False False True
True True
Loss: -0.11658446490764618
False False True
True True
Loss: -0.11443356424570084
False False True
True True
Loss: -0.13089314103126526
False False True
True True
Loss: -0.2650243937969208
False False True
True True
Loss: -0.11258850991725922
False False True
True True
Loss: -0.7054144144058228
False False True
True True
Loss: -52.28873825073242
False False True
True True
Loss: -31594.07421875
False False True
True True
Loss: -0.11444133520126343
False False True
True True
Loss: -0.11350683867931366
False False True
True True
Loss: -0.11381115764379501
False False True
True True
Loss: -0.11308854818344116
False False True
True True
Loss: -0.1130816712975502
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.1130816712975502
False False True
True True
Loss: -0.11308164894580841
False False T

  print(sampled_graphs.edge_weight.grad is None, sampled_graphs.edge_weight.requires_grad)


False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
Loss: -0.11308164894580841
False False True
True True
L