In [14]:
import sys
import os
current = os.getcwd()
parent = os.path.dirname(current)
sys.path.append(parent)
from Kernels.Weisfeiler_Lehman import Weisfeiler_Lehman
import torch
from torch_geometric.nn import global_mean_pool
print(torch.__version__)
import torchnet as tnt
import torch.nn.functional as F
import torch.nn as nn
import wget
import zipfile
from collections import Counter
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from numpy.linalg import inv,multi_dot
from scipy.linalg import expm
from matplotlib import pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout
import numpy as np
import pydot
from Datasets.utils_mutag import load_data , create_loaders 



1.13.0+cpu


In [15]:
dataset = load_data(path='Datasets/MUTAG/', ds_name='MUTAG',
                    use_node_labels=True, use_edge_labels=True, max_node_label=7,max_edge_label=4)


#pytorch loaders
train_dataset, val_dataset = create_loaders(
    dataset, batch_size=1, split_id=150, offset=0)
print('Data are ready')

../Datasets/MUTAG/MUTAG_graph_indicator.txt
Data are ready


In [16]:
X=np.array(dataset,dtype=object)[:,0]
y=np.array(dataset,dtype=object)[:,1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

# Testing with Weisfeler Lehman

In [17]:
#transform featurex
kernel=Weisfeiler_Lehman(normalise=True,h=0,node_label='attr_dict')
k_train=kernel.fit_transform(list(X_train))
k_test=kernel.transform(list(X_test))
# Uses the SVM classifier to perform classification
clf = SVC(kernel='precomputed')
clf.fit(np.asarray(k_train), np.ravel(y_train).astype(int))
y_pred = clf.predict(np.asarray(k_test))
# Computes and prints the classification accuracy
acc = accuracy_score(np.array(y_test,dtype=int), np.array(y_pred,dtype=int))
print("Accuracy:", str(round(acc*100, 2)) + "%")

Accuracy: 63.16%


# Testing with an heterogenous arhitecture(MEWISPool)  suposedly  achieving 96% after 200 epochs

https://paperswithcode.com/sota/graph-classification-on-mutag

In [58]:

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, enhance=False):
        super(MLP, self).__init__()

        self.enhance = enhance

        self.fc1 = nn.Linear(in_features=input_dim, out_features=hidden_dim)
        self.fc2 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.fc3 = nn.Linear(in_features=hidden_dim, out_features=output_dim)

        if enhance:
            self.bn1 = nn.BatchNorm1d(hidden_dim)
            self.bn2 = nn.BatchNorm1d(hidden_dim)
            self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.fc1(x)
        if self.enhance:
            x = self.bn1(x)
        x = torch.relu(x)
        if self.enhance:
            x = self.dropout(x)

        x = self.fc2(x)
        if self.enhance:
            x = self.bn2(x)
        x = torch.relu(x)
        if self.enhance:
            x = self.dropout(x)

        x = self.fc3(x)

        return x


class MEWISPool(nn.Module):
    def __init__(self, hidden_dim):
        super(MEWISPool, self).__init__()

        self.gc1 = GINConv(MLP(1, hidden_dim, hidden_dim))
        self.gc2 = GINConv(MLP(hidden_dim, hidden_dim, hidden_dim))
        self.gc3 = GINConv(MLP(hidden_dim, hidden_dim, 1))

    def forward(self, x, edge_index, batch):
        # computing the graph laplacian and adjacency matrix
        batch_nodes = batch.size(0)
        if edge_index.size(1) != 0:
            L_indices, L_values = get_laplacian(edge_index)
            L = torch.sparse.FloatTensor(L_indices, L_values, torch.Size([batch_nodes, batch_nodes]))
            A = torch.diag(torch.diag(L.to_dense())) - L.to_dense()

            # entropy computation
            entropies = self.compute_entropy(x, L, A, batch)  # Eq. (8)
        else:
            A = torch.zeros([batch_nodes, batch_nodes])
            norm = torch.norm(x, dim=1).unsqueeze(-1)
            entropies = norm / norm

        # graph convolution and probability scores
        probabilities = self.gc1(entropies, edge_index)
        probabilities = self.gc2(probabilities, edge_index)
        probabilities = self.gc3(probabilities, edge_index)
        probabilities = torch.sigmoid(probabilities)

        # conditional expectation; Algorithm 1
        gamma = entropies.sum()
        loss = self.loss_fn(entropies, probabilities, A, gamma)  # Eq. (9)

        mewis = self.conditional_expectation(entropies, probabilities, A, loss, gamma)

        # graph reconstruction; Eq. (10)
        x_pooled, adj_pooled = self.graph_reconstruction(mewis, x, A)
        edge_index_pooled, batch_pooled = self.to_edge_index(adj_pooled, mewis, batch)

        return x_pooled, edge_index_pooled, batch_pooled, loss, mewis

    @staticmethod
    def compute_entropy(x, L, A, batch):
        # computing local variations; Eq. (5)
        V = x * torch.matmul(L, x) - x * torch.matmul(A, x) + torch.matmul(A, x * x)
        V = torch.norm(V, dim=1)

        # computing the probability distributions based on the local variations; Eq. (7)
        P = torch.cat([torch.softmax(V[batch == i], dim=0) for i in torch.unique(batch)])
        P[P == 0.] += 1
        # computing the entropies; Eq. (8)
        H = -P * torch.log(P)

        return H.unsqueeze(-1)

    @staticmethod
    def loss_fn(entropies, probabilities, A, gamma):
        term1 = -torch.matmul(entropies.t(), probabilities)[0, 0]

        term2 = torch.matmul(torch.matmul(probabilities.t(), A), probabilities).sum()

        return gamma + term1 + term2

    def conditional_expectation(self, entropies, probabilities, A, threshold, gamma):
        sorted_probabilities = torch.sort(probabilities, descending=True, dim=0)

        dummy_probabilities = probabilities.detach().clone()
        selected = set()
        rejected = set()

        for i in range(sorted_probabilities.values.size(0)):
            node_index = sorted_probabilities.indices[i].item()
            neighbors = torch.where(A[node_index] == 1)[0]
            if len(neighbors) == 0:
                selected.add(node_index)
                continue
            if node_index not in rejected and node_index not in selected:
                s = dummy_probabilities.clone()
                s[node_index] = 1
                s[neighbors] = 0

                loss = self.loss_fn(entropies, s, A, gamma)

                if loss <= threshold:
                    selected.add(node_index)
                    for n in neighbors.tolist():
                        rejected.add(n)

                    dummy_probabilities[node_index] = 1
                    dummy_probabilities[neighbors] = 0

        mewis = list(selected)
        mewis = sorted(mewis)

        return mewis

    @staticmethod
    def graph_reconstruction(mewis, x, A):
        x_pooled = x[mewis]

        A2 = torch.matmul(A, A)
        A3 = torch.matmul(A2, A)

        A2 = A2[mewis][:, mewis]
        A3 = A3[mewis][:, mewis]

        I = torch.eye(len(mewis))
        one = torch.ones([len(mewis), len(mewis)])

        adj_pooled = (one - I) * torch.clamp(A2 + A3, min=0, max=1)

        return x_pooled, adj_pooled

    @staticmethod
    def to_edge_index(adj_pooled, mewis, batch):
        row1, row2 = torch.where(adj_pooled > 0)
        edge_index_pooled = torch.cat([row1.unsqueeze(0), row2.unsqueeze(0)], dim=0)
        batch_pooled = batch[mewis]

        return edge_index_pooled, batch_pooled


class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(Net, self).__init__()

        self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.pool1 = MEWISPool(hidden_dim=hidden_dim)
        self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.pool2 = MEWISPool(hidden_dim=hidden_dim)
        self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.fc2 = nn.Linear(in_features=hidden_dim, out_features=num_classes)

    def forward(self, x, edge_index, batch):
        x = self.gc1(x, edge_index)
        x = torch.relu(x)

        x_pooled1, edge_index_pooled1, batch_pooled1, loss1, mewis = self.pool1(x, edge_index, batch)

        x_pooled1 = self.gc2(x_pooled1, edge_index_pooled1)
        x_pooled1 = torch.relu(x_pooled1)

        x_pooled2, edge_index_pooled2, batch_pooled2, loss2, mewis = self.pool2(x_pooled1, edge_index_pooled1,
                                                                                batch_pooled1)

        x_pooled2 = self.gc3(x_pooled2, edge_index_pooled2)

        readout = torch.cat([x_pooled2[batch_pooled2 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled2)],
                            dim=0)

        out = self.fc1(readout)
        out = torch.relu(out)
        out = self.fc2(out)

        return torch.log_softmax(out, dim=-1), loss1 + loss2


class Net2(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(Net2, self).__init__()

        self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.pool1 = MEWISPool(hidden_dim=hidden_dim)
        self.gc4 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.fc2 = nn.Linear(in_features=hidden_dim, out_features=num_classes)

    def forward(self, x, edge_index, batch):
        x = self.gc1(x, edge_index)
        x = torch.relu(x)

        x = self.gc2(x, edge_index)
        x = torch.relu(x)

        x = self.gc3(x, edge_index)
        x = torch.relu(x)
        readout2 = torch.cat([x[batch == i].mean(0).unsqueeze(0) for i in torch.unique(batch)], dim=0)

        x_pooled1, edge_index_pooled1, batch_pooled1, loss1, mewis = self.pool1(x, edge_index, batch)

        x_pooled1 = self.gc4(x_pooled1, edge_index_pooled1)

        readout = torch.cat([x_pooled1[batch_pooled1 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled1)],
                            dim=0)

        out = readout2 + readout

        out = self.fc1(out)
        out = torch.relu(out)
        out = self.fc2(out)

        return torch.log_softmax(out, dim=-1), loss1


class Net3(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(Net3, self).__init__()

        self.gc1 = GINConv(MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.gc2 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.pool1 = MEWISPool(hidden_dim=hidden_dim)
        self.gc3 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.gc4 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.pool2 = MEWISPool(hidden_dim=hidden_dim)
        self.gc5 = GINConv(MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, enhance=True))
        self.fc1 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.fc2 = nn.Linear(in_features=hidden_dim, out_features=num_classes)

    def forward(self, x, edge_index, batch):
        x = self.gc1(x, edge_index)
        x = torch.relu(x)

        x = self.gc2(x, edge_index)
        x = torch.relu(x)

        x_pooled1, edge_index_pooled1, batch_pooled1, loss1, mewis1 = self.pool1(x, edge_index, batch)

        x_pooled1 = self.gc3(x_pooled1, edge_index_pooled1)
        x_pooled1 = torch.relu(x_pooled1)

        x_pooled1 = self.gc4(x_pooled1, edge_index_pooled1)
        x_pooled1 = torch.relu(x_pooled1)

        x_pooled2, edge_index_pooled2, batch_pooled2, loss2, mewis2 = self.pool2(x_pooled1, edge_index_pooled1,
                                                                                 batch_pooled1)

        x_pooled2 = self.gc5(x_pooled2, edge_index_pooled2)
        x_pooled2 = torch.relu(x_pooled2)

        readout = torch.cat([x_pooled2[batch_pooled2 == i].mean(0).unsqueeze(0) for i in torch.unique(batch_pooled2)],
                            dim=0)

        out = self.fc1(readout)
        out = torch.relu(out)
        out = self.fc2(out)

        return torch.log_softmax(out, dim=-1), loss1 + loss2

In [64]:
import os.path as osp
import time
import torch
import os
torch.manual_seed(7)
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader 
from torch_geometric.nn import GINConv
from torch_geometric.utils import get_laplacian
os.makedirs('checkpoints', exist_ok=True)

DATASET_NAME = 'MUTAG'
BATCH_SIZE = 20
HIDDEN_DIM = 32
EPOCHS = 5
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-5
SCHEDULER_PATIENCE = 10
SCHEDULER_FACTOR = 0.1
EARLY_STOPPING_PATIENCE = 50



In [65]:
path = osp.join(osp.dirname(osp.abspath("__file__")), '../Datasets/MutagII', DATASET_NAME)
dataset = TUDataset(path, name=DATASET_NAME, use_node_attr=True, use_edge_attr=True).shuffle()
n = (len(dataset) + 9) // 10

input_dim = dataset.num_features
num_classes = dataset.num_classes

dataset = dataset.shuffle()

test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

device = torch.device('cpu')


In [66]:

# Model, Optimizer and Loss definitions
model = Net3(input_dim=input_dim, hidden_dim=HIDDEN_DIM, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                       patience=SCHEDULER_PATIENCE,
                                                       factor=SCHEDULER_FACTOR,
                                                       verbose=True)
nll_loss = torch.nn.NLLLoss()

best_val_loss = float('inf')
best_test_acc = 0
wait = None
for epoch in range(EPOCHS):
    # Training the model
    s_time = time.time()
    train_loss = 0.
    train_corrects = 0
    model.train()
    for i, data in enumerate(train_loader):
        s = time.time()
        data = data.to(device)
        optimizer.zero_grad()
        out, loss_pool = model(data.x, data.edge_index, data.batch)
        loss_classification = nll_loss(out, data.y.view(-1))
        loss = loss_classification + 0.01 * loss_pool

        loss.backward()
        train_loss += loss.item()
        train_corrects += out.max(dim=1)[1].eq(data.y.view(-1)).sum().item()
        optimizer.step()
        # print(f'{i}/{len(train_loader)}, {time.time() - s}')

    train_loss /= len(train_loader)
    train_acc = train_corrects / len(train_dataset)
    scheduler.step(train_loss)

    # Validation
    val_loss = 0.
    val_corrects = 0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            s = time.time()
            data = data.to(device)
            out, loss_pool = model(data.x, data.edge_index, data.batch)
            loss_classification = nll_loss(out, data.y.view(-1))
            loss = loss_classification + 0.01 * loss_pool
            val_loss += loss.item()
            val_corrects += out.max(dim=1)[1].eq(data.y.view(-1)).sum().item()
            # print(f'{i}/{len(val_loader)}, {time.time() - s}')

    val_loss /= len(val_loader)
    val_acc = val_corrects / len(val_dataset)

    # Test
    test_loss = 0.
    test_corrects = 0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            s = time.time()
            data = data.to(device)
            out, loss_pool = model(data.x, data.edge_index, data.batch)
            loss_classification = nll_loss(out, data.y.view(-1))
            loss = loss_classification + 0.01 * loss_pool
            test_loss += loss.item()
            test_corrects += out.max(dim=1)[1].eq(data.y.view(-1)).sum().item()
            # print(f'{i}/{len(val_loader)}, {time.time() - s}')

    test_loss /= len(test_loader)
    test_acc = test_corrects / len(test_dataset)

    elapse_time = time.time() - s_time
    log = '[*] Epoch: {}, Train Loss: {:.3f}, Train Acc: {:.2f}, Val Loss: {:.3f}, ' \
          'Val Acc: {:.2f}, Test Loss: {:.3f}, Test Acc: {:.2f}, Elapsed Time: {:.1f}'\
        .format(epoch, train_loss, train_acc, val_loss, val_acc, test_loss, best_test_acc, elapse_time)
    print(log)

    # Early-Stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_test_acc = test_acc
        wait = 0
        # saving the model with best validation loss
        torch.save(model.state_dict(), f'checkpoints/{DATASET_NAME}.pkl')
    else:
        wait += 1
    # early stopping
    if wait == EARLY_STOPPING_PATIENCE:
        print('======== Early stopping! ========')
        break

[*] Epoch: 0, Train Loss: 1.857, Train Acc: 0.64, Val Loss: 1.483, Val Acc: 0.68, Test Loss: 1.398, Test Acc: 0.00, Elapsed Time: 0.9
[*] Epoch: 1, Train Loss: 0.654, Train Acc: 0.71, Val Loss: 1.007, Val Acc: 0.84, Test Loss: 0.828, Test Acc: 0.89, Elapsed Time: 1.0
[*] Epoch: 2, Train Loss: 0.564, Train Acc: 0.75, Val Loss: 0.690, Val Acc: 0.74, Test Loss: 0.560, Test Acc: 0.95, Elapsed Time: 1.1
[*] Epoch: 3, Train Loss: 0.635, Train Acc: 0.69, Val Loss: 0.554, Val Acc: 0.84, Test Loss: 0.241, Test Acc: 0.84, Elapsed Time: 0.9
[*] Epoch: 4, Train Loss: 0.510, Train Acc: 0.73, Val Loss: 0.442, Val Acc: 0.84, Test Loss: 0.300, Test Acc: 0.89, Elapsed Time: 0.9
