# This notebook is based on the paper: "Global-Convergence-Nonconvex-Optimization". 

The aim of this project is to find the global solution to  
\begin{equation}
  \min_{x \in \mathbb{R}^n} f(x).
\end{equation}

To obtain a global minimizer, the main idea is to minimize the Moreau Envelop instead, which "convexifies" the original function. To make the Moreau envelope tractable, we use connections to Hamilton-Jacobi Equations via the Cole-Hopf and Hopf-Lax formulas to efficiently compute the gradients of the Moreau envelope.

In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import math
import time
import random
from scipy.integrate import quad
from scipy.special import roots_hermite
from typing import Optional 

epsilon_double = np.finfo(np.float64).eps

from test_functions import Griewank, AlpineN1, Drop_Wave, Levy, Rastrigin, Ackley
from test_functions import Griewank_numpy, AlpineN1_numpy, Drop_Wave_numpy, Levy_numpy, Rastrigin_numpy, Ackley_numpy

from test_functions import MultiMinimaFunc, MultiMinimaAbsFunc
from test_functions import MultiMinimaFunc_numpy, MultiMinimaAbsFunc_numpy

seed   = 30
torch.manual_seed(seed)

### Create Class for Hamilton-Jacobi Moreau Adaptive Descent (HJ_MAD)

The Moreau envelop of $f$ is given by
\begin{equation}
  u(x,t) \triangleq \inf_{z\in \mathbb{R}^n} f(z) + \dfrac{1}{2t}\|z-x\|^2.
\end{equation}

We leverage the fact that the solution to the Moreau envelope above satisfies the Hamilton-Jacobi Equation
\begin{equation}
  \begin{split}
    u_t^\delta  + \frac{1}{2}\|Du^\delta  \|^2 \ = \frac{\delta}{2} \Delta u^\delta \qquad &\text{ in }  \mathbb{R}^n\times (0,T]
    \\
    u = f \qquad &\text{ in } \mathbb{R}^n\times \{t = 0\}
  \end{split}
\end{equation}
when $\delta = 0$. 

By adding a viscous term ($\delta > 0$), we are able to approximate the solution to the HJ equation using the Cole-Hopf formula to obtain
\begin{equation}
  u^\delta(x,t) = - \delta \ln\Big(\Phi_t * \exp(-f/\delta)\Big)(x) = - \delta \ln \int_{\mathbb{R}^n} \Phi(x-y,t)  \exp\left(\frac{-f(y)}{\delta}\right) dy 
\end{equation}
where 
\begin{equation}
  \Phi(x,t) = \frac{1}{{(4\pi \delta t)}^{n/2}} \exp{\frac{-\|x\|^2}{4\delta t}}. 
\end{equation}
This allows us to write the Moreau Envelope (and its gradient) explicitly as an expectation. In this case, we compute the gradient as
\begin{equation}
  \nabla u^\delta(x,t) = \dfrac{1}{t}\cdot  \dfrac{\mathbb{E}_{y\sim  \mathbb{P}_{x,t}}\left[(x-y) \exp\left(-\delta^{-1}\tilde{f}(y)\right) \right]}
    {\mathbb{E}_{y\sim  \mathbb{P}_{x,t}}\left[ \exp\left(-\delta^{-1} \tilde{f}(y)\right) \right]}
\end{equation}


### Define classes for HJ-MAD, Pure Random Search, and Gradient Descent 

In [None]:
# ------------------------------------------------------------------------------------------------------------
# HJ Moreau Adaptive Descent
# ------------------------------------------------------------------------------------------------------------

class HJ_MAD:
    ''' 
        Hamilton-Jacobi Moreau Adaptive Descent (HJ_MAD) is used to solve nonconvex minimization
        problems via a zeroth-order sampling scheme.
        
        Inputs:
          1)  f            = function to be minimized. Inputs have size (n_samples x n_features). Outputs have size n_samples
          2)  x_true       = true global minimizer
          3)  delta        = coefficient of viscous term in the HJ equation
          4)  int_samples  = number of samples used to approximate expectation in heat equation solution
          5)  x_true       = true global minimizer
          6)  t_vec        = time vector containig [initial time, minimum time allowed, maximum time]
          7)  max_iters    = max number of iterations
          8)  tol          = stopping tolerance
          9)  theta        = parameter used to update tk
          10) beta         = exponential averaging term for gradient beta (beta multiplies history, 1-beta multiplies current grad)
          11) eta_vec      = vector containing [eta_minus, eta_plus], where eta_minus < 1 and eta_plus > 1 (part of time update)
          11) alpha        = step size. has to be in between (1-sqrt(eta_minus), 1+sqrt(eta_plus))
          12) fixed_time   = boolean for using adaptive time
          13) verbose      = boolean for printing
          14) momentum     = For acceleration.
          15) accelerated  = boolean for using Accelerated Gradient Descent

        Outputs:
          1) x_opt                    = optimal x_value approximation
          2) xk_hist                  = update history
          3) tk_hist                  = time history
          4) fk_hist                  = function value history
          5) xk_error_hist            = error to true solution history 
          6) rel_grad_uk_norm_hist    = relative grad norm history of Moreau envelope
    '''
    def __init__(self, f, x_true, delta=0.1, int_samples=100, t_vec = [1.0, 1e-3, 1e1], max_iters=5e4, 
                 tol=5e-2, theta=0.9, beta=[0.9], eta_vec = [0.9, 1.1], alpha=1.0, fixed_time=False, 
                 verbose=True,rescale0=1e-1, momentum=None,saturate_tol=1e-9, integration_method='MC',
                 rootsGHQ=None):
      
      self.delta            = delta
      self.f                = f
      self.int_samples      = int_samples
      self.max_iters        = max_iters
      self.tol              = tol
      self.t_vec            = t_vec
      self.theta            = theta
      self.x_true           = x_true
      self.beta             = beta 
      self.alpha            = alpha 
      self.eta_vec          = eta_vec
      self.fixed_time       = fixed_time
      self.verbose          = verbose
      self.momentum         = momentum
      self.rescale0         = rescale0
      self.saturate_tol     = saturate_tol
      self.integration_method = integration_method
      if rootsGHQ is not None:
        self.rootsGHQ = rootsGHQ

      
      # check that alpha is in right interval
      assert(alpha >= 1-np.sqrt(eta_vec[0]))
      assert(alpha <= 1+np.sqrt(eta_vec[1]))

    def compute_grad_uk_NMC(self,x, t, dim=slice(None)):
      """
        Compute the gradient of Moreau Envelope of f using Monte Carlo Integration.
        
        Parameters:
            x (float): Point at which to compute the gradient.
            t (float): Time parameter.
            dim (int): Dimension to compute the gradient.
        
        Returns:
            grad (float): Computed gradient.
        """
      
      underflow = 1e-15
      overflow = 1e15
      rescale_factor = self.rescale0
      min_rescale=1e-15
      iterations = 0

      n_features = x.shape[0]
      device = x.device

      y = x.expand(self.int_samples, n_features).clone()
      while True:
        t_rescaled = t/rescale_factor


        sqrt_factor = np.sqrt(self.delta * t_rescaled)

        # Replace the specified column with random samples
        if dim == slice(None):  # randomize all n_features of y
          #z = torch.randn(self.int_samples,n_features)
          y = x - self.z * sqrt_factor
        else:
          # y = x.expand(self.int_samples, n_features).clone()
          # # Replace the specified column with random samples
          # y[:, dim] = standard_dev * torch.randn(self.int_samples) + x[dim]
          #z = torch.randn(self.int_samples)
          y[:, dim] = x[dim] - self.z * sqrt_factor

        f_values = self.f(y)

        # Rescale the exponent to prevent overflow/underflow
        rescaled_exponent= -rescale_factor*f_values/self.delta

        # Remove the maximum exponent to prevent overflow
        max_exponent = torch.max(rescaled_exponent) 
        shifted_exponent = rescaled_exponent - max_exponent

        # Find the maximum absolute exponent
        # max_abs_exponent = torch.max(torch.abs(rescaled_exponent))
        # max_exponent = torch.max(rescaled_exponent)

        # # If max_exponent is negative, add the absolute value to prevent underflow
        # # If max_exponent is positive, subtract it to prevent overflow
        # shifted_exponent = rescaled_exponent + max_abs_exponent if max_exponent < 0 else rescaled_exponent - max_exponent

        # Compute the exponential term
        exp_term = torch.exp(shifted_exponent)

        denominator       = torch.mean(exp_term)

        # Check if the denominator is within the underflow/overflow bounds
        if (denominator >= underflow and denominator <= overflow) or rescale_factor < min_rescale:
          break

        # Adjust rescale factor and increment iteration count
        rescale_factor /= 2
        iterations += 1

      # Increase intial rescale factor if it is too small to better utilize samples
      # if iterations == 0:
      #   self.rescale0 = self.rescale0*2
      # else:
      # self.rescale0 = rescale_factor

      # print(f"{iterations=}")

      # THIS IS IMPORTANT FOR TUNING INITIAL RESCALE FACTOR
      # print(f'Loops to find rescale factor: {loop_iterations}')

      # Replace the specified column with random samples
      if dim == slice(None):  # randomize all n_features of y
        numerator = self.z*exp_term.view(self.int_samples, 1)
      else:
        numerator = self.z*exp_term

      numerator = torch.mean(numerator)
        
      # Compute Grad U
      grad_uk_1D = sqrt_factor*numerator/denominator # the t gets canceled with the update formula
      grad_uk = torch.zeros(n_features, dtype=torch.float64, device=device)
      grad_uk[dim] = grad_uk_1D

      # Compute Estimated prox_xk
      prox_xk = x - grad_uk

      #print(f"{grad_uk[dim]=}")

      return grad_uk, prox_xk

    def compute_grad_uk_MC(self,x, t, dim=slice(None)):
      """
        Compute the gradient of Moreau Envelope of f using Monte Carlo Integration.
        
        Parameters:
            x (float): Point at which to compute the gradient.
            t (float): Time parameter.
            dim (int): Dimension to compute the gradient.
        
        Returns:
            grad (float): Computed gradient.
        """
      
      underflow = 1e-15
      overflow = 1e15
      rescale_factor = self.rescale0
      min_rescale=1e-10
      iterations = 0

      while True:
        n_features = x.shape[0]
        standard_dev = np.sqrt(self.delta*t/rescale_factor)

        #samples = self.int_samples
        
        #y = standard_dev * torch.randn(samples, n_features) + x
        y = x.expand(self.int_samples, n_features).clone()

        if dim == slice(None):  # randomize all n_features of y
          y = standard_dev * self.z + x#torch.randn(self.int_samples, n_features) + x
        else:
          y = x.expand(self.int_samples, n_features).clone()
          # Replace the specified column with random samples
          y[:, dim] = standard_dev * self.z + x[dim] # torch.randn(self.int_samples) + x[dim]

        f_values = self.f(y)

        rescaled_exponent = -rescale_factor*f_values/self.delta
        # Remove the maximum exponent to prevent overflow
        max_exponent = torch.max(rescaled_exponent) 
        shifted_exponent = rescaled_exponent - max_exponent

        exp_term = torch.exp(shifted_exponent)
        v_delta       = torch.mean(exp_term)

        # print(f'v_delta = {v_delta}')

        if (v_delta >= underflow and v_delta <= overflow) or rescale_factor < min_rescale:
          break

        # Adjust rescale factor and increment iteration count
        rescale_factor /= 2
        iterations += 1

      # Increase intial rescale factor if it is too small to better utilize samples
      if iterations == 0:
        self.rescale0 = self.rescale0*2
      else:
        self.rescale0 = rescale_factor

      # THIS IS IMPORTANT FOR TUNING INITIAL RESCALE FACTOR
      # print(f'Loops to find rescale factor: {loop_iterations}')
      numerator = y*exp_term.view(self.int_samples, 1)
      numerator = torch.mean(numerator, dim=0)
        
      # Compute Grad U
      # TODO: Check if this is correct
      grad_uk = (x -  numerator/(v_delta)) # the t gets canceled with the update formula

      # Compute Moreau envelope
      #uk = -self.delta * torch.log(v_delta)

      # Compute Estimated prox_xk
      prox_xk = numerator / (v_delta)

      return grad_uk, prox_xk
  

    def compute_grad_uk_GHQ(self, x, t, dim):
        """
        Compute the gradient of Moreau Envelope of f using Gauss-Hermite quadrature.
        
        Parameters:
            x (float): Point at which to compute the gradient.
            t (float): Time parameter.
            dim (int): Dimension to compute the gradient.
        
        Returns:
            grad (float): Computed gradient.
        """

        n_features = x.shape[0]
        min_rescale=1e-15
        underflow = 1e-15
        overflow = 1e15

        z, weights = self.rootsGHQ

        # Define rescale_factor for preventing overflow/underflow
        rescale_factor = self.rescale0
        rescale_counter = 0

        # The device on which the computation is done
        device = x.device

        # Repeat x for all samples
        y = x.clone().expand(self.int_samples,n_features)
        y = y.clone().contiguous()

        while True:
          # Rescale t
          t_rescaled = t/rescale_factor

          # Compute the roots and weights for the Hermite quadrature

          # Compute the integral of the exponential for f in 1 Dimension
          y_1D = x[dim] - z*np.sqrt(2*self.delta*t_rescaled) # Size (int_samples,1)
          y[:,dim] = y_1D # Update y only in the dimension we are moving.

          rescaled_f_exponent = - rescale_factor*self.f(y)/ self.delta
          max_exponent = torch.max(rescaled_f_exponent)  # Find the maximum exponent
          shifted_exponent = rescaled_f_exponent - max_exponent
          F_exp = torch.exp(shifted_exponent)


          # Compute the Denominator Integral
          v_delta = - torch.sum(weights * F_exp) # * (1 / np.sqrt(np.pi)) 

          # Make sure over/underflow does not occur in the Denominator
          if (v_delta >= underflow and v_delta <= overflow) or rescale_factor < min_rescale:
            break
          
          # Adjust rescale factor and increment iteration count
          rescale_factor /= 2
          rescale_counter += 1
        
        # Increase intial rescale factor if it is too small to better utilize samples
        if rescale_counter == 0:
          self.rescale0 = self.rescale0*2
        else:
          self.rescale0 = rescale_factor

        #print(f"{rescale_counter=}")

        # Compute Numerator Integral
        grad_v_delta_F = z * F_exp
          
        # numerator = - np.sqrt(2/(self.delta*t_rescaled)) * torch.sum(weights * grad_v_delta_F) # * (1 / np.sqrt(np.pi)) 
        numerator = np.sqrt(2/(self.delta*t_rescaled)) * torch.sum(weights * grad_v_delta_F) # * (1 / np.sqrt(np.pi)) 

        # Compute Gradient in 1D
        grad_uk_1D = - self.delta * numerator / v_delta
        grad_uk = torch.zeros(n_features, dtype=torch.float64, device=device)
        grad_uk[dim] = grad_uk_1D

        # Compute Prox_xk
        prox_xk = x - t_rescaled*grad_uk
        
        return t_rescaled*grad_uk, prox_xk
    
    def compute_grad_uk_quad(self, x, t, dim):
      """
      Compute the gradient of Moreau Envelope of f using SciPy's quad adaptive integration.
      
      Parameters:
          x (torch.Tensor): Point at which to compute the gradient.
          t (float): Time parameter.
          dim (int): Dimension to compute the gradient.
      
      Returns:
          grad (torch.Tensor): Computed gradient.
          prox_xk (torch.Tensor): Proximal operator result.
      """

      n_features = x.shape[0]
      min_rescale = 1e-15
      underflow = 1e-15
      overflow = 1e15

      # Define rescale_factor for preventing overflow/underflow
      rescale_factor = self.rescale0
      rescale_counter = 0

      # The device on which the computation is done
      device = x.device

      while True:
          # Rescale t
          t_rescaled = t / rescale_factor
          sqrt_factor = np.sqrt(2 * self.delta * t_rescaled)


          # Define the rescaled function for numerator and denominator
          def integrand(z, for_numerator=False):
              y = x.clone()
              y[dim] = x[dim] - z * sqrt_factor
              rescaled_f_exponent = -rescale_factor * self.f(y.view(1,-1)) / self.delta
              exp_term = torch.exp(rescaled_f_exponent - z**2)
              
              if for_numerator:
                  return z * exp_term.item()
              else:
                  return exp_term.item()

          # Compute denominator integral (v_delta)
          denominator_result, _ = quad(
              integrand, 
              -10, 
              10, 
              args=(False,)
          )

          # Ensure no over/underflow occurs
          if (underflow <= denominator_result <= overflow) or rescale_factor < min_rescale:
              print(f"{denominator_result=}")
              break

          # Adjust rescale factor and increment iteration count
          rescale_factor /= 2
          rescale_counter += 1

      # Update initial rescale factor if too small to better utilize samples
      # if rescale_counter == 0:
      #     self.rescale0 *= 2
      # else:
      #     self.rescale0 = rescale_factor

      # Compute numerator integral
      numerator_result, _ = quad(
          integrand, 
          10, 
          -10, 
          args=(True,)
      )

      # Gradient computation
      grad_uk_1D = -self.delta * (np.sqrt(2 / (self.delta * t_rescaled)) * numerator_result) / denominator_result
      grad_uk = torch.zeros(n_features, dtype=torch.float64, device=device)
      grad_uk[dim] = grad_uk_1D

      # Compute proximal operator result
      prox_xk = x - t_rescaled * grad_uk

      return t_rescaled * grad_uk, prox_xk


    def gradient_descent(self, xk, tk, update_dim=slice(None)):
        # Compute prox and gradient
        if self.integration_method == 'MC':
          grad_uk, prox_xk = self.compute_grad_uk_MC(xk, tk, update_dim)

          # Perform gradient descent update
          xk_plus1 = xk.clone()
          xk_plus1[update_dim] = xk[update_dim] - self.alpha * (xk[update_dim] - prox_xk[update_dim])

        elif self.integration_method == 'GHQ':
          grad_uk, _ =self.compute_grad_uk_GHQ(xk, tk, update_dim)
        
          # Perform gradient descent update
          xk_plus1 = xk - self.alpha * grad_uk

        elif self.integration_method == 'quad':
          grad_uk, _ = self.compute_grad_uk_quad(xk, tk, update_dim)

          # Perform gradient descent update
          xk_plus1 = xk - self.alpha * grad_uk
        
        elif self.integration_method == 'NMC':
          grad_uk, _ = self.compute_grad_uk_NMC(xk, tk, update_dim)

          # Perform gradient descent update
          xk_plus1 = xk - self.alpha * grad_uk

        return xk_plus1, grad_uk

    def update_time(self, tk, rel_grad_uk_norm):
      '''
        time step rule

        if ‖gk_plus‖≤ theta (‖gk‖+ eps):
          min (eta_plus t,T)
        else
          max (eta_minus t,t_min) otherwise

        OR:
        
        if rel grad norm too small, increase tk (with maximum T).
        else if rel grad norm is too "big", decrease tk with minimum (t_min)
      '''

      eta_minus = self.eta_vec[0]
      eta_plus = self.eta_vec[1]
      T = self.t_vec[2]
      t_min = self.t_vec[1]

      if rel_grad_uk_norm <= self.theta:
        # increase t when relative gradient norm is smaller than theta
        tk = min(eta_plus*tk , T) 
      else:
        # decrease otherwise t when relative gradient norm is smaller than theta
        tk = max(eta_minus*tk, t_min)

      return tk
    
    def stopping_criteria(self,k,cd,history):
      '''
        Stopping Criteria for HJ-MAD and HJ-MAD-CD
      '''
      xk_hist, xk_error_hist, rel_grad_uk_norm_hist, fk_hist, tk_hist = history

      if xk_error_hist[k] < self.tol:
          if self.verbose:
            print('HJ-MAD converged with rel grad norm {:6.2e}'.format(rel_grad_uk_norm_hist[k]))
            print('iter = ', k, ', number of function evaluations = ', len(xk_error_hist)*self.int_samples)
          return True
      elif k==self.max_iters:
        if self.verbose:
          print('HJ-MAD failed to converge with rel grad norm {:6.2e}'.format(rel_grad_uk_norm_hist[k]))
          print('iter = ', k, ', number of function evaluations = ', len(xk_error_hist)*self.int_samples)
          print('Used fixed time = ', self.fixed_time)
          return True
      if cd:
        if k > 0 and np.abs(torch.norm(xk_hist[k] - xk_hist[k-1])) < self.saturate_tol*torch.norm(xk_hist[k-1]): 
          if self.verbose:
            print('HJ-MAD converged due to error saturation with rel grad norm {:6.2e}'.format(rel_grad_uk_norm_hist[k]))
            print('iter = ', k, ', number of function evaluations = ', len(xk_error_hist)*self.int_samples)
          return True
        elif k > 10 and np.sum(np.diff(xk_error_hist[k-10:k+1]) > 0) > 3: # TODO: Needs to be Removed and Replaced with stopping criterion below
          if self.verbose:
            print('HJ-MAD stopped due to non-monotonic error decrease with rel grad norm {:6.2e}'.format(rel_grad_uk_norm_hist[k]))
            print('iter = ', k, ', number of function evaluations = ', len(xk_error_hist)*self.int_samples)
          return True
        # elif k > 20 and torch.std(fk_hist[k-20:k+1]) < self.tol:
        #   print('HJ-MAD converged due to oscillating fk with rel grad norm {:6.2e}'.format(rel_grad_uk_norm_hist[k]))
        #   print('iter = ', k, ', number of function evaluations = ', len(xk_error_hist)*int_samples)
        #   return True

    
    def run(self, x0, cd=False, update_dim=slice(None)):
      """
      Run the HJ-MAD algorithm to minimize the function.

      Parameters:
      x0 (torch.Tensor): Initial guess for the minimizer.
      cd (bool): Coordinate descent flag.
      update_dim (slice): Dimension to update.

      Returns:
      x_opt (torch.Tensor): Optimal x value approximation.
      xk_hist (torch.Tensor): Update history.
      tk_hist (torch.Tensor): Time history.
      xk_error_hist (torch.Tensor): Error to true solution history.
      rel_grad_uk_norm_hist (torch.Tensor): Relative grad norm history of Moreau envelope.
      fk_hist (torch.Tensor): Function value history.
      """
      # Dimensions of x0
      n_features = x0.shape[0]

      # Initialize history tensors
      xk_hist = torch.zeros(self.max_iters, n_features)
      xk_error_hist = torch.zeros(self.max_iters)
      rel_grad_uk_norm_hist = torch.zeros(self.max_iters)
      fk_hist = torch.zeros(self.max_iters)
      tk_hist = torch.zeros(self.max_iters)

      # Initialize iteration variables x and t
      xk = x0
      x_opt = xk
      tk = self.t_vec[0]

      # Set up Momentum
      if self.momentum is not None:
        xk_minus_1 = xk

      rel_grad_uk_norm = 1.0

      if self.integration_method == 'NMC' or self.integration_method == 'MC':
        if update_dim == slice(None):  # randomize all n_features of y
          self.z = torch.randn(self.int_samples,n_features)
        else:
          self.z = torch.randn(self.int_samples)

      fmt = '[{:3d}]: fk = {:6.2e} | xk_err = {:6.2e} | |grad_uk| = {:6.2e} | tk = {:6.2e}'
      if self.verbose:
        print('-------------------------- RUNNING HJ-MAD ---------------------------')
        print('dimension = ', n_features, 'n_samples = ', self.int_samples)

      # Compute initial gradient
      _ , grad_uk = self.gradient_descent(xk, tk, update_dim)


      for k in range(self.max_iters):
        # Store current state in history
        xk_hist[k, :] = xk
        rel_grad_uk_norm_hist[k] = rel_grad_uk_norm
        xk_error_hist[k] = torch.norm(xk - self.x_true)
        tk_hist[k] = tk
        fk_hist[k] = self.f(xk.view(1, n_features))

        if self.verbose:
          print(fmt.format(k + 1, fk_hist[k], xk_error_hist[k], rel_grad_uk_norm_hist[k], tk))

        # Check for convergence
        if self.stopping_criteria(k, cd, [xk_hist, xk_error_hist, rel_grad_uk_norm_hist, fk_hist, tk_hist]):
          break

        if k > 0 and fk_hist[k] < fk_hist[k - 1]:
          x_opt = xk

        grad_uk_norm_old = torch.norm(grad_uk)

        # Accelerate gradient descent if momentum is not None
        if self.momentum is not None and k > 0:
          yk = xk.clone()
          yk[update_dim] = xk[update_dim] + self.momentum * (xk[update_dim] - xk_minus_1[update_dim])
          xk_minus_1[update_dim] = xk[update_dim]
        else:
          yk = xk.clone()

        # Perform gradient descent
        xk, grad_uk = self.gradient_descent(yk, tk, update_dim)

        # Compute relative gradients
        grad_uk_norm = torch.norm(grad_uk)
        rel_grad_uk_norm = grad_uk_norm / (grad_uk_norm_old + 1e-12)

        # Update tk
        if not self.fixed_time:
          tk = self.update_time(tk, rel_grad_uk_norm)

  
      return x_opt, xk_hist[0:k+1,:], tk_hist[0:k+1], xk_error_hist[0:k+1], rel_grad_uk_norm_hist[0:k+1], fk_hist[0:k+1]
    

In [None]:
class HJ_MAD_CoordinateDescent(HJ_MAD):
    """
    Hamilton-Jacobi Moreau Adaptive Descent (HJ_MAD) Coordinate Descent for 2D functions.

    This class extends the HJ_MAD algorithm to perform coordinate descent for 2D functions. It alternates between 
    optimizing each coordinate while keeping the other fixed, treating the function as a 1D function for each run 
    and using the previous solution in the next run.

    Attributes:
        f (callable): The function to be minimized.
        x_true (torch.Tensor): The true global minimizer.
        delta (float): Coefficient of the viscous term in the HJ equation.
        int_samples (int): Number of samples used to approximate expectation in the heat equation solution.
        t_vec (list): Time vector containing [initial time, minimum time allowed, maximum time].
        max_iters (int): Maximum number of iterations.
        tol (float): Stopping tolerance.
        alpha (float): Step size.
        beta (float): Exponential averaging term for gradient beta.
        eta_vec (list): Vector containing [eta_minus, eta_plus].
        theta (float): Parameter used to update tk.
        fixed_time (bool): Whether to use adaptive time.
        verbose (bool): Whether to print progress.
        rescale0 (float): Initial rescale factor.
        momentum (float): Momentum term for acceleration.

    Methods:
        run(x0, num_cycles): Runs the coordinate descent optimization process.
    """

    def __init__(self, f, x_true, delta=0.1, int_samples=100, t_vec=[1.0, 1e-3, 1e1], max_iters=5e4,
                 tol=5e-2, theta=0.9, beta=[0.9], eta_vec=[0.9, 1.1], alpha=1.0, fixed_time=False,
                 plot=False, verbose=True, rescale0=1e-1, momentum=None,saturate_tol=1e-9,integration_method='MC'):
        self.tol = tol
        self.plot = plot

        if integration_method == 'GHQ':
            device = x_true.device
            z, weights = roots_hermite(int_samples)
            z = torch.tensor(z, dtype=torch.float64, device=device)
            weights = torch.tensor(weights, dtype=torch.float64, device=device)
            rootsGHQ = (z, weights)
        else:
            rootsGHQ = None

        super().__init__(f=f, x_true=x_true, delta=delta, int_samples=int_samples, t_vec=t_vec, max_iters=max_iters,
                         tol=self.tol, alpha=alpha, beta=beta, eta_vec=eta_vec, theta=theta, fixed_time=fixed_time,
                         verbose=verbose, rescale0=rescale0, momentum=momentum,saturate_tol=saturate_tol,integration_method=integration_method,
                         rootsGHQ=rootsGHQ)

    def plot_1d_descent(self, xk, xk_new, dim, domain=(-15, 15), num_points=1000):
        """
        Plots the 1D descent for the current dimension.

        Args:
            xk (torch.Tensor): Current position.
            dim (int): The current dimension being optimized.
            domain (tuple): The range over which to vary the current dimension.
            num_points (int): Number of points to sample in the domain.
        """
        x_vals = np.linspace(domain[0], domain[1], num_points)
        f_vals = []
        h_vals = []

        for x in x_vals:
            xk_varied = xk.clone()
            xk_varied[dim] = x
            f_vals.append(self.f(xk_varied.unsqueeze(0)).item())
            h_vals.append(self.f(xk_varied.unsqueeze(0)).item() + 1/(2*self.t_vec[0])*torch.norm(xk_varied-xk)**2)

        std_dev = np.sqrt(self.delta * self.t_vec[0]/self.rescale0)
        std_dev_minus = xk.clone()
        std_dev_plus = xk.clone()
        std_dev_minus[dim] -= std_dev
        std_dev_plus[dim] += std_dev
        # print(f"{self.delta} * {self.t_vec[0]} = {self.delta * self.t_vec[0]}")
        # print(f'Std Dev: {std_dev}, Std Dev Minus: {std_dev_minus}, Std Dev Plus: {std_dev_plus}')

        plt.figure()
        plt.plot(x_vals, f_vals, '-', color='black', label=f'f(x) Dimension {dim + 1}')
        plt.plot(x_vals, h_vals, '-', color='blue', label=r'$f(x) + \frac{1}{2t_0} ||x - x_k||^2$')
        plt.plot(xk[dim], self.f(xk.unsqueeze(0)).item(), '*', color='red', label=f'xk Dimension {dim + 1}')
        plt.plot(xk_new[dim], self.f(xk_new.unsqueeze(0)).item(), '*', color='green', label=f'New xk Dimension {dim + 1}')
        plt.plot(std_dev_minus[dim].item(), self.f(std_dev_minus.unsqueeze(0)).item(), 'x', color='purple', label='Std Devs')
        plt.plot(std_dev_plus[dim].item(), self.f(std_dev_plus.unsqueeze(0)).item(), 'x', color='purple')
        plt.xlabel(f'Dimension {dim + 1}')
        plt.ylabel('Function Value')
        plt.title(f'1D Descent for Dimension {dim + 1}')
        plt.legend()
        plt.show()

    def run(self, x0, num_cycles=10):
        """
        Runs the coordinate descent optimization process.

        Args:
            x0 (torch.Tensor): Initial guess for the minimizer.
            num_cycles (int): Number of cycles to run the coordinate descent.

        Returns:
            torch.Tensor: Optimal solution found by the coordinate descent.
            list: History of solutions for each cycle.
            list: History of the entire optimization process.
            list: Error history for each cycle.
        """
        xk = x0.clone()
        n_features = x0.shape[0]
        CD_xk_hist = torch.zeros(n_features*num_cycles+1, n_features)
        CD_xk_hist[0,:]    = xk
        full_xk_hist = []
        full_fk_hist = []
        full_xk_error_hist = []

        # x_opt, xk_hist, _, xk_error_hist, _, _ = super().run(xk,cd=True)
        # xk = x_opt.clone()
        # full_history.extend(xk_hist)
        # xk_error_hist_MAD.extend(xk_error_hist)

        for cycle in range(num_cycles):
            dims = list(range(n_features))
            #dims = [1, 0]

            # # Randomly select 50% of the dimensions
            # dims = random.sample(range(n_features), k=n_features // 2)
            dim_count=0
            for dim in dims:
                # if dim == 1:
                #     self.t_vec[0] = 800
                #     self.t_vec[1] = 0.1
                #     self.t_vec[2] = 1000
                #     self.delta = 0.1
                #     self.rescale0 = 1

                # Plot the 1D descent for the current dimension
                xk_prev = xk.clone()
                
                if self.verbose:
                    print(f"Cycle {cycle + 1}/{num_cycles} and Dimension {dim + 1}/{n_features}")

                # Optimize with respect to the first coordinate
                xk, xk_hist, tk_hist, xk_error_hist, rel_grad_uk_norm_hist, fk_hist = super().run(xk,cd=True, update_dim=dim)

                if self.plot:
                    self.plot_1d_descent(xk_prev,xk, dim)
     
                full_xk_hist.extend(xk_hist.numpy())
                full_xk_error_hist.extend(xk_error_hist.numpy())
                full_fk_hist.extend(fk_hist.numpy())

                if  xk_error_hist[-1] < self.tol:
                    print(f'HJ-MAD-CD converged. Error: {xk_error_hist[-1]:.3f}, tolerance: {self.tol}.')
                    CD_xk_hist[cycle+1,:]    = xk
                    X_opt = xk
                    return X_opt, CD_xk_hist,full_xk_hist, full_xk_error_hist, full_fk_hist
                
                CD_xk_hist[cycle+dim_count+1,:]    = xk
                dim_count+=1

            # if cycle > 0 and cycle % 3 == 0:
            #     self.int_samples *= 2

            # CD_xk_hist[cycle+1,:]    = xk

        X_opt = xk
        return X_opt, CD_xk_hist,full_xk_hist, full_xk_error_hist, full_fk_hist

In [None]:
import concurrent.futures

class HJ_MAD_CoordinateDescent_parallel(HJ_MAD):
    """
    Hamilton-Jacobi Moreau Adaptive Descent (HJ_MAD) Coordinate Descent for 2D functions.

    This class extends the HJ_MAD algorithm to perform coordinate descent for 2D functions. It alternates between 
    optimizing each coordinate while keeping the other fixed, treating the function as a 1D function for each run 
    and using the previous solution in the next run.

    Attributes:
        f (callable): The function to be minimized.
        x_true (torch.Tensor): The true global minimizer.
        delta (float): Coefficient of the viscous term in the HJ equation.
        int_samples (int): Number of samples used to approximate expectation in the heat equation solution.
        t_vec (list): Time vector containing [initial time, minimum time allowed, maximum time].
        max_iters (int): Maximum number of iterations.
        tol (float): Stopping tolerance.
        alpha (float): Step size.
        beta (float): Exponential averaging term for gradient beta.
        eta_vec (list): Vector containing [eta_minus, eta_plus].
        theta (float): Parameter used to update tk.
        fixed_time (bool): Whether to use adaptive time.
        verbose (bool): Whether to print progress.
        rescale0 (float): Initial rescale factor.
        momentum (float): Momentum term for acceleration.

    Methods:
        run(x0, num_cycles): Runs the coordinate descent optimization process.
    """

    def __init__(self, f, x_true, delta=0.1, int_samples=100, t_vec=[1.0, 1e-3, 1e1], max_iters=5e4,
                 tol=5e-2, theta=0.9, beta=[0.9], eta_vec=[0.9, 1.1], alpha=1.0, fixed_time=False,
                 verbose=True,plot=False, rescale0=1e-1, momentum=None, saturate_tol=1e-9):
        self.tol = tol
        self.plot = plot
        super().__init__(f=f, x_true=x_true, delta=delta, int_samples=int_samples, t_vec=t_vec, max_iters=max_iters,
                         tol=self.tol, alpha=alpha, beta=beta, eta_vec=eta_vec, theta=theta, fixed_time=fixed_time,
                         verbose=verbose, rescale0=rescale0, momentum=momentum, saturate_tol=saturate_tol)

    def plot_1d_descent(self, xk, dim, domain=(-10, 10), num_points=100):
        """
        Plots the 1D descent for the current dimension.

        Args:
            xk (torch.Tensor): Current position.
            dim (int): The current dimension being optimized.
            domain (tuple): The range over which to vary the current dimension.
            num_points (int): Number of points to sample in the domain.
        """
        x_vals = np.linspace(domain[0], domain[1], num_points)
        f_vals = []

        for x in x_vals:
            xk_varied = xk.clone()
            xk_varied[dim] = x
            f_vals.append(self.f(xk_varied.unsqueeze(0)).item())

        std_dev = np.sqrt(self.delta * self.t_vec[0])
        std_dev_minus = xk.clone()
        std_dev_plus = xk.clone()
        std_dev_minus[dim] -= std_dev
        std_dev_plus[dim] += std_dev

        plt.figure()
        plt.plot(x_vals, f_vals, '-', color='black', label=f'f(x) Dimension {dim + 1}')
        plt.plot(xk[dim].item(), self.f(xk.unsqueeze(0)).item(), '*', color='red', label=f'xk Dimension {dim + 1}')
        plt.plot(std_dev_minus[dim].item(), self.f(std_dev_minus.unsqueeze(0)).item(), 'x', color='purple', label='Std Dev Minus')
        plt.plot(std_dev_plus[dim].item(), self.f(std_dev_plus.unsqueeze(0)).item(), '*', color='purple', label='Std Dev Plus')
        plt.xlabel(f'Dimension {dim + 1}')
        plt.ylabel('Function Value')
        plt.title(f'1D Descent for Dimension {dim + 1}')
        plt.legend()
        plt.show()

    def optimize_dimension(self, xk, dim):
        """
        Optimize with respect to a single dimension.

        Args:
            xk (torch.Tensor): Current position.
            dim (int): The dimension to optimize.

        Returns:
            torch.Tensor: Updated position for the dimension.
            list: History of positions for the dimension.
            list: History of errors for the dimension.
            list: History of function values for the dimension.
        """
        xk, xk_hist, tk_hist, xk_error_hist, rel_grad_uk_norm_hist, fk_hist = super().run(xk, cd=True, update_dim=dim)
        return xk, xk_hist, xk_error_hist, fk_hist

    def run(self, x0, num_cycles=10):
        """
        Runs the coordinate descent optimization process.

        Args:
            x0 (torch.Tensor): Initial guess for the minimizer.
            num_cycles (int): Number of cycles to run the coordinate descent.

        Returns:
            torch.Tensor: Optimal solution found by the coordinate descent.
            list: History of solutions for each cycle.
            list: History of the entire optimization process.
            list: Error history for each cycle.
        """
        xk = x0.clone()
        n_features = x0.shape[0]
        CD_xk_hist = torch.zeros(n_features * num_cycles + 1, n_features)
        CD_xk_hist[0, :] = xk
        full_xk_hist = []
        full_fk_hist = []
        full_xk_error_hist = []

        for cycle in range(num_cycles):
            dims = list(range(n_features))
            dim_count = 0
            print("Cycle", cycle + 1)
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futures = {executor.submit(self.optimize_dimension, xk.clone(), dim): dim for dim in dims}
                results = {dim: future.result() for future, dim in futures.items()}

            for dim, (xk_dim, xk_hist, xk_error_hist, fk_hist) in results.items():
                if self.verbose:
                    print(f"Cycle {cycle + 1}/{num_cycles} and Dimension {dim + 1}/{n_features}")

                if self.plot:
                    self.plot_1d_descent(xk, dim)

                xk[dim] = xk_dim[dim]
                full_xk_hist.extend(xk_hist.numpy())
                full_xk_error_hist.extend(xk_error_hist.numpy())
                full_fk_hist.extend(fk_hist.numpy())

                if xk_error_hist[-1] < self.tol:
                    if self.verbose:
                        print(f'HJ-MAD-CD converged. Error: {xk_error_hist[-1]:.3f}, tolerance: {self.tol}.')
                    CD_xk_hist[cycle + 1, :] = xk
                    X_opt = xk
                    return X_opt, CD_xk_hist, full_xk_hist, full_xk_error_hist, full_fk_hist

                CD_xk_hist[cycle + dim_count + 1, :] = xk
                dim_count += 1
            print(f"Error = {torch.norm(xk - self.x_true)}")
            
            if xk_error_hist[-1] < self.tol:
                if self.verbose:
                    print(f'HJ-MAD-CD converged. Error: {xk_error_hist[-1]:.3f}, tolerance: {self.tol}.')
                CD_xk_hist[cycle + 1, :] = xk
                X_opt = xk
                return X_opt, CD_xk_hist, full_xk_hist, full_xk_error_hist, full_fk_hist
        X_opt = xk
        return X_opt, CD_xk_hist, full_xk_hist, full_xk_error_hist, full_fk_hist

### Set up hyperparameters for HJ-MAD for different functions

In [None]:
# Default values
delta         = 5e-3
max_iters     = 1000 #int(1e5)
tol           = 5e-2#7e-4
momentum      = 0.64
rescale0      = 0.5
# Set the number of trials to run
avg_trials = 1
sat_tol = 1e-10

# # def f(x):
# #   return MultiMinimaFunc(x)
# # def f_numpy(x):
# #   return MultiMinimaFunc_numpy(x)
# # ax_bry  = 30
# # f_name  = 'MultiMinimaFunc'
# # dim = 1; int_samples = int(100);
# # x0      = -30*torch.ones(dim, dtype=torch.double)
# # x_true  = -1.51034569*torch.ones(dim, dtype=torch.double)

# # delta         = 0.1
# # max_iters     = int(100)
# # tol           = 1e-3
# # momentum = 0.5

# # theta         = 1.0 # note: larger theta => easier to increase time
# # beta          = 0.0
# # t_min     = 1e-1
# # t_max     = 300
# # t_init    = 220
# # alpha     = 0.1
# # eta_min = 0.99
# # eta_plus = 5.0
# # eta_vec = [eta_min, eta_plus]


# # # ----------------------------------------------------------------------------------------------------

# def f(x):
#   return Griewank(x)
# def f_numpy(x):
#   return Griewank_numpy(x)
# ax_bry  = 20
# f_name  = 'Griewank'
# dim = 2; int_samples = int(10000);
# x0      = 10*torch.ones(dim, dtype=torch.double)
# x_true  = torch.zeros(dim, dtype=torch.double)
# rescale0      = 1
# delta         = 1e-6
# max_iters     = int(1e4)
# tol           = 1e-4
# momentum = 0.0

# theta         = 1.0 # note: larger theta => easier to increase time
# beta          = 0.9
# t_min     = 1e-2/delta
# t_max     = int(2)/delta
# t_init    = 1e-2/delta
# alpha     = 5e-2
# eta_min = 0.99
# eta_plus = 5.0
# eta_vec = [eta_min, eta_plus]

# # # ----------------------------------------------------------------------------------------------------

# # def f(x):
# #   return Griewank(x)
# # def f_numpy(x):
# #   return Griewank_numpy(x)
# # ax_bry  = 20
# # f_name  = 'Griewank'
# # dim = 20; int_samples = int(1000);#int(1000000);
# # x0      = 10*torch.ones(dim, dtype=torch.double)
# # x_true  = torch.zeros(dim, dtype=torch.double)
# # rescale0      = 2**(-15)
# # delta         = 1e-6
# # max_iters     = int(1e5)
# # tol           = 5e-2
# # momentum = 0.64

# # theta         = 1.0 # note: larger theta => easier to increase time
# # beta          = 0.9
# # t_min     = 1e-1/delta
# # t_max     = int(2e1)/delta
# # t_init    = 1e-1/delta
# # alpha     = 5e-2
# # eta_min = 0.99
# # eta_plus = 5.0
# # eta_vec = [eta_min, eta_plus]

# # ----------------------------------------------------------------------------------------------------
# # def f(x):
# #   return Griewank(x)
# # def f_numpy(x):
# #   return Griewank_numpy(x)
# # f_name  = 'Griewank'
# # dim = 200; int_samples = int(100); # this one has higher dimension
# # x0      = 10*torch.ones(dim, dtype=torch.double)
# # x_true  = torch.zeros(dim, dtype=torch.double)
# # rescale0      = 2**(-5)
# # sat_tol = 1e-9 #7e-7 or 7e-10 (not sure) for 100 dims, 7e-8 for less than 100 dims
# # theta     = 1.0 # note: larger theta => easier to increase time
# # beta      = 0.9
# # # momentum  = 0.5
# # # beta      = 0.0
# # momentum  = 0.0
# # t_min     = 2e1
# # t_max     = 1e5
# # t_init    = 2e1
# # alpha     = 1.2
# # eta_min = 0.5
# # eta_plus = 5.0
# # eta_vec = [eta_min, eta_plus]

# # ----------------------------------------------------------------------------------------------------

def f(x):
  return Griewank(x)
def f_numpy(x):
  return Griewank_numpy(x)
f_name  = 'Griewank'
dim = 500; int_samples = int(100); 
x0      = 10*torch.ones(dim, dtype=torch.double)
x_true  = torch.zeros(dim, dtype=torch.double)
rescale0      = 2**(-6)#256
sat_tol = 1e-9 #7e-7 or 7e-10 (not sure) for 100 dims, 7e-8 for less than 100 dims
theta     = 1.0 # note: larger theta => easier to increase time
beta      = 0.9
momentum  = 0.0
t_min     = 2e1
t_max     = 1e5
t_init    = 2e1
alpha     = 1.2
eta_min = 0.5
eta_plus = 5.0
eta_vec = [eta_min, eta_plus]

# # ----------------------------------------------------------------------------------------------------

# def f(x):
#   return Drop_Wave(x)
# def f_numpy(x):
#   return Drop_Wave_numpy(x)
# ax_bry  = 20
# max_iters     = 1000
# f_name  = 'Drop_Wave'
# rescale0      = 1
# dim = 2; int_samples = int(10000)
# x0      = 10*torch.ones(dim, dtype=torch.double)
# x_true  = torch.zeros(dim, dtype=torch.double)

# momentum      = 0.5
# delta         = 1e-4
# theta         = 1.0 # note: larger theta => easier to increase time
# beta          = 0.8
# t_min     = 1e-6
# t_max     = int(2e1)/delta
# t_init    = 1e3
# alpha     = 0.5
# eta_min = 0.5
# eta_plus = 5.0
# eta_vec = [eta_min, eta_plus]

# # # ----------------------------------------------------------------------------------------------------

# def f(x):
#   return AlpineN1(x)
# def f_numpy(x):
#   return AlpineN1_numpy(x)
# ax_bry  = 20
# f_name  = 'AlpineN1'

# dim = 2; int_samples = int(100000)# int(10000);
# x0      = 10*torch.ones(dim, dtype=torch.double)
# x_true  = torch.zeros(dim, dtype=torch.double)

# momentum      = 0.45

# theta         = 1.0 # note: larger theta => easier to increase time
# beta          = 0.0
# # t_max     = int(2e1)/delta
# # t_init    = 1e-3
# # t_min     = t_init
# t_max     = int(2e3)/delta
# t_init    = 1e-3
# t_min     = 1e-4
# alpha     = 0.25
# eta_min = 0.6
# eta_plus = 5.0
# eta_vec = [eta_min, eta_plus]

# # ----------------------------------------------------------------------------------------------------

# def f(x):
#   return Levy(x)
# def f_numpy(x):
#   return Levy_numpy(x)
# ax_bry  = 20
# f_name  = 'Levy'

# # Set the number of trials to run
# rescale0 = 2**(-7)
# tol           = 5e-2
# sat_tol = 1e-12
# max_iters     = 1000

# dim = 2; int_samples = int(500000)
# x0      = -15*torch.ones(dim, dtype=torch.double)
# x_true  = torch.ones(dim, dtype=torch.double)

# theta         = 0.9 # note: larger theta => easier to increase time
# beta          = 0.5

# t_max     = int(2e5)/delta
# t_init    = 1e6
# t_min     = 1e2
# alpha     = 0.25
# eta_min = 0.6
# eta_plus = 1.5
# eta_vec = [eta_min, eta_plus]

# # ----------------------------------------------------------------------------------------------------

# def f(x):
#   return Rastrigin(x)
# def f_numpy(x):
#   return Rastrigin_numpy(x)
# ax_bry  = 20
# f_name  = 'Rastrigin'
# delta=5e-3
# dim = 2; int_samples = int(10000);
# x0      = 10*torch.ones(dim, dtype=torch.double)
# x_true  = torch.zeros(dim, dtype=torch.double)
# momentum      = 0.25
# theta         = 1.0 # note: larger theta => easier to increase time
# beta          = 0.5
# t_max     = int(2e1)/delta
# t_init    = 5.0
# t_min     = t_init
# alpha     = 0.5
# eta_min = 0.5
# eta_plus = 5.0
# eta_vec = [eta_min, eta_plus]
# tol=2e-10

# # # ----------------------------------------------------------------------------------------------------

# def f(x):
#   return Ackley(x)
# def f_numpy(x):
#   return Ackley_numpy(x)
# ax_bry  = 20
# f_name  = 'Ackley'

# dim = 2; int_samples = int(100000);
# x0      = 10*torch.ones(dim, dtype=torch.double)
# x_true  = torch.zeros(dim, dtype=torch.double)
# momentum      = 0.25
# theta         = 1.0 # note: larger theta => easier to increase time
# beta          = 0.9
# t_max     = int(2e1)/delta
# t_init    = 1e-3
# t_min     = t_init
# alpha     = 5e-1
# eta_min = 0.5
# eta_plus = 5.0
# eta_vec = [eta_min, eta_plus]

Run HJ-MAD and average its results over avg_trials trials

In [None]:
# Note: Under this transformation the standard deviation of the Gaussian is 1, hence we have more control over t and delta

# Create an instance of HJ_MAD_CoordinateDescent
if f_name == 'Ackley':
    hj_mad_cd_GHQ = HJ_MAD_CoordinateDescent(f, x_true, delta=delta*1e-10,
                    int_samples=int(1000), t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.0,saturate_tol=sat_tol,integration_method='GHQ')
elif f_name == 'Griewank' and dim == 500:
    tol           = 5e-2
    hj_mad_cd_GHQ = HJ_MAD_CoordinateDescent(f, x_true, delta= 1e-8,
                    int_samples=int(20), t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.0,saturate_tol=1e-10,integration_method='GHQ')

    hj_mad_cd_MC = HJ_MAD_CoordinateDescent(f, x_true, delta=delta,
                    int_samples=int(1000), t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=momentum,saturate_tol=sat_tol,integration_method="MC")

    hj_mad_cd_NMC = HJ_MAD_CoordinateDescent(f, x_true, delta=1e-6,
                    int_samples=int(1000), t_vec=[t_init*1e4, t_min, t_max*1e5], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=momentum,saturate_tol=sat_tol,integration_method="NMC")
    

elif f_name == 'Rastrigin':
    tol = 1e-4
    hj_mad_cd_GHQ = HJ_MAD_CoordinateDescent(f, x_true, delta=delta,
                    int_samples=int(10000), t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.0,saturate_tol=sat_tol,integration_method='GHQ')

    hj_mad_cd_MC = HJ_MAD_CoordinateDescent(f, x_true, delta=delta,
                    int_samples=int(5000), t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.0,saturate_tol=sat_tol*1e5,integration_method="MC")

    hj_mad_cd_NMC = HJ_MAD_CoordinateDescent(f, x_true, delta=5e-5,
                    int_samples=int(5000), t_vec=[20, 1e-5, 1e5], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.1,saturate_tol=sat_tol*1e5,integration_method="NMC")
elif f_name == 'Levy': # All have been tuned
    hj_mad_cd_GHQ = HJ_MAD_CoordinateDescent(f, x_true, delta=1e-13,
                    int_samples=int(80), t_vec=[0.5, 0.5, 1], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.6,saturate_tol=sat_tol,integration_method='GHQ')
    
    hj_mad_cd_MC = HJ_MAD_CoordinateDescent(f, x_true, delta=delta,
                    int_samples=int(100000), t_vec=[t_init, t_min, t_max], max_iters=10, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=momentum,saturate_tol=sat_tol,integration_method="MC")

    hj_mad_cd_NMC = HJ_MAD_CoordinateDescent(f, x_true, delta=1e-2,
                        int_samples=int(10), t_vec=[10, 1e-2, 1000], max_iters=10, tol=tol, alpha=alpha,
                        beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                        momentum=0.5,saturate_tol=sat_tol,integration_method="NMC")
elif f_name == 'AlpineN1': # None are tuned
    hj_mad_cd_GHQ = HJ_MAD_CoordinateDescent(f, x_true, delta=5e-15,
                    int_samples=int(1000), t_vec=[50, 10, 60], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.0,saturate_tol=sat_tol,integration_method='GHQ')
    
    hj_mad_cd_MC = HJ_MAD_CoordinateDescent(f, x_true, delta=delta,
                    int_samples=int_samples, t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.0,saturate_tol=sat_tol,integration_method="MC")

    hj_mad_cd_NMC = HJ_MAD_CoordinateDescent(f, x_true, delta=delta,
                    int_samples=int(50), t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.0,saturate_tol=sat_tol,integration_method="NMC")
    
elif f_name == 'Drop_Wave':
    tol = 5e-2
    hj_mad_cd_GHQ = HJ_MAD_CoordinateDescent(f, x_true, delta=1e-21,
                    int_samples=int(1000), t_vec=[1e6, 1e4, 1e6], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.4,saturate_tol=sat_tol,integration_method='GHQ')
else:
    hj_mad_cd_GHQ = HJ_MAD_CoordinateDescent(f, x_true, delta=delta,
                    int_samples=int_samples, t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=False, verbose=True,rescale0=rescale0,
                    momentum=0.0,saturate_tol=sat_tol,integration_method='GHQ')


In [None]:
# # Initialize accumulators for averages
# avg_func_evals = 0
# sum_elapsed_time = 0
# total_iterations = 0  # To store total iterations across trials

# # Run the specified number of trials
# for _ in range(avg_trials):
#     start_time = time.time()  # Record the start time

#     # Execute the HJ_MAD_CD algorithm and retrieve results
#     x_opt_cd_GHQ, coordinate_wise_xk_hist_GHQ, xk_hist_cd_GHQ, xk_error_hist_cd_GHQ, fk_hist_cd_GHQ = hj_mad_cd_GHQ.run(x0, num_cycles=20)

#     elapsed_time = time.time() - start_time  # Calculate elapsed time
#     sum_elapsed_time += elapsed_time  # Accumulate elapsed time

#     total_iterations += len(xk_error_hist_cd_GHQ)  # Add iterations used in this trial
#     avg_func_evals += len(xk_error_hist_cd_GHQ) * int_samples  # Update average function evaluations

#     print(f"Elapsed time: {elapsed_time:.4f} seconds")  # Print elapsed time for the current trial


# # Compute averages after all trials
# avg_func_evals /= avg_trials  # Average function evaluations per trial
# average_iterations = total_iterations / avg_trials  # Average number of iterations per trial

# # Output results
# # print('\n\n avg_func_evals = ', avg_func_evals)
# print(f"Average iterations before convergence/stopping: {average_iterations:.2f}")
# print(f"Average elapsed time: {sum_elapsed_time / avg_trials:.4f} seconds")

In [None]:


# # Initialize accumulators for averages
# avg_func_evals = 0
# sum_elapsed_time = 0
# total_iterations = 0  # To store total iterations across trials

# # Run the specified number of trials
# for _ in range(avg_trials):
#     start_time = time.time()  # Record the start time

#     # Execute the HJ_MAD_CD algorithm and retrieve results
#     x_opt_cd_MC, coordinate_wise_xk_hist_MC, xk_hist_cd_MC, xk_error_hist_cd_MC, fk_hist_cd_MC = hj_mad_cd_MC.run(x0, num_cycles=20)

#     elapsed_time = time.time() - start_time  # Calculate elapsed time
#     sum_elapsed_time += elapsed_time  # Accumulate elapsed time

#     total_iterations += len(xk_error_hist_cd_MC)  # Add iterations used in this trial
#     avg_func_evals += len(xk_error_hist_cd_MC) * int_samples  # Update average function evaluations

#     print(f"Elapsed time: {elapsed_time:.4f} seconds")  # Print elapsed time for the current trial


# # Compute averages after all trials
# avg_func_evals /= avg_trials  # Average function evaluations per trial
# average_iterations = total_iterations / avg_trials  # Average number of iterations per trial

# # Output results
# # print('\n\n avg_func_evals = ', avg_func_evals)
# print(f"Average iterations before convergence/stopping: {average_iterations:.2f}")
# print(f"Average elapsed time: {sum_elapsed_time / avg_trials:.4f} seconds")

In [None]:


# # Initialize accumulators for averages
# avg_func_evals = 0
# sum_elapsed_time = 0
# total_iterations = 0  # To store total iterations across trials

# # Run the specified number of trials
# for _ in range(avg_trials):
#     start_time = time.time()  # Record the start time

#     # Execute the HJ_MAD_CD algorithm and retrieve results
#     x_opt_cd_NMC, coordinate_wise_xk_hist_NMC, xk_hist_cd_NMC, xk_error_hist_cd_NMC, fk_hist_cd_NMC = hj_mad_cd_NMC.run(x0, num_cycles=20)

#     elapsed_time = time.time() - start_time  # Calculate elapsed time
#     sum_elapsed_time += elapsed_time  # Accumulate elapsed time

#     total_iterations += len(xk_error_hist_cd_NMC)  # Add iterations used in this trial
#     avg_func_evals += len(xk_error_hist_cd_NMC) * int_samples  # Update average function evaluations

#     print(f"Elapsed time: {elapsed_time:.4f} seconds")  # Print elapsed time for the current trial


# # Compute averages after all trials
# avg_func_evals /= avg_trials  # Average function evaluations per trial
# average_iterations = total_iterations / avg_trials  # Average number of iterations per trial

# # Output results
# # print('\n\n avg_func_evals = ', avg_func_evals)
# print(f"Average iterations before convergence/stopping: {average_iterations:.2f}")
# print(f"Average elapsed time: {sum_elapsed_time / avg_trials:.4f} seconds")

In [None]:
# # Create an instance of HJ_MAD_CoordinateDescent
# hj_mad_cd = HJ_MAD_CoordinateDescent_parallel(f, x_true, delta=delta,
#                     int_samples=int_samples, t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol, alpha=alpha,
#                     beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False,plot=True, verbose=False,rescale0=rescale0,
#                     momentum=momentum,saturate_tol=sat_tol,integration_method='MC')

# # Initialize accumulators for averages
# avg_func_evals = 0
# sum_elapsed_time = 0
# total_iterations = 0  # To store total iterations across trials

# # Run the specified number of trials
# for _ in range(avg_trials):
#     start_time = time.time()  # Record the start time

#     # Execute the HJ_MAD_CD algorithm and retrieve results
#     x_opt_cd_para, coordinate_wise_xk_hist_para, xk_hist_cd_para, xk_error_hist_cd_para, fk_hist_cd_para = hj_mad_cd.run(x0, num_cycles=20)

#     elapsed_time = time.time() - start_time  # Calculate elapsed time
#     sum_elapsed_time += elapsed_time  # Accumulate elapsed time

#     total_iterations += len(xk_error_hist_cd_para)  # Add iterations used in this trial
#     avg_func_evals += len(xk_error_hist_cd_para) * int_samples  # Update average function evaluations

#     print(f"Elapsed time: {elapsed_time:.4f} seconds")  # Print elapsed time for the current trial


# # Compute averages after all trials
# avg_func_evals /= avg_trials  # Average function evaluations per trial
# average_iterations = total_iterations / avg_trials  # Average number of iterations per trial

# # Output results
# print('\n\n avg_func_evals = ', avg_func_evals)
# print(f"Average iterations before convergence/stopping: {average_iterations:.2f}")
# print(f"Average elapsed time: {sum_elapsed_time / avg_trials:.4f} seconds")

In [None]:
int_samples = int(100000);
rescale0=1
# delta         = 5e-7
# t_min     = 2e1
# t_max     = 1e6
# t_init    = 2e3#2e1
delta         = 5e-7
t_min     = 2e1
t_max     = 1e5
t_init    = 2e3#2e1
HJ_MAD_alg = HJ_MAD(f, x_true, delta=delta,
                    int_samples=int_samples, t_vec=[t_init, t_min, t_max], max_iters=max_iters, tol=tol*1e-1, alpha=alpha,
                    beta=beta, eta_vec = eta_vec, theta=theta, fixed_time=False, verbose=True,rescale0=rescale0,momentum=0.5,
                    integration_method='NMC')
# Initialize accumulators for averages
avg_func_evals = 0
sum_elapsed_time = 0
total_iterations = 0  # To store total iterations across trials

# Run the specified number of trials
for _ in range(avg_trials):
    #x0 = 10*torch.ones(dim, dtype=torch.double)
    start_time = time.time()  # Record the start time

    # Execute the HJ_MAD algorithm and retrieve results
    x_opt_MAD, xk_hist_MAD, tk_hist_MAD, xk_error_hist_MAD, rel_grad_uk_norm_hist_MAD, fk_hist_MAD = HJ_MAD_alg.run(x0)

    elapsed_time = time.time() - start_time  # Calculate elapsed time
    sum_elapsed_time += elapsed_time  # Accumulate elapsed time
    
    total_iterations += len(xk_error_hist_MAD)  # Add iterations used in this trial
    avg_func_evals += len(xk_error_hist_MAD) * int_samples  # Update average function evaluations

    print(f"Elapsed time: {elapsed_time:.4f} seconds")  # Print elapsed time for the current trial

# Compute averages after all trials
avg_func_evals /= avg_trials  # Average function evaluations per trial
average_iterations = total_iterations / avg_trials  # Average number of iterations per trial

# Output results
print('\n\n avg_func_evals = ', avg_func_evals)
print(f"Average iterations before convergence: {average_iterations:.2f}")
print(f"Average elapsed time: {sum_elapsed_time / avg_trials:.4f} seconds")

### Generate Convergence Histories and Optimization Path Plots

In [None]:

title_fontsize = 22
fontsize       = 18
fig1 = plt.figure()

plt.style.use('seaborn-whitegrid')
ax = plt.axes()

ax.semilogy(xk_error_hist_MAD, color='purple', linewidth=3,label='HJ-MAD(NMC)');
#ax.semilogy(xk_error_hist_cd_NMC, color='red', linewidth=3,label='HJ-MAD-NMC');
#ax.semilogy(xk_error_hist_cd_GHQ, color='blue', linewidth=3,label='HJ-MAD-CD-GHQ');
#ax.semilogy(xk_error_hist_cd_MC, color='green', linewidth=3,label='HJ-MAD-CD-MC');
# ax.semilogy(xk_error_hist_EGD[0:len(xk_error_hist_GD)], 'm-', linewidth=3)
#ax.semilogy(xk_error_hist_GD[0:len(xk_error_hist_GD)], 'g-', linewidth=3)
ax.set_title(f'Dims={dim},Func={f_name},\n Adaptive Rescale Factor', fontsize=title_fontsize)
ax.set_xlabel("Iterations", fontsize=title_fontsize)
ax.set_ylabel("Errors", fontsize=title_fontsize)
ax.legend(fontsize=fontsize)
# title_str = 'Relative Errors'
# ax.set_title(title_str, fontsize=title_fontsize)
ax.tick_params(labelsize=fontsize, which='both', direction='in')

# save_str = 'griewank_error_hist.png'
# fig1.savefig(save_str, dpi=300 , bbox_inches="tight", pad_inches=0.0)

In [None]:
fig1 = plt.figure()

plt.style.use('seaborn-whitegrid')
ax = plt.axes()
ax.semilogy(fk_hist_MAD, color='red', linewidth=3,label='HJ-MAD');
ax.semilogy(fk_hist_cd_NMC, color='red', linewidth=3,label='HJ-MAD-NMC');
ax.semilogy(fk_hist_cd_GHQ, color='blue', linewidth=3,label='HJ-MAD-CD-GHQ');
ax.semilogy(fk_hist_cd_MC, color='green', linewidth=3,label='HJ-MAD-CD-MC');

ax.set_xlabel("Iterations", fontsize=title_fontsize)
ax.set_ylabel("fk", fontsize=title_fontsize)
ax.legend(fontsize=fontsize)
title_str = 'Objective Function Values'
ax.set_title(title_str, fontsize=title_fontsize)
ax.tick_params(labelsize=fontsize, which='both', direction='in')

# save_str = 'griewank_func_hist.png'
# fig1.savefig(save_str, dpi=300 , bbox_inches="tight", pad_inches=0.0)

## 2D Plots


In [None]:
if dim == 2:

  if f_name == 'Levy':
    x0      = -15*torch.ones(dim, dtype=torch.double)
    x_true  = torch.ones(dim, dtype=torch.double)
  else:
    x0      = 10*torch.ones(dim, dtype=torch.double)
    x_true  = torch.zeros(dim, dtype=torch.double)

  surface_plot_resolution = 50
  x = np.linspace(-ax_bry, ax_bry, surface_plot_resolution)
  y = np.linspace(-ax_bry, ax_bry, surface_plot_resolution)

  X, Y = np.meshgrid(x, y)
  n_features = 2

  t_final = t_max

  Z                 = np.zeros(X.shape)
  Z_MAD             = np.zeros(X.shape)

  for i in range(X.shape[0]):
    for j in range(X.shape[1]):
      Z[i,j] = f(torch.FloatTensor([X[i,j],Y[i,j]]).view(1,n_features))  
     

  fig, ax = plt.subplots(1, 1)
  im = ax.contourf(X, Y, Z, 20, cmap=plt.get_cmap('gray'))
  plt.style.use('default')

  title_fontsize = 22
  fontsize       = 15

  ax.plot(np.vstack(xk_hist_cd_MC)[:,0], np.vstack(xk_hist_cd_MC)[:,1], '-o', color='blue',label='HJ-MAD-CD-MC')
  ax.plot(np.vstack(xk_hist_cd_GHQ)[:,0], np.vstack(xk_hist_cd_GHQ)[:,1], 'm-o', label='HJ-MAD-CD-GHQ')
  ax.plot(np.vstack(xk_hist_cd_NMC)[:,0], np.vstack(xk_hist_cd_NMC)[:,1], 'm-o',label='HJ-MAD-CD-NMC')


  ax.plot(x_true[0], x_true[1], 'rx', markeredgewidth=3, markersize=12,label='global min')
  ax.plot(x0[0], x0[1], 'kx', markeredgewidth=3, markersize=12,label='initial guess')

  ax.legend(fontsize=12, facecolor='white', markerfirst=False, loc='lower right')

  ax.set_xlim(-ax_bry,ax_bry)
  cb = plt.colorbar(im)

  # save_loc = 'optimization_paths.png'
  # plt.savefig(save_loc,bbox_inches='tight')
  plt.show()

## Interactive 2D Plot


In [None]:
if dim == 2:
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt
    import numpy as np

    ax_bry_3D_plot = 20
    surface_plot_resolution = 50
    x = np.linspace(-ax_bry_3D_plot, ax_bry_3D_plot, surface_plot_resolution)
    y = np.linspace(-ax_bry_3D_plot, ax_bry_3D_plot, surface_plot_resolution)

    X, Y = np.meshgrid(x, y)
    # Convert PyTorch tensors to NumPy
    # xk_hist_MAD_np = xk_hist_MAD.numpy()
    # coordinate_wise_xk_hist_np = np.vstack(xk_hist_cd)

    # # Ensure z_values are scalars
    # HJ_MAD_f_values = np.array([
    #     f(torch.FloatTensor([[xk_hist_MAD_np[i, 0], xk_hist_MAD_np[i, 1]]])).item()
    #     for i in range(len(xk_hist_MAD_np))
    # ])

    HJ_MAD_CD_f_values = np.array([
        f(torch.FloatTensor([[np.vstack(xk_hist_cd_MC)[i, 0], np.vstack(xk_hist_cd_MC)[i, 1]]])).item()
        for i in range(len(np.vstack(xk_hist_cd_MC)))
    ])

    # Global minimum and initial guess
    if x_true.dim() == 1:
        x_true = x_true.unsqueeze(0)
    global_min_f = f(x_true).item()

    # Initial guess point
    if x0.dim() == 1:
        x0 = x0.unsqueeze(0)
    f_initial = f(x0).item()


In [None]:
if dim == 2:
    import plotly.graph_objects as go

    # Create surface trace
    surface_trace = go.Surface(
        z=Z, x=X, y=Y, colorscale='Viridis', showscale=True, name='Surface'
    )

    # Create optimization paths
    # HJ_MAD_trace = go.Scatter3d(
    #     x=xk_hist_MAD_np[:, 0],
    #     y=xk_hist_MAD_np[:, 1],
    #     z=HJ_MAD_f_values,
    #     mode='lines+markers',
    #     marker=dict(size=5, color='red'),
    #     line=dict(color='red', width=3),
    #     name='HJ-MAD'
    # )

    HJ_MAD_CD_trace = go.Scatter3d(
        x=np.vstack(xk_hist_cd_MC)[:, 0],
        y=np.vstack(xk_hist_cd_MC)[:, 1],
        z=HJ_MAD_CD_f_values,
        mode='lines+markers',
        marker=dict(size=5, color='blue'),
        line=dict(color='blue', width=3),
        name='HJ-MAD-CD'
    )

    # Global minimum point
    global_min_trace = go.Scatter3d(
        x=[x_true[0, 0].item()],
        y=[x_true[0, 1].item()],
        z=[global_min_f],
        mode='markers',
        marker=dict(size=8, color='black', symbol='x'),
        name='Global min'
    )

    # Initial guess point
    initial_guess_trace = go.Scatter3d(
        x=[x0[0, 0].item()],
        y=[x0[0, 1].item()],
        z=[f_initial],
        mode='markers',
        marker=dict(size=8, color='green', symbol='x'),
        name='Initial guess'
    )

    # Combine traces
    fig = go.Figure(data=[surface_trace, HJ_MAD_CD_trace, global_min_trace, initial_guess_trace])

    # Set layout details
    fig.update_layout(
        title="Interactive 3D Optimization Path",
        scene=dict(
            xaxis_title="X-axis",
            yaxis_title="Y-axis",
            zaxis_title="f-axis",
        ),
        margin=dict(l=0, r=0, t=40, b=0),
        legend=dict(
            x=0.02,  # Adjust the x position of the legend
            y=0.98,  # Adjust the y position of the legend
            bgcolor='rgba(255, 255, 255, 0.5)',  # Set background color with transparency
        )
    )

    # Show interactive plot
    fig.show(renderer="notebook")

In [None]:
if dim == 2:

    # Create the 3D plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot the surface
    ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='viridis', edgecolor='none', zorder=1)

    # Plot the HJ-MAD optimization path
    # ax.plot(xk_hist_MAD_np[:, 0], xk_hist_MAD_np[:, 1], HJ_MAD_f_values, '-o', color='red', label="HJ-MAD", zorder=2)

    # Plot the HJ-MAD-CD optimization path
    ax.plot(np.vstack(xk_hist_cd_MC)[:, 0], np.vstack(xk_hist_cd_MC)[:, 1], HJ_MAD_CD_f_values, '-o', color='blue', label="HJ-MAD-CD", zorder=2)

    ax.plot(
        [x_true[0, 0].item()],  # Wrap in list
        [x_true[0, 1].item()],  # Wrap in list
        [global_min_f],  # Wrap in list
        'x', color='black', label="Global min", zorder=3
    )
    ax.plot(
        [x0[0, 0].item()],  # Wrap in list
        [x0[0, 1].item()],  # Wrap in list
        [f_initial],  # Wrap in list
        'x', color='green', label="Initial guess", zorder=3
    )

    # Set view angle
    ax.view_init(elev=50, azim=30)  # Increase the elevation angle to 90 degrees

    # Add labels and legend
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    ax.set_zlabel('f-axis')
    ax.legend()

    plt.show()