In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from pathlib import Path
from pol.utils.validation.scene_saver import load_scenes, count_h5_keys, find_max_h5_key
from pol.utils.plotting import LossTrajectoryPlotter
from pol.utils.path import PathHelper

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf')
import matplotlib
use_tex = matplotlib.checkdep_usetex(True)
if use_tex:
    plt.rcParams['text.usetex'] = True
plt.rcParams.update({'font.size': 28})

In [2]:
prob = 'l1_norm_8d'
path_helper = PathHelper('../../tests/analytical/')
method = 'pol_res_mot'
all_satisfy = True

In [3]:
def ISTA(X, t=1/2):
    # X: BxD
    # OT weight is 1, lasso weight is also 1
    # ISTA(y) = ||y||_1 + ||x-y||_2^2
    cond1 = X >= t
    cond2 = torch.logical_and(X < t, X > -t)
    cond3 = X <= -t
    
    result = torch.where(cond1, X - t, X)
    result = torch.where(cond2, torch.zeros_like(X), X)
    result = torch.where(cond3, X + t, X)
    return result

In [4]:
def plot_diff(ax):
    exp_name = path_helper.format_exp_name(prob, method)
    h5_path = path_helper.locate_scene_h5(prob, method)
    scene = load_scenes(h5_path)[0]
    k = find_max_h5_key(scene, 'itr', return_itr=True)
    itr_first = scene['itr_{}'.format(0)]
    itr_last = scene['itr_{}'.format(1)]
    X_first = torch.from_numpy(itr_first['X'][0, :, :]) # BxD
    X_last = torch.from_numpy(itr_last['X'][0, :, :]) # BxD
    X_gt = ISTA(X_first)
    
    diff = (X_last - X_gt).square().sum(-1) # B
    diff = diff.mean()
    
    print(diff)
    
def vis2():
    fig_size = 8
    fig, axes = plt.subplots(1, 1, squeeze=False)
    fig.set_figheight(fig_size)
    fig.set_figwidth(fig_size)
    plot_diff(axes[0, 0])
    #fig.savefig('figs/{}_2d_vis.png'.format(prob))
    
vis2()

tensor(0.8691)


<Figure size 576x576 with 1 Axes>