In [1]:
import torch
import numpy as np
from bbob_test_functions import *
import time

device = 'cpu'
# set default to type double 
torch.set_default_dtype(torch.float64)

In [2]:
def compute_prox(x, t, f, delta=1e-1, int_samples=100, alpha=1.0, linesearch_iters=0, device='cpu'):
  '''
      compute prox.
      input is a single vector x of size (dim x 1)
  '''
  assert(x.shape[1]==1)
  assert(x.shape[0]>=1)
  linesearch_iters +=1
  standard_dev = np.sqrt(delta*t/alpha)

  dim = x.shape[0]

  y = standard_dev * torch.randn(int_samples, dim, device=device) + x.permute(1,0) # here y has shape (n_samples x dim)

  z = -f(y)*(alpha/delta) # shape =  n_samples
  w = torch.softmax(z, dim=0) # shape = n_samples

  softmax_overflow_check = (w < np.inf)
  if softmax_overflow_check.prod()==0.0:
    print('x = ', x)
    print('z = ', z)
    print('w = ', w)
    alpha = 0.5*alpha
    return compute_prox(x, t, f, delta=delta, int_samples=int_samples, alpha=alpha, linesearch_iters=linesearch_iters, device=device)
  else:
    prox_term = torch.matmul(w.t(), y)
    prox_term = prox_term.view(-1,1)

    prox_overflow = (prox_term < np.inf)
    if prox_overflow.prod() == 0.0:
      print('prox overflowed: ', prox_term)
    assert(prox_overflow.prod() == 1.0)

    envelope = f(prox_term.view(1,-1)) + (1/(2*t)) * torch.norm(prox_term - x, p=2)**2

    return prox_term, envelope, linesearch_iters

In [3]:
# function_name = 'sphere_function'
# function_name = 'ellipsoidal_function'
# function_name = 'rastrigin_function'
# function_name = 'bueche_rastrigin_function'
# function_name = 'attractive_sector_function'
# function_name = 'rosenbrock_function'
# function_name = 'rotated_rosenbrock_function'
# function_name = 'discus_function'
# function_name = 'bent_cigar_function'
# function_name = 'sharp_ridge_function'
# function_name = 'different_powers_function'
# function_name = 'weierstrass_function'
# function_name = 'schaffers_f7_function'
# function_name = 'schaffers_f7_moderately_ill_cond_function'
# function_name = 'composite_griewank_rosenbrock_function'
function_name = 'schwefel_function'


dim = 10
x_true = torch.zeros(dim, device=device)
if function_name == 'schwefel_function':
    one_plus_minus = torch.ones(dim, device=device)
    one_plus_minus[1::2] = -1

    x_true = 4.2096874633/2 * one_plus_minus


x0 = 4*torch.ones(dim, 1, device=device)
n_trials = 3
max_iters = int(1e4)
t = 1.0

# define function and choose dimension
def f(x, return_gradient=False):
    # return sphere_function(x)
    # turn function_name string into pytorch function

    if return_gradient==True:
        x.requires_grad = True
    
    fx = eval(function_name)(x)
    if return_gradient==True:
        grad_fx = torch.autograd.grad(outputs=fx, inputs=x, grad_outputs=torch.ones_like(fx), create_graph=True)[0]
        x.requires_grad = False
        fx = fx.detach()
        grad_fx = grad_fx.detach()
        return fx, grad_fx
    else:
        x.requires_grad = False
        return fx

### run gradient descent

In [4]:
# now run gradient descent with the same initial condition

step_size_gd = 1e-5
xk_gd = x0.clone()
fx, grad_fx = f(xk_gd.view(1,-1), return_gradient=True)

f_hist_gd_array = torch.zeros(max_iters)
grad_norm_hist_gd_array = torch.zeros(max_iters)
rel_err_gd_array = torch.zeros(max_iters)

for i in range(max_iters):

    f_hist_gd_array[i] = fx.cpu()
    grad_norm_hist_gd_array[i] = torch.norm(grad_fx).cpu()
    rel_err_gd_array[i] = torch.norm(xk_gd - x_true)

    start_time = time.time()
    fx, grad_fx = f(xk_gd.view(1,-1), return_gradient=True)
    xk_gd = xk_gd - step_size_gd*grad_fx.view(-1,1)
    end_time = time.time()
    iter_time = end_time - start_time
    print('iter: ', (i+1), ' fx: ', "{:5.2e}".format(fx.item()),
            ' |grad_fx|: ', "{:5.2e}".format(torch.norm(grad_fx).item()),
            ' rel_err: ', "{:5.2e}".format(torch.norm(xk_gd - x_true).item()),
            ' time = ', '{:5.2f}'.format(iter_time))

iter:  1  fx:  9.47e+04  |grad_fx|:  2.60e+04  rel_err:  4.47e+01  time =   0.00
iter:  2  fx:  8.79e+04  |grad_fx|:  2.51e+04  rel_err:  4.42e+01  time =   0.00
iter:  3  fx:  8.16e+04  |grad_fx|:  2.37e+04  rel_err:  4.38e+01  time =   0.00
iter:  4  fx:  7.60e+04  |grad_fx|:  2.25e+04  rel_err:  4.33e+01  time =   0.00
iter:  5  fx:  7.09e+04  |grad_fx|:  2.23e+04  rel_err:  4.29e+01  time =   0.00
iter:  6  fx:  6.61e+04  |grad_fx|:  2.03e+04  rel_err:  4.25e+01  time =   0.00
iter:  7  fx:  6.19e+04  |grad_fx|:  2.01e+04  rel_err:  4.22e+01  time =   0.00
iter:  8  fx:  5.79e+04  |grad_fx|:  1.89e+04  rel_err:  4.18e+01  time =   0.00
iter:  9  fx:  5.44e+04  |grad_fx|:  1.87e+04  rel_err:  4.15e+01  time =   0.00
iter:  10  fx:  5.10e+04  |grad_fx|:  1.69e+04  rel_err:  4.11e+01  time =   0.00
iter:  11  fx:  4.81e+04  |grad_fx|:  1.68e+04  rel_err:  4.08e+01  time =   0.00
iter:  12  fx:  4.53e+04  |grad_fx|:  1.62e+04  rel_err:  4.05e+01  time =   0.00
iter:  13  fx:  4.27e+04 

In [5]:
# ---------------------------------------------------------------------------------------------------
# Proximal Point using Laplace's approximation
# ---------------------------------------------------------------------------------------------------
def laplace_proximal_point(x0, f, max_iters, t, x_true=x_true, int_samples=100, delta=1e-1, verbose=True, print_freq=1, device='cpu'):

  assert len(x0.shape)==2 and x0.shape[1]==1
  xk = x0.clone()
  fk_hist = torch.zeros(max_iters)
  rel_err_hist = torch.zeros(max_iters)
  time_hist = torch.zeros(max_iters)
  grad_norm_hist = torch.zeros(max_iters)

  for i in range(max_iters):

    start_time = time.time()

    fk, grad_fk = f(xk.permute(1,0), return_gradient=True) # need to input xk with dimensions (n_samples x 1)
    grad_norm = torch.norm(grad_fk)

    rel_err = torch.norm(xk - x_true)

    x_new, uk, ls_iters = compute_prox(xk, t, f, delta=delta, int_samples=int_samples, device=device)

    fk_hist[i] = fk.cpu()
    rel_err_hist[i] = rel_err.cpu()
    grad_norm_hist[i] = grad_norm.cpu()

    end_time = time.time()
    iter_time = end_time - start_time

    time_hist[i] = iter_time

    if verbose:
      if (i+1)%print_freq == 0:
        print('iter: ', (i+1), ' fk: ', "{:5.2e}".format(fk.item()),
              ' rel_err: ', "{:5.2e}".format(rel_err.item()),
              ' grad_norm:' , "{:5.2e}".format(grad_norm.item()),
              ' uk: ', "{:5.2e}".format(uk.item()),
              # ', |xnew-xk|: ', "{:5.2e}".format(norm_diff.item()),
              ' ls: ', ls_iters,
              ' time = ', '{:5.2f}'.format(iter_time))

    xk = x_new

  return xk, fk_hist, rel_err_hist, grad_norm_hist

In [6]:
delta_array = [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]
samples_array = [1e1, 1e2, 1e3, 1e4]

f_hist_array = torch.zeros(n_trials, len(delta_array), len(samples_array), max_iters)
rel_err_array = torch.zeros(n_trials, len(delta_array), len(samples_array), max_iters)
grad_norm_hist_array = torch.zeros(n_trials, len(delta_array), len(samples_array), max_iters)

for i in range(n_trials):

  for j in range(len(delta_array)):
    current_delta = delta_array[j]
    max_ls_iters = 1

    for k in range(len(samples_array)):
      current_int_samples = int(samples_array[k])

      print('\n --------------- Trial: ', i+1, ', delta: ', current_delta, ', n_samples:', current_int_samples, ' --------------- \n')

      temp_output_tuple = laplace_proximal_point(x0,
                                                 f,
                                                 max_iters,
                                                 t,
                                                 int_samples = current_int_samples,
                                                 delta = current_delta,
                                                 device=device,
                                                 print_freq=100)
      xopt                = temp_output_tuple[0]
      fk_hist             = temp_output_tuple[1]
      rel_err_hist        = temp_output_tuple[2]
      grad_norm_hist      = temp_output_tuple[3]


      f_hist_array[i,j,k,:] = fk_hist
      rel_err_array[i,j,k,:] = rel_err_hist
      grad_norm_hist_array[i,j,k,:] = grad_norm_hist


 --------------- Trial:  1 , delta:  0.0001 , n_samples: 10  --------------- 

iter:  100  fk:  5.98e+04  rel_err:  4.27e+01  grad_norm: 2.00e+04  uk:  5.95e+04  ls:  1  time =   0.00
iter:  200  fk:  3.51e+04  rel_err:  3.98e+01  grad_norm: 1.35e+04  uk:  3.47e+04  ls:  1  time =   0.00
iter:  300  fk:  1.78e+04  rel_err:  3.67e+01  grad_norm: 8.67e+03  uk:  1.77e+04  ls:  1  time =   0.00
iter:  400  fk:  7.36e+03  rel_err:  3.35e+01  grad_norm: 4.91e+03  uk:  7.28e+03  ls:  1  time =   0.00
iter:  500  fk:  2.20e+03  rel_err:  3.06e+01  grad_norm: 2.15e+03  uk:  2.15e+03  ls:  1  time =   0.00
iter:  600  fk:  2.99e+02  rel_err:  2.77e+01  grad_norm: 8.45e+02  uk:  2.90e+02  ls:  1  time =   0.00
iter:  700  fk:  2.50e+00  rel_err:  2.58e+01  grad_norm: 3.52e+02  uk:  3.63e+00  ls:  1  time =   0.00
iter:  800  fk:  3.23e+00  rel_err:  2.55e+01  grad_norm: 3.26e+02  uk:  2.80e+00  ls:  1  time =   0.00
iter:  900  fk:  3.76e+00  rel_err:  2.55e+01  grad_norm: 3.86e+02  uk:  1.53e+0

### Save Results

In [9]:
# save_dir = '/content/drive/MyDrive/Projects/2023-Inf-Convolutions/Proximal_Point_Experiments/least_squares_full_rank/exp_results/'
save_dir = 'exp_results/'
data_name = function_name + '_dim' + str(int(dim)) + '_t' + "{:1.0e}".format(t) + '.pth'
file_name = save_dir + data_name
state = {
    'f_hist_array': f_hist_array,
    'f_hist_gd_array': f_hist_gd_array,
    'rel_err_array': rel_err_array,
    'rel_err_gd_array': rel_err_gd_array,
    'grad_norm_hist_array': grad_norm_hist_array,
    'grad_norm_hist_gd_array': grad_norm_hist_gd_array,
    'dim': dim,
    'delta_array': delta_array,
    'samples_array': samples_array,
    'n_trials': n_trials,
    't': t
}
torch.save(state, file_name)
print('files saved to ' + file_name)

files saved to exp_results/schwefel_function_dim10_t1e+00.pth
