In [1]:
import numpy as np
import torch

import matplotlib.pyplot as plt

from pathlib import Path
from pol.utils.validation.scene_saver import load_scenes, count_h5_keys
from pol.utils.plotting import LossTrajectoryPlotter, WitnessMetricsPlotter
from pol.utils.validation.mso_eval import MSOSolution, MSOEvaluation, distance_sqr_matrix
from pol.utils.path import PathHelper

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf')

import matplotlib
use_tex = matplotlib.checkdep_usetex(True)
if use_tex:
    plt.rcParams['text.usetex'] = True
plt.rcParams.update({'font.size': 18})

usetex mode requires dvipng.


In [2]:
prob = 'rastrigin_2d_degen' # 'prodsin_2d' # 'conic' # 'prodsin_2d'
path_helper = PathHelper('../../tests/analytical/')

if prob == 'prodsin_2d':
    all_methods = ['pd_mlr', 'gol_res_hhlr', 'pol_res_lot']
    gt_method = 'pd_mlr'
    candidate_methods = ['gol_res_hhlr', 'pol_res_lot']
    # obj_thresholds = [-0.99, -0.98, -0.97, -0.96, -0.95]
    # precision_thresholds = torch.linspace(0.1, 2.0, 50)
    theta_range = range(0, 8)
    view_box = [[-8, 8], [-8, 8]]
    method_itr = {}
    max_num_iter = 100
elif prob == 'conic':
    # all_methods = ['pd_mlr', 'gol_res_mlr', 'pol_res_lot']
    all_methods = ['pd_mlr', 'pol_res_lot', 'pol_res_lot_hor1']
    gt_method = 'pd_mlr'
    candidate_methods = ['gol_res_mlr', 'pol_res_lot']
    max_num_iter = 50
    theta_range = range(0, 16)
    view_box = [[-5, 5], [-5, 5]]
    method_itr = {
        'pol_res_lot': 5,
        'pol_res_lot_hor1': 5,
        'gol_res_mlr': 100,
    }
elif prob == 'rastrigin_2d':
    all_methods = ['pd_llr', 'pol_res_hgot', 'pol_res_hot', 'pol_res_mot', 'pol_res_not_hor1']
    gt_method = 'pd_llr'
    candidate_methods = ['pol_res_hgot', 'pol_res_hgot_hor1']
    max_num_iter = 100
    theta_range = range(0, 1)
    view_box = [[-5, 5], [-5, 5]]
    method_itr = {
        'pol_res_hgot': 50,
        'pol_res_hgot_hor1': 50,
        'pol_res_hot': 50,
        'pol_res_mot': 50,
        'pol_res_not_hor1': 50,
    }
elif prob == 'rastrigin_2d_degen':
    all_methods = ['pd_llr', 'pol_res_hgot_hor1', 'pol_res_hot_hor1', 'pol_res_mot_hor1', 'pol_res_not_hor1']
    gt_method = 'pd_llr'
    candidate_methods = ['pol_res_hgot', 'pol_res_hgot_hor1']
    max_num_iter = 100
    theta_range = range(0, 1)
    view_box = [[-5, 5], [-5, 5]]
    method_itr = {
        'pol_res_hgot_hor1': 50,
        'pol_res_hot_hor1': 50,
        'pol_res_mot_hor1': 50,
        'pol_res_not_hor1': 50,
    }
    
    import pickle
    with open('../../tests/analytical/L2L_record_rastrigin_2.pickle', 'rb') as f:
        L2L_results = pickle.load(f)
    L2L_x_finals = L2L_results['x_finals'][-1][0]
else:
    assert(False)
    
cmap = plt.cm.get_cmap('hsv', len(all_methods)+1)
colors = {all_methods[i]: cmap(i) for i in range(len(all_methods))}

def convert_method_to_label(method):
    if method.startswith('gol'):
        return 'GOL'
    if method.startswith('pol'):
        return 'POL'
    if method.startswith('pd'):
        return 'PD'
    return 'Unknown'

In [3]:
import ipywidgets as widgets
method_select_widget = widgets.SelectMultiple(
    options=all_methods,
    value=all_methods,
    description='Methods',
    disabled=False
)

In [4]:
loss_plotter = LossTrajectoryPlotter(
    path_helper=path_helper,
    max_num_iter=max_num_iter,
    colors=colors,
    theta_name='thetas')

def vis1():
    loss_plotter.plot_all_convergence(
        problem=prob, methods=method_select_widget.value,
        loss_name='loss', theta_range=[theta_range[0], theta_range[1], theta_range[2]],
        fig_size=10,
        include_hist=False,
    method_label_fn=convert_method_to_label)

In [5]:
display(method_select_widget)

SelectMultiple(description='Methods', index=(0, 1, 2, 3, 4), options=('pd_llr', 'pol_res_hgot_hor1', 'pol_res_…

In [6]:
#vis1()

In [7]:
def plot_particles(ax, theta_idx, methods, 
                   hide=False, marker=None, name=None, s=200, clr=None):
    def plot(X, name, color):
        m = 'o' if marker is None else marker
        ax.scatter(X[:, -2], X[:, -1], color=color, 
                   label=name, s=s, alpha=1.0, marker=m)

    for method_idx, method in enumerate(methods):
        exp_name = (path_helper.format_exp_name(prob, method) 
                    if name is None else name)
        scene = loss_plotter.load_exp_scene(prob, method)
        
        if method in method_itr:
            k = method_itr[method]
        else:
            k = count_h5_keys(scene, 'itr') - 1
            
        # print(f'{method} has {count_h5_keys(scene, "itr")} iterations')
        # print(f'{scene.keys()}')
        tmp = scene['itr_{}'.format(k)]
        X = tmp['X'][theta_idx, :, :]
        S = tmp['satisfy'][theta_idx, :]
        plot(X[S], exp_name, clr if clr is not None else colors[method])
        
    # print(f'marker for {method} is {marker}')

    ax.set_xlim(view_box[0][0], view_box[0][1])
    ax.set_ylim(view_box[1][0], view_box[1][1])
    ax.set_aspect('equal')
    if hide:
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        [i.set_linewidth(4.0) for i in ax.spines.values()]
    else:
        ax.legend()
    
def vis2(theta_range, all_in_one=False, stylish=False, names=None, hide=False, 
         plot_loss_landscape=False, special_gt=False, clrs=None):
    methods = method_select_widget.value
    fig_size = 8
    num_row = 1 if all_in_one else len(methods)
    num_col = 1 if all_in_one else len(theta_range)
    fig, axes = plt.subplots(num_row, num_col, squeeze=False)
    fig.set_figheight(fig_size*len(methods))
    fig.set_figwidth(fig_size*len(theta_range))
    
    style_list=['o', '1', '2', '3', '4']
    for t, theta in enumerate(theta_range):
        for i, method in enumerate(methods):
            r = 0 if all_in_one else i
            c = 0 if all_in_one else t
            if method == gt_method and special_gt:
                if plot_loss_landscape:
                    scene = loss_plotter.load_exp_scene(prob, method)
                    P = scene['info']['landscape_P'][:]
                    P_loss = scene['info']['landscape_P_loss'][:]
                    P_loss = P_loss.squeeze(0)
                    axes[r,c].contour(P[:, :, 0], P[:, :, 1], P_loss, 8,
                                     cmap='binary', alpha=0.3)
                continue
            
            if stylish:
                style = style_list[i]
            else:
                style = None
            plot_particles(axes[r, c], theta, [method], hide=hide, marker=style, 
                           name=names[i] if names is not None else None,
                          clr=clrs[i] if clrs is not None else None)

    #fig.savefig('figs/{}_2d_vis.png'.format(prob))

In [8]:
if prob == 'rastrigin_2d' or prob == 'rastrigin_2d_degen':
    vis2(theta_range=theta_range, all_in_one=True, stylish=True, special_gt=True,
         names=['gt', '$\lambda = 400$', '$\lambda = 10$', '$\lambda=1$', '$\lambda = 0$'],
        plot_loss_landscape=True, clrs=['red', 'tab:purple', 'tab:red', 'tab:olive', 'tab:green'])
else:
    vis2(theta_range=theta_range, hide=True)

<Figure size 576x2880 with 1 Axes>

In [9]:
def vis3():
    from pol.utils.plotting import WitnessMetricsPlotter
    plotter = WitnessMetricsPlotter(path_helper, problem=prob,
                                    gt_method=gt_method, 
                                    candidate_methods=candidate_methods)
    fig = plotter.plot(shortname_fn=convert_method_to_label,
                      colors=colors)

In [10]:
#vis3()

In [12]:
def vis4():
    fig_size = 8
    fig, ax = plt.subplots(1, 1)
    fig.set_figheight(fig_size)
    fig.set_figwidth(fig_size)
    
    scene = loss_plotter.load_exp_scene(prob, gt_method)
    P = scene['info']['landscape_P'][:]
    P_loss = scene['info']['landscape_P_loss'][:]
    P_loss = P_loss.squeeze(0)
    ax.contour(P[:, :, 0], P[:, :, 1], P_loss, 8,
                     cmap='binary', alpha=0.3)
    
    ax.set_xlim(view_box[0][0], view_box[0][1])
    ax.set_ylim(view_box[1][0], view_box[1][1])
    ax.set_aspect('equal')

    trajs = L2L_results['x_finals']
    # print(trajs)
    
    # for t in range(9, len(trajs)):
    for t in range(9, 10):
        X = trajs[t][0]
        X = X.reshape(-1, 2)
        # X = X[:, 3, :]
        ax.scatter(X[:, -2], X[:, -1], color='tab:blue', alpha = 0.3 + 0.7 * (t / len(trajs)))
    
vis4()

<Figure size 576x576 with 1 Axes>