In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.utils import to_scipy_sparse_matrix
import scipy
from torch_geometric.datasets import Planetoid, TUDataset, QM9
from scipy.sparse.csgraph import floyd_warshall, dijkstra
import seaborn as sns
import igraph

In [None]:
mutagenity_feature_names = {
    0: 'C',
    1: 'O',
    2: 'Cl',
    3: 'H',
    4: 'N',
    5: 'F',
    6: 'Br',
    7: 'S',
    8: 'P',
    9: 'I',
    10: 'Na',
    11: 'K',
    12: 'Li',
    13: 'Ca'
}.values()

In [343]:
class GNAN(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers, hidden_channels=None, bias=True, dropout=0.0,
                 device='cpu', limited_m=False, normalize_m=True):
        super().__init__()

        self.device = device
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.limited_m = limited_m
        self.normalize_m = normalize_m
        self.fs = nn.ModuleList()

        for _ in range(in_channels):
            if num_layers == 1:
                curr_f = [nn.Linear(1, out_channels, bias=bias)]
            else:
                curr_f = [nn.Linear(1, hidden_channels, bias=bias), nn.ReLU(), nn.Dropout(p=dropout)]
                for _ in range(1, num_layers - 1):
                    curr_f.append(nn.Linear(hidden_channels, hidden_channels, bias=bias))
                    curr_f.append(nn.ReLU())
                    curr_f.append(nn.Dropout(p=dropout))
                curr_f.append(nn.Linear(hidden_channels, out_channels, bias=bias))
            self.fs.append(nn.Sequential(*curr_f))

        if num_layers == 1:
            self.m = [nn.Linear(1, out_channels, bias=bias)]
        else:
            self.m = [nn.Linear(1, hidden_channels, bias=bias), nn.ReLU()]
            for _ in range(1, num_layers - 1):
                self.m.append(nn.Linear(hidden_channels, hidden_channels, bias=bias))
                self.m.append(nn.ReLU())
            if limited_m:
                self.m.append(nn.Linear(hidden_channels, 1, bias=bias))
            else:
                self.m.append(nn.Linear(hidden_channels, out_channels, bias=bias))

        self.m = nn.Sequential(*self.m)

    def init_params(self):
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param, gain=self.init_std)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def forward(self, inputs, node_ids):
        x, edge_index, node_distances = inputs.x, inputs.edge_index, inputs.node_distances
        fx = torch.empty(x.size(0), x.size(1), self.out_channels).to(self.device)
        for feature_index in range(x.size(1)):
            feature_col = x[:, feature_index]
            feature_col = feature_col.view(-1, 1)
            feature_col = self.fs[feature_index](feature_col)
            fx[:, feature_index] = feature_col

        f_sums = fx.sum(dim=1)
        stacked_results = torch.empty(len(node_ids), self.out_channels).to(self.device)
        for j, node in enumerate(node_ids):
            node_dists = node_distances[node]
            normalization = inputs.normalization_matrix[node]
            m_dist = self.m(node_dists.view(-1, 1))
            if self.normalize_m:
                if m_dist.size(1) == 1:
                    m_dist = torch.div(m_dist, normalization.view(-1, 1))
                else:
                    for i in range(m_dist.size(1)):
                        m_dist[:, i] = torch.div(m_dist[:, i], normalization)
            pred_for_node = torch.sum(torch.mul(m_dist, f_sums), dim=0)
            stacked_results[j] = pred_for_node.view(1, -1)

        return stacked_results

In [345]:
data_path = 'data'
dataset = TUDataset(root=data_path, name='Mutagenicity')

mutagenicity_gnan = <path to model params>
model = GNAN(in_channels=<>, hidden_channels=<>, num_layers=<>, out_channels=<>, bias=<>, limited_m=0, normalize_m=1)
model.load_state_dict(torch.load(f"{mutagenicity_gnan}", map_location=torch.device('cpu') ))

data = list(dataset)

num_classes = dataset.num_classes
num_features = data[0].x.size(-1)

max_distance = 0
for g in data:
    igraph_graph = igraph.Graph(directed=False)
    igraph_graph.add_vertices(g.x.size(0))
    igraph_graph.add_edges(g.edge_index.T.numpy())
    curr_max_distance = igraph_graph.diameter(directed=False)
    if curr_max_distance > max_distance:
        max_distance = curr_max_distance


In [346]:
# Plot rho in normalized distances
y_input_values = torch.tensor([1/(1+i) for i in range(max_distance +1)])
m_y_values = np.zeros(shape=(y_input_values.size(0), ))
for i, val in enumerate(y_input_values):
    mdist = model.m.forward(val.view(-1, 1)).detach()
    m_y_values[i] = mdist

In [None]:
# plot for distance function
sns.set_style("whitegrid")
plt.figure(figsize=(4, 2))

x_ticks = [i for i in range(max_distance+1)]

#remove the paddings in the beggining and end of the plot
plt.xlim(-0.5, max_distance-0.5)
plt.plot(x_ticks, m_y_values, marker='.', markersize=4)
# plt.xticks(distance_ticks)
plt.xlabel('Distance', size=9)
plt.ylabel('Distance function output', size=9)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)

plt.show()

In [None]:
f_scores = torch.zeros((num_features, ))
for i in range(num_features):
    f_scores[i] = model.fs[i].forward(torch.tensor([1.0]).view(-1, 1)).detach().flatten()[0]

In [None]:
# plot a bar plot of the f_scors with the feautre names for each bar
plt.figure(figsize=(4, 2))
plt.bar(mutagenity_feature_names, f_scores)

plt.xlabel('Feature (atom)', size=9)
plt.ylabel('Feature function output', size=9)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.show()

In [None]:
from matplotlib.colors import LinearSegmentedColormap
colors = ["red", "white", "green"]  # Red to white to green
n_bins = 100  # Number of bins in the colormap
cmap_name = "custom_colormap"
# Create the colormap
cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)
plot_x = torch.tensor(np.arange(max_distance)).long()
plot_y = torch.zeros((len(f_scores), 1))


In [None]:
#plot heatmap
z = np.outer(f_scores, m_y_values)
fig, ax= plt.subplots(figsize=(45, 15.5))
sns.heatmap(z, annot=True, fmt=".2f", xticklabels=x_ticks, yticklabels=mutagenity_feature_names, cmap=cmap,
            center=0, annot_kws={"fontsize":16}, cbar=False ,ax=ax)
cbar = plt.colorbar(ax.collections[0], ax=ax, use_gridspec=True, aspect=70)

cbar.ax.set_position([0.75, 0.1, 2, 0.755])
cbar.ax.tick_params(labelsize=20)
plt.xticks(fontsize=20, weight = 'bold')
plt.yticks(fontsize=20, weight = 'bold')
plt.xlabel('Distance', size=30)
plt.ylabel('Feature (atom)', size=30)
plt.title('Mutagenicity, is mutagenic', size=20)
plt.show()
