In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from utils.notebook_prelude import *

In [None]:
IMAGE_FOLDER = 'tmp/wl_examples'
EXTENSIONS = ['png', 'pdf']
FIG_DIM_WIDTH = 2
FIG_DIM_HEIGHT = 2.4
FONT_SIZE = 9
CMAP = 'Pastel1'
CMAP = 'Set1'

os.makedirs(IMAGE_FOLDER, exist_ok=True)

In [None]:
g1 = nx.Graph()
g1.add_edge('1', '2')
g1.add_edge('2', '3')
g1.add_edge('2', '4')

g2 = g1.copy()
g2.add_edge('3', '4')

In [None]:
def save_fig(fig, filename_without_ext, folder = IMAGE_FOLDER, extensions = EXTENSIONS):
    for ext in extensions:
        filename = '{}/{}.{}'.format(folder, filename_without_ext, ext)
        fig.savefig(filename)

def cleanup_axes(ax):
    ax.grid('off')
    for pos, spine in ax.spines.items(): spine.set_visible(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

def plot_graphs(graphs, pos, nodes = None, node_colors = None, figdim_width = FIG_DIM_WIDTH, figdim_height = FIG_DIM_HEIGHT, font_size = FONT_SIZE):
    assert len(graphs)
    fig, axes = plt.subplots(ncols = len(graphs), figsize = (len(graphs) * FIG_DIM_WIDTH, FIG_DIM_HEIGHT))
    
    if nodes is None:
        nodes = [None] * len(graphs)
    if node_colors is None:
        node_colors = [None] * len(graphs)

    for g, nodelist, node_color, ax in zip(graphs, nodes, node_colors, axes.flatten()):
        nx.draw_networkx(g, nodelist = nodelist, node_color = node_color, pos=pos, ax = ax, font_size = font_size)
        cleanup_axes(ax)
    fig.tight_layout()
    return fig, axes

def create_coloring(graphs, colors):
    current_color = 0
    mapping = {}
    node_colors = []
    for graph in graphs:
        assert current_color < len(colors)
        nodes = sorted(graph.nodes())
        for node in nodes:
            if node not in mapping:
                mapping[node] = colors[current_color]
                current_color += 1
        node_colors.append([mapping[node] for node in nodes])
    return node_colors, mapping

def get_next_coloring(graphs, compress = False):
    label_mapping = {}
    label_counter = 1
    new_graph_labels = []
    for graph in graphs:
        new_labels = {}
        for node in sorted(graph.nodes()):
            neighbours = graph.neighbors(node)
            if compress:
                new_label = node
            else:
                new_label = [node] + list(sorted(neighbours))
                new_label = ','.join([str(x) for x in new_label])
            if new_label not in label_mapping:
                label_mapping[new_label] = label_counter
                label_counter += 1
            if compress:
                new_label = label_mapping[new_label]
            new_labels[node] = new_label
        new_graph_labels.append(new_labels)
    return new_graph_labels

def get_phi(graphs):
    node_num = 0
    all_labels = set()
    for g in graphs:
        nodes = g.nodes()
        node_num += len(nodes)
        all_labels |= set(nodes)
    all_labels = sorted(all_labels)
    phis = []
    for g in graphs:
        phi = np.zeros(node_num, dtype=np.uint)
        for node in g.nodes():
            phi[all_labels.index(node)] = 1
        phis.append(phi)
    return phis, all_labels

def relabel_graphs(graphs, label_mappings, pos):
    new_pos = nx.circular_layout(nx.Graph())
    for label_mapping in label_mappings:
        for old, new in label_mapping.items():
            old_pos = pos[old]
            new_pos[new] = old_pos
    new_graphs = []
    for graph, label_mapping in zip(graphs, label_mappings):
        new_graph = nx.relabel_nodes(graph, mapping=label_mapping)
        new_graphs.append(new_graph)
    new_colors, mapping = create_coloring(new_graphs, colors)
    return new_graphs, new_colors, new_pos

In [None]:
def add_phis_to_fig(graphs, axes, fig):
    phis, all_labels = get_phi(graphs)
    for phi, ax in zip(phis, axes.flatten()):
        ax.text(x = 0, y = -1.35, s = '$\phi$ = [{}]'.format(','.join(str(x) for x in phi)), fontdict={'horizontalalignment': 'center'})
    fig.tight_layout()

pos = pos = nx.circular_layout(g1, scale = 1)
colors = plt.get_cmap(CMAP).colors
nodes = None
node_colors = None
graphs = [g1, g2]
node_colors, mapping = create_coloring(graphs, colors)
fig, axes = plot_graphs(graphs, pos = pos, node_colors = node_colors)
add_phis_to_fig(graphs, axes, fig)
save_fig(fig, 'wl_iteration_0')

# Show phi
phis, all_labels = get_phi(graphs)

for i in range(1):
    # Recolor graphs
    new_graphs, new_colors, new_pos = relabel_graphs(graphs, get_next_coloring(graphs), pos)
    fig, axes = plot_graphs(new_graphs, pos = new_pos, node_colors=new_colors)
    
    save_fig(fig, 'wl_iteration_{}_stage_0_recolored'.format(i + 1))

    # Compress labels
    new_graphs, new_colors, new_pos = relabel_graphs(new_graphs, get_next_coloring(new_graphs, compress=True), new_pos)
    fig, axes = plot_graphs(new_graphs, pos = new_pos, node_colors=new_colors)

    add_phis_to_fig(new_graphs, axes, fig)
    save_fig(fig, 'wl_iteration_{}_stage_1_compressed'.format(i + 1))
    graphs = new_graphs
    pos = new_pos
