# MMA Overview

In [30]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import imageio
import cooper
import matplotlib.animation as animation

%matplotlib qt

## Example functions to show convex approximation

In [31]:
class LamModel(nn.Module):
    def __init__(self, lam_fun, x0):
        super().__init__()
        self.x = nn.Parameter(torch.tensor([x0]))
        # self.x = nn.Parameter(torch.tensor([0.0,1.0]))
        # self.x2 = nn.Parameter(torch.tensor([5.0]))
        self.lam_fun = lam_fun

    def forward(self):
        return self.lam_fun(self.x)
        # return self.lam_fun(self.x) + self.x2


quad_fun = lambda x: (x)**2
x_max = 2
x_min = -2

In [32]:


# Define the range of values for x0
x0_values = np.linspace(-2, 2, 6*30)
# append a flipped version of x0_values too
x0_values = np.append(x0_values[:-1], np.flip(x0_values))

# Create a list to hold the frames of the animation
frames = []


fig, ax = plt.subplots()

def animate(x0, ma=1):
    quad_model = LamModel(quad_fun, x0)
    n_params = sum(p.numel() for p in quad_model.parameters() if p.requires_grad)
    output = torch.sum(quad_model.forward())

    # Get the gradient of the output with respect to the parameters of the model
    dfdx = torch.autograd.grad(output, quad_model.parameters(), create_graph=True)

    # Reshape the gradient into a vector
    dfdx = torch.cat([g.contiguous().view(-1) for g in dfdx])

    # Define the MMA parameters
    x_current = quad_model.x.data
    lower = x_current - 0.5/ma*(x_max-x_min)
    upper = x_current + 0.5/ma*(x_max-x_min)

    # Construct the approximate constants
    p = (upper-x_current)**2 * (1.001 * torch.relu(dfdx) + 0.001 * torch.relu(-dfdx) + 1e-5/(x_max-x_min))
    q = (x_current-lower)**2 * (0.001 * torch.relu(dfdx) + 1.001 * torch.relu(-dfdx) + 1e-5/(x_max-x_min))
    r = output - torch.sum( p/(upper-x_current) + q/(x_current-lower) )

    alpha1 = x_min
    alpha2 = lower + 0.1*(x_current-lower)
    alpha3 = x_current - 0.5*(x_max-x_min)
    alphas = [alpha1, alpha2, alpha3]
    alphas = [float(a) for a in alphas]
    alpha = max(alphas)
    beta1 = x_max
    beta2 = upper - 0.1*(upper-x_current)
    beta3 = x_current + 0.5*(x_max-x_min)
    betas = [beta1, beta2, beta3]
    betas = [float(b) for b in betas]
    beta = min(betas)

    # colors for the bounds above, make them similar but slightly different
    colors = ['magenta', 'cyan', 'orange']

    x_t = torch.linspace(alpha, beta, 100)
    f_tilde = p/(upper-x_t) + q/(x_t-lower) + r


    # Plot the function, and the current value of x, and the current value of the function
    x = np.linspace(x_min, x_max, 100)
    yl = [-0.5, 5]

    ax.clear()
    ax.plot(x, quad_fun(x))
    ax.plot(quad_model.x.detach().numpy(), quad_fun(quad_model.x.detach().numpy()), 'o')
    ax.plot(x_t, f_tilde.detach().numpy())
    # Plot the MMA parameters as vertical dotted lines
    for (i, (alpha, beta)) in enumerate(zip(alphas,betas)):
        ax.plot([alpha, alpha], yl, '--', color = colors[i])
        ax.plot([beta, beta], yl, '--', color = colors[i])
    ax.set_ylim(yl)
    ax.set_xlim([x_min*1.1, x_max*1.1])
    ax.set_xlabel('$x$')
    ax.set_ylabel('$f(x)$')
    ax.set_title(rf'$x^{{(0)}}$ = {x0:.2f}')
    # Convert the plot to an image and add it to the frames list
    fig.canvas.draw

ma = 2
anim_fun = lambda x0: animate(x0, ma=ma)
anim = animation.FuncAnimation(fig, anim_fun, frames=x0_values, interval=6000/len(x0_values))
anim.save(f'quad_fun_ma{ma}.gif', writer='imagemagick')

# close figure
plt.close()

MovieWriter imagemagick unavailable; using Pillow instead.


## Show how the moving asymptotes work for the example

In [34]:
x0 = -1.9
xk = [x0]
x0 = xk[0]
print(xk)
upper_hist = []
lower_hist = []

def mma_update(frame_ct):
    # frame_ct is a weird variable here. It is not the iteration number of MMA. I use frame_ct for plotting as a gif, so there are two frames for each MMA iteration.

    # x0 is really x^(k)
    x0 = xk[-1]
    k = frame_ct//2
    quad_model = LamModel(quad_fun, x0)
    n_params = sum(p.numel() for p in quad_model.parameters() if p.requires_grad)
    output = torch.sum(quad_model.forward())

    # Get the gradient of the output with respect to the parameters of the model, using AD
    # There are other ways to do this, such as initializing the gradient, looping through the parameters(), and grabbing .grad from each parameter.
    dfdx = torch.autograd.grad(output, quad_model.parameters(), create_graph=True)

    # Reshape the gradient into a vector
    dfdx = torch.cat([g.contiguous().view(-1) for g in dfdx])

    x_current = quad_model.x.data
    if k<2:
        lower = x_current - 0.5*(x_max-x_min)
        upper = x_current + 0.5*(x_max-x_min)
        gamma = np.nan
    else:
        # Update the bounds based on if the last two steps were in the same direction
        deltax_1 = xk[-1] - xk[-2]
        deltax_2 = xk[-2] - xk[-3]
        if deltax_1*deltax_2 < 0:
            gamma = 0.7
        elif deltax_1*deltax_2 > 0:
            gamma = 1.2
        else:
            gamma = 1
        lower = x_current - gamma*(xk[-1] - lower_hist[-1])
        upper = x_current + gamma*(upper_hist[-1] - xk[-1])
    if frame_ct%2==1:
        lower_hist.append(lower)
        upper_hist.append(upper)

    # Construct the approximate constants
    p = (upper-x_current)**2 * (1.001 * torch.relu(dfdx) + 0.001 * torch.relu(-dfdx) + 1e-5/(x_max-x_min))
    q = (x_current-lower)**2 * (0.001 * torch.relu(dfdx) + 1.001 * torch.relu(-dfdx) + 1e-5/(x_max-x_min))
    r = output - torch.sum( p/(upper-x_current) + q/(x_current-lower) )

    alpha1 = x_min
    alpha2 = lower + 0.1*(x_current-lower)
    alpha3 = x_current - 0.5*(x_max-x_min)
    alphas = [alpha1, alpha2, alpha3]
    alphas = [float(a) for a in alphas]
    alpha = max(alphas)
    beta1 = x_max
    beta2 = upper - 0.1*(upper-x_current)
    beta3 = x_current + 0.5*(x_max-x_min)
    betas = [beta1, beta2, beta3]
    betas = [float(b) for b in betas]
    beta = min(betas)

    # colors for the bounds above, make them similar but slightly different
    colors = ['magenta', 'cyan', 'orange']

    x_t = torch.linspace(alpha, beta, 1000)
    f_tilde = p/(upper-x_t) + q/(x_t-lower) + r

    min_ind = torch.argmin(f_tilde)
    min_x = x_t[min_ind]
    min_f = f_tilde[min_ind].detach()
    if frame_ct%2==1:
        xk.append(min_x.item())

    # Plot the function, and the current value of x, and the current value of the function
    x = np.linspace(x_min, x_max, 1000)
    yl = [-0.5, 4.2]

    ax.clear()
    ax.plot(x, quad_fun(x))
    ax.plot(quad_model.x.detach().numpy(), quad_fun(quad_model.x.detach().numpy()), 'o')
    ax.plot(x_t, f_tilde.detach().numpy())
    # Plot the MMA parameters as vertical dotted lines
    # Plot only the 2nd beta and 2nd alpha.
    i = 1
    alpha = alphas[i]
    beta = betas[i]
    ax.plot([alpha, alpha], yl, '--', color = colors[i])
    ax.plot([beta, beta], yl, '--', color = colors[i])
    # Plot the minimum of the MMA approximation, as an empty square
    ax.set_ylim(yl)
    ax.set_xlim([x_min*1.1, x_max*1.1])
    ax.set_xlabel('$x$')
    ax.set_ylabel('$f(x)$')
    ax.set_title(rf'$\gamma^{{({k})}}$ = {gamma:.1f}, $x^{{({k})}}$ = {x0:.2f}, Error = {np.abs(min_f):.1e}')

    if frame_ct%2==1:
        ax.plot(min_x, min_f, 's', color='red', fillstyle='none')

time_per_frame = 2000
frames = range(23)
fig, ax = plt.subplots()
anim = animation.FuncAnimation(fig, mma_update, frames=frames, interval=time_per_frame)
anim.save('mma_steps.gif', writer='imagemagick')
# close figure
plt.close()

MovieWriter imagemagick unavailable; using Pillow instead.


[-1.9]


Traceback (most recent call last):
  File "c:\Users\grant\anaconda3\envs\torch_pde\lib\site-packages\matplotlib\cbook\__init__.py", line 287, in process
    func(*args, **kwargs)
  File "c:\Users\grant\anaconda3\envs\torch_pde\lib\site-packages\matplotlib\animation.py", line 911, in _start
    self.event_source.add_callback(self._step)
AttributeError: 'NoneType' object has no attribute 'add_callback'
