In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path
import networkx as nx
from pol.utils.validation.scene_saver import load_scenes, count_h5_keys, find_max_h5_key
from pol.utils.plotting import LossTrajectoryPlotter
from pol.utils.path import PathHelper

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf')
import matplotlib
use_tex = matplotlib.checkdep_usetex(True)
if use_tex:
    plt.rcParams['text.usetex'] = True
plt.rcParams.update({'font.size': 28})

In [2]:
prob = 'k8_mixed_circle_linear' # 'k8_circle'
path_helper = PathHelper('../../tests/maxcut/')


all_methods = ['pd_llr', 'gol_res_hlr',
               'pol_res_hot', #'pol_res_mot', 'pol_res_lot', 'pol_res_hmot', 'pol_res_mlot', 'pol_res_hot',
               ]
              #]
gt_method = 'pd_mlr'
candidate_methods = all_methods[1:]
max_num_iter = 10000
theta_range = range(0, 16) # range(0, 16)
#view_box = [[-5, 5], [-5, 5]]
method_itr = {
    #'pd_mlr': 1000,
    # 'pol_res_mot': 10,
    # 'pol_res_lot': 10,
    # 'pol_res_hmot': 10,
    # 'pol_res_mlot': 10,
    'pol_res_mot': 100,
    'pol_res_lot': 100,
    'pol_res_hmot': 100,
    'pol_res_mlot': 100,
    'pol_res_hot': 100,
    # 'pol_res_mot': 5,
    # 'pol_res_lot': 5,
    # 'pol_res_hmot': 5,
    # 'pol_res_mlot': 5,
    # 'gol_res_mlr': 100,
    # 'gol_res_hlr': 100,
    'gol_res_mlr': 100,
    'gol_res_hlr': 100,
}

    
cmap = plt.cm.get_cmap('hsv', len(all_methods)+1)
colors = {all_methods[i]: cmap(i) for i in range(len(all_methods))}

def convert_method_to_label(method):
    if method.startswith('gol'):
        return 'GOL'
    if method.startswith('pol'):
        return 'POL'
    if method.startswith('pd'):
        return 'PD'
    return 'Unknown'

In [3]:
import ipywidgets as widgets
method_select_widget = widgets.SelectMultiple(
    options=all_methods,
    value=all_methods,
    description='Methods',
    disabled=False
)

In [4]:
loss_plotter = LossTrajectoryPlotter(
    path_helper=path_helper,
    max_num_iter=max_num_iter,
    colors=colors,
    theta_name='thetas')

def vis1():
    loss_plotter.plot_all_convergence(
        problem=prob, methods=method_select_widget.value,
        loss_name='X_loss', satisfy_name='satisfy',
        theta_range=theta_range,
        fig_size=10,
        method_itr=method_itr,
        include_hist=True)

    

In [5]:
display(method_select_widget)

SelectMultiple(description='Methods', index=(0, 1, 2), options=('pd_llr', 'gol_res_hlr', 'pol_res_hot'), value…

In [6]:
#vis1()

In [7]:
from sklearn.manifold import TSNE

def plot_particles(ax, theta_idx, methods, hide=False):
    def plot(X, name, color):
        ax.scatter(X[:, -2], X[:, -1], color=color, label=name, alpha=0.5)

    for method_idx, method in enumerate(methods):
        exp_name = path_helper.format_exp_name(prob, method)
        scene = loss_plotter.load_exp_scene(prob, method)
        
        if method in method_itr:
            k = method_itr[method]
        else:
            k = count_h5_keys(scene, 'itr') - 1
        tmp = scene['itr_{}'.format(k)]
        X = tmp['X'][theta_idx, :, :]
        X_embed = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(X)
        #X_embed = np.stack([X[..., 0], X[..., 8]], -1)
        S = tmp['satisfy'][theta_idx, :]
        plot(X_embed[S], exp_name, colors[method])

    # ax.set_xlim(view_box[0][0], view_box[0][1])
    # ax.set_ylim(view_box[1][0], view_box[1][1])
    if hide:
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    else:
        ax.legend()
    
def vis2():
    methods = method_select_widget.value
    fig_size = 8
    fig, axes = plt.subplots(len(methods), len(theta_range), squeeze=False)
    fig.set_figheight(fig_size*len(methods))
    fig.set_figwidth(fig_size*len(theta_range))
    for t, theta in enumerate(theta_range):
        for i, method in enumerate(methods):
            plot_particles(axes[i, t], theta, [method], hide=False)
    #fig.savefig('figs/{}_2d_vis.png'.format(prob))

In [8]:
#vis2()

In [9]:
def vis3():   
    from pol.utils.plotting import WitnessMetricsPlotter
    plotter = WitnessMetricsPlotter(path_helper, problem=prob,
                                    gt_method=gt_method, 
                                    candidate_methods=candidate_methods)
    fig = plotter.plot(shortname_fn=convert_method_to_label,
                      colors=colors)

#vis3()

In [13]:
def draw_cut(ax, edges, weights, Y, include_weights=True, hide=False, small=True):
    # Draw only positively weighted edges.
    num_vertex = Y.shape[0]
    pos_mask = weights > 0
    edges = edges[pos_mask, :]
    weights = weights[pos_mask]
    
    G = nx.Graph()
    G.add_nodes_from(range(len(Y)))
    is_cut = []
    for i in range(len(edges)):
        u = edges[i, 0]
        v = edges[i, 1]
        G.add_edge(u, v)
        is_cut.append(Y[u] != Y[v])
    is_cut = np.array(is_cut)
    
    if small:
        node_options = {
            "node_size": 1000,
            "edgecolors": "black",
            "linewidths": 3,
        }
        edge_options = {
            'width': 3,
        }
        font_size = 20
    else:
        node_options = {
            "node_size": 4000,
            "edgecolors": "black",
            "linewidths": 5,
        }
        edge_options = {
            'width': 6,
        }
        font_size = 40
    plt.sca(ax)
    pos = nx.spring_layout(G, seed=48) # .spring_layout(G, seed=42)
    nx.draw_networkx_nodes(G, pos, nodelist=[i for i in range(num_vertex) if Y[i] > 0], 
                          node_color='tab:blue', **node_options)
    nx.draw_networkx_nodes(G, pos, nodelist=[i for i in range(num_vertex) if Y[i] < 0], 
                          node_color='tab:red', **node_options)
    nx.draw_networkx_edges(G, pos, edgelist=edges[is_cut, :], style='dashed', **edge_options)
    nx.draw_networkx_edges(G, pos, edgelist=edges[np.logical_not(is_cut), :], **edge_options)
    if include_weights:
        nx.draw_networkx_edge_labels(G,pos,edge_labels={(edges[i, 0], edges[i, 1]): '{:.2f}'.format(w) 
                                                        for i, w in enumerate(weights)})
    labeldict = nx.draw_networkx_labels(G, pos, font_size=font_size, font_family="sans-serif")
    
    
from pol.utils.filter_cuts import filter_cuts, is_integral
from scipy.sparse.csgraph import connected_components

def is_connected(edges, weights, silent=True):
    num_vertex = edges.max() + 1
    from scipy.sparse import csr_matrix
    u = edges[weights > 0.5, 0]
    v = edges[weights > 0.5, 1]
    weights = weights[weights > 0.5]
    mat = csr_matrix((weights, (u, v)), shape=(num_vertex, num_vertex))
    num_component = connected_components(mat, directed=False)[0]

    return num_component == 1
    
def vis_graph(method, num_theta, num_sol, min_unique, 
              block_mode=False, num_row=None, num_col=None, hide=False,
             special_i=None, small=True):
    fig_size = 8
    if not block_mode:
        fig, axes = plt.subplots(num_theta, num_sol, squeeze=False)
        fig.set_figheight(fig_size*num_theta)
        fig.set_figwidth(fig_size*num_sol)
    else:
        fig, axes = plt.subplots(num_row, num_col, squeeze=False, constrained_layout=True)
        fig.set_figheight(fig_size*num_row)
        fig.set_figwidth(fig_size*num_col)
        ax_flatten = []
        for i in range(num_row):
            for j in range(num_col):
                ax_flatten.append(axes[i][j])
    scene = loss_plotter.load_exp_scene(prob, method)
    info = scene['info']
    edges = info['edges'] # Ex2
    k = method_itr.get(method, -1)
    if k == -1:
        k = find_max_h5_key(scene, 'itr', return_itr=True)
    itr = scene['itr_{}'.format(k)]
    thetas = info['thetas']
    count = 0
    
    rg = [special_i] if special_i is not None else range(0, thetas.shape[0])
    for i in rg:
        weights = thetas[i, :] # E  
        if not is_integral(weights):
            continue
        if not is_connected(edges, weights):
            continue

        Y = itr['Y'][i] # BxV
        Y_loss = itr['Y_loss'][i] # B
        ret = filter_cuts(Y, Y_loss)
        filtered = list(ret['filtered'])
        if len(filtered) < min_unique:
            continue
        print('Found {}'.format(i))
        #is_connected(edges, weights, silent=False)
        for j in range(min(num_sol, len(filtered))):
            if block_mode:
                ax = ax_flatten[j]
            else:
                ax = axes[count, j]
            draw_cut(ax, edges, weights, Y=filtered[j][0], include_weights=False, small=small)
            if not hide:
                ax.set_title(r'$\#\textrm{cuts}=' + '{}$'.format(int(round(filtered[j][1]))))
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
                ax.spines['left'].set_visible(False)
            else:
                ax.axis('off')
        count += 1
        if count == num_theta:
            break

In [14]:
vis_graph('pol_res_hot', num_theta=10, num_sol=8, min_unique=8, hide=False)

Found 17
Found 32
Found 33
Found 60
Found 89
Found 97
Found 100
Found 105
Found 141
Found 209


<Figure size 4608x5760 with 80 Axes>

In [12]:
# vis_graph('pol_res_hot', num_theta=1, num_sol=18, min_unique=18,
#          block_mode=True, num_row=2, num_col=9, hide=True, small=False, special_i=5583)

Found 5583


<Figure size 5184x1152 with 18 Axes>