In [None]:
import torch_geometric as pyg    
import torch
import networkx as nx
import matplotlib.pyplot as plt
import random
import numpy as np

In [None]:
NODE_CLS = {
    0: 'C',
    1: 'N',
    2: 'O',
    3: 'F',
    4: 'I',
    5: 'Cl',
    6: 'Br',
}

# NODE_COLOR = {
#     0: 'red',
#     1: 'orange',
#     2: 'yellow',
#     3: 'green',
#     4: 'cyan',
#     5: 'blue',
#     6: 'magenta',
# }

NODE_COLOR = {
    0: 'lightgray',
    1: 'deepskyblue',
    2: 'red',
    3: 'cyan',
    4: 'magenta',
    5: 'springgreen',
    6: 'chocolate',
}

In [None]:
def convert(G, generate_label=False, label_dict=None, color_dict=None):
    G = nx.convert_node_labels_to_integers(G)
    if label_dict is not None: node_labels = [G.nodes[i]['label']
                    if 'label' in G.nodes[i] or not generate_label
                    else random.choice(list(label_dict))
                    for i in G.nodes]
    if G.number_of_edges() > 0:
            edge_index, edge_attr = pyg.utils.to_undirected(
                torch.tensor(list(G.edges)).T,
            ), None
    else:
            edge_index, edge_attr = torch.empty(2, 0).long(), None
    return pyg.data.Data(
        G=G,
        # x=torch.eye(len(self.NODE_CLS))[node_labels].float(),
        x = torch.ones((G.number_of_nodes(), 1)) if label_dict else node_labels,
        y=torch.tensor(G.graph['label'] if "label" in G.graph else -1).long(),
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

In [None]:
import pickle
import os
from utils import draw_graph
gdir = "./results/baseline_graphs/"
for filename in os.listdir(gdir):
    random.seed(7)
    G = pickle.load(open(gdir+filename, "rb"))
    fn = filename.split(".")[0]
    target = fn[-1]
    method = fn.split("_")[0]
    pos = nx.spring_layout(G, seed=7)

    node_labels = {i: NODE_CLS[G.nodes[i]['label']] for i in G.nodes} if "MUTAG" in filename else None
    node_colors = [NODE_COLOR[G.nodes[i]['label']] for i in G.nodes] if "MUTAG" in filename else None
    fig = plt.figure(frameon=False)
    # fig.set_size_inches(w,h)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    nx.draw_networkx(G, pos=pos, with_labels=("MUTAG" in filename), labels=node_labels, node_color=node_colors, ax=ax, node_size=800, font_size=18, width=3)
    imgname = f"./results/Images/{method}/{fn[len(method)+1:]}.png"
    plt.savefig(imgname)
    plt.close()

In [None]:
gdir = "./results/solution_pickles/"
for filename in os.listdir(gdir):
    if not os.path.isfile(gdir+filename): continue
    random.seed(7)
    solution = pickle.load(open(gdir+filename, "rb"))[-1]

    X = solution["X"]
    A = solution["A"]

    molecules = (X.shape[1] == 7)

    G = nx.from_numpy_array(A, create_using=nx.Graph)


    fn = filename.split(".")[0]
    target = fn[-1]
    method = "mipexplainer"
    pos = nx.spring_layout(G, seed=7)

    node_labels = {i: NODE_CLS[np.argmax(X[i])] for i in G.nodes} if molecules else None
    node_colors = [NODE_COLOR[np.argmax(X[i])] for i in G.nodes] if molecules else None
    fig = plt.figure(frameon=False)
    # fig.set_size_inches(w,h)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    nx.draw_networkx(G, pos=pos, with_labels=("MUTAG" in filename), labels=node_labels, node_color=node_colors, ax=ax, node_size=800, font_size=18, width=3)
    plt.savefig(f"./results/Images/{method}/{fn}.png")
    plt.close()

In [None]:
from sklearn.model_selection import train_test_split
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
import torch
import os
import pickle
import numpy as np
import torch.nn.utils.prune as prune
import pandas as pd
from gnn import GNN, test
from torch_geometric.utils import to_dense_adj

In [None]:
pairs = [
    ("Is_Acyclic_Ones", "models/Is_Acyclic_Ones_model.pth"),
    ("Shapes_Ones","models/Shapes_Ones_model.pth"),
    ("MUTAG", "models/MUTAG_model_new.pth"),
]

evals = {}

for dataset_name, model_path in pairs:
    print(dataset_name)
    with open(f"data/{dataset_name}/dataset.pkl", "rb") as f: dataset = pickle.load(f)

    ys = [int(d.y) for d in dataset]
    num_classes = len(set(ys))
    num_node_features = dataset[0].x.shape[1]

    for class_index in range(num_classes):
        for j, data in enumerate(random.choices([g for g in dataset if int(g.y) == class_index], k=5)):
            X = data.x.detach().numpy()
            A = to_dense_adj(data.edge_index).detach().numpy().squeeze()
            molecules = (X.shape[1] == 7)
            G = nx.from_numpy_array(A, create_using=nx.Graph)
            pos = nx.spring_layout(G, seed=7)

            node_labels = {i: NODE_CLS[np.argmax(X[i])] for i in G.nodes} if molecules else None
            node_colors = [NODE_COLOR[np.argmax(X[i])] for i in G.nodes] if molecules else None
            fig = plt.figure(frameon=False)
            # fig.set_size_inches(w,h)
            ax = plt.Axes(fig, [0., 0., 1., 1.])
            ax.set_axis_off()
            fig.add_axes(ax)
            nx.draw_networkx(G, pos=pos, with_labels=molecules, labels=node_labels, node_color=node_colors, ax=ax, node_size=800, font_size=18, width=3)
            plt.savefig(f"./figures/examples/{dataset_name}_class_{class_index}_example_{j}.png")
            plt.close()

    train_dataset, test_dataset = train_test_split(dataset, train_size=0.8, stratify=ys, random_state=7)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = torch.load(model_path)
    print(str(model))

    train_acc = test(model, train_loader)
    test_acc = test(model, test_loader)
    class_test_accuracies = {}
    for idx in range(num_classes):
        loader = DataLoader([g for g in test_dataset if int(g.y) == idx], batch_size=32, shuffle=False)
        class_test_accuracies[idx] = test(model, loader)
    evals[dataset_name if dataset_name[-4:] != "Ones" else dataset_name[:-5]] = {
        "Number of Graphs": len(dataset),
        "Number of Classes": num_classes,
        "Average Number of Nodes": np.mean([g.num_nodes for g in dataset]),
        "Average Number of Edges": np.mean([g.num_edges for g in dataset]),
        "Number of Node Features": num_node_features,
        "Train Accuracy": train_acc, 
        "Test Accuracy": test_acc,
        "Number of Model Parameters": sum(param.numel() for param in model.parameters()),
        "Classwise Test Accuracies": class_test_accuracies
    }

In [None]:
df = pd.DataFrame.from_dict(evals, orient="index")
df.head()

In [None]:
print(df[df.columns[:5]].to_latex(index=True, float_format="{:.3f}".format).replace("_", "\\_"))

In [None]:
print(df[df.columns[5:-1]].to_latex(index=True, float_format="{:.3f}".format).replace("_", "\\_"))