### Minimax Games

Consider a two-player game $\min_{\mathbf{x}} \max_{\mathbf{y}} f(\mathbf{x}, \mathbf{y})$ where $f: \mathbb{R}^m \times \mathbb{R}^n \rightarrow \mathbb{R}$ is twice continuously differentiable. An interesting example is generative adversarial networks (GANs):

\\[ \min_{G} \max_{D} \, \mathbb{E}_{x \sim p_{\text{data}}} \left[ \log(D(x)) \right] + \mathbb{E}_{z \sim p_{\text{latent}}} \left[ \log(1 - D(G(z))) \right]. \\]

Here $\mathbf{x}$ represents the parameters of the generator $G$ which aims to transform latent vectors $z \sim p_{\text{latent}}$ into samples $G(z)$ that mimic those from a given data distribution $p_{\text{data}}$, and $\mathbf{y}$ represents the parameters of the discriminator $D$ which minimizes a [log loss](http://wiki.fast.ai/index.php/Log_Loss) to differentiate between $G(z)$ and $x \sim p_{\text{data}}$. As in the case of GANs, the players of minimax games are often parametrized by neural networks, and the loss functions $f(\mathbf{x}, \mathbf{y})$ are not necessarily convex in $\mathbf{x}$ or concave in $\mathbf{y}$. Hence, Nash equilibria do not necessarily exist, and even if they do, finding a global Nash equilibrium is hopelessly impractical. We therefore restrict to gradient-based methods that hopefully converge to local solutions.

For convenience, let $\mathbf{w} = (\mathbf{x}, \mathbf{y}) \in \mathbb{R}^{m + n}$ denote the vector of combined parameters and $\xi(\mathbf{w}) = \xi(\mathbf{x}, \mathbf{y}) =(\nabla_{\mathbf{x}} f(\mathbf{w}), -\nabla_{\mathbf{y}}f(\mathbf{w})) \in \mathbb{R}^{m + n}$ denote the signed vector of partial derivatives, often known as the simultaneous gradient. The Hessian of the game is a $(m+ n) \times (m + n)$-matrix of second-order derivaties, which is not necessarily symmetric:

\\[ \mathbf{H}(\mathbf{w}) := \nabla_{\mathbf{w}} \cdot \xi(\mathbf{w})^{\mathsf{T}} = \begin{bmatrix}\nabla^2_{\mathbf{x}\mathbf{x}} f(\mathbf{w}) & \nabla^2_{\mathbf{x}\mathbf{y}} f(\mathbf{w}) \\  -\nabla^2_{\mathbf{y}\mathbf{x}} f(\mathbf{w}) & -\nabla^2_{\mathbf{y}\mathbf{y}} f(\mathbf{w}) \end{bmatrix}. \\]

**Definition 1:** A point $\mathbf{w}^\star = (\mathbf{x}^\star, \mathbf{y}^\star)$ is a **local minimax** (or Nash equilibirum) if there is a neighborhood $U$ of $(\mathbf{x}^\star, \mathbf{y}^\star)$ such that $f(\mathbf{x}^\star, \mathbf{y}) \leq f(\mathbf{x}^\star, \mathbf{y}^\star) \leq f(\mathbf{x}, \mathbf{y}^\star)$ for all $(\mathbf{x}, \mathbf{y}) \in U$. These conditions are equivalent to $\xi(\mathbf{w}^\star) = 0$ and $\nabla^2_{\mathbf{x}\mathbf{x}} f(\mathbf{w}^\star) \succeq 0$ and $\nabla^2_{\mathbf{y}\mathbf{y}} f(\mathbf{w}^\star) \preceq 0$.

**Definition 2:** A point $\mathbf{w}$ in a discrete time dynamical system with update rule $\mathbf{w}_{t + 1} = \omega(\mathbf{w}_{t})$ is called a **fixed point** if $\omega(\mathbf{w}) = \mathbf{w}$. A fixed point $\mathbf{w}$ is **stable** if the spectral radius $\rho(\mathbf{J}(\mathbf{w}))$ is at most 1, where $\mathbf{J(\mathbf{w})}$ is the Jacobian of $\omega$ computed at $\mathbf{w}$.

The reason we're interested in spectral analysis of the Jacobian of the fixed points is the following well-known fact: if a fixed point $\mathbf{w}$ is stable and hyperbolic (i.e. $\mathbf{J}(\mathbf{w})$ has no eigenvalues with absolute value 1), there is a small neighborhood around $\mathbf{w}$ such that all initializations in that neighborhood results in convergence to $\mathbf{w}$.

### Gradient Descent Ascent (GDA)
A straightforward optimization routine to solve $\min_{\mathbf{x}} \max_{\mathbf{y}} f(\mathbf{x}, \mathbf{y})$ is <em>gradient descent-ascent</em> (GDA), where both players take a gradient update simultaneously $\mathbf{w}_{t + 1} = \mathbf{w}_t - \eta \xi(\mathbf{w})$, i.e.:

\\[ \begin{bmatrix}\mathbf{x}_{t + 1} \\ \mathbf{y}_{t + 1} \end{bmatrix} = \begin{bmatrix}\mathbf{x}_{t} \\ \mathbf{y}_{t} \end{bmatrix} - \eta \begin{bmatrix}\nabla_{\mathbf{x}} f( \mathbf{x}_t, \mathbf{y}_t)  \\ -\nabla_{\mathbf{y}} f( \mathbf{x}_t, \mathbf{y}_t) \end{bmatrix}. \\]

**Theorem 1** ([Mescheder et al., 2017](http://papers.nips.cc/paper/6779-the-numerics-of-gans.pdf), [Daskalakis and Panageas (2018](https://papers.nips.cc/paper/8136-the-limit-points-of-optimistic-gradient-descent-in-min-max-optimization.pdf)): If the Hessian computed at a local minimax has no purely imaginary eigenvalue, then the local minimax is a stable fixed point with small enough learning rate.

**Proof of Theorem 1**: 
Before we prove Theorem 1, it's worth noting that stability does not guarantee local minimaxity. For example, $(0, 0)$ is not a local minimax of $f(x, y) = 3x^2 + 4xy + y^2$ due to $\nabla_{\mathbf{y}\mathbf{y}}((0, 0)) = 2 > 0$, though it is a stable fixed point of GDA for all $\eta < 1$ (the Jacobian of GDA has eigenvalues $1 - 2\eta$). 

More generally, the Jacobian of GDA at a local minimax $\mathbf{w}^\star$ is $\mathbf{J}(\mathbf{w}^\star) = \mathbf{I} - \eta \mathbf{H}(\mathbf{w}^\star)$, which has eigenvalues $1 - \eta \lambda(\mathbf{H})$ where $\lambda(\mathbf{H})$ are eigenvalues of the Hessian evaluated at $\mathbf{w}^\star$. Since $\mathbf{w}^\star$ is a local minimax, $\nabla^2_{\mathbf{x}\mathbf{x}}(\mathbf{w}^\star)$ and $\nabla^2_{\mathbf{y}\mathbf{y}}(\mathbf{w}^\star)$ are positive semidefinite and thus by Ky Fan inequality, $\text{Re}(\lambda(\mathbf{H})) \geq \frac{1}{2} \lambda_\min(\mathbf{H} + \mathbf{H}^\mathsf{T}) \geq 0$. By choosing $\eta < 2\min_{\lambda(\mathbf{H})} \left\{\text{Re}(\lambda(\mathbf{H})) / |\lambda(\mathbf{H})|^2  \right\}$, we have $|1 - \eta \lambda(\mathbf{H})| = 1 - \eta (2\text{Re}(\lambda(\mathbf{H}) - \eta |\lambda(\mathbf{H})|^2) < 1$. In other words, any local minimax $\mathbf{w}^\star$ of GDA is stable if the learning rate $\eta$ is small enough. However, if the Hessian has an eigenvalue with a small real part but a large imaginary part ($\text{Re}(\lambda(\mathbf{H})) / |\lambda(\mathbf{H})|^2$ is small), the learning rate has to be very small, which implies extremely slow convergence.

In case the Hessians contain purely imaginary eigenvalues, Theorem 1 does not guarantee convergence of GDA. In fact, recent works ([Mertikopoulos et al., 2018](https://arxiv.org/abs/1709.02738), [Balduzzi et al., 2018](https://arxiv.org/abs/1802.05642)) show that GDA exhibits strong rotation around fixed points and sometimes diverges. In the simple bilinear setting where $f(\mathbf{x}, \mathbf{y}) = \mathbf{x}^{\mathsf{T}} \mathbf{A} \mathbf{y}$ for some matrix $\mathbf{A}$, for example, the simultaneous gradient $\xi(\mathbf{x}, \mathbf{y}) = (\mathbf{A} \mathbf{y}, - \mathbf{A}^{\mathsf{T}} \mathbf{x})$ implies

\\[ \mathbf{w}_{t + 1} = \begin{bmatrix}\mathbf{I} & -\mathbf{A} \\ \mathbf{A}^{\mathsf{T}} & \mathbf{I}  \end{bmatrix}\mathbf{w}_{t} = \det(\mathbf{I} + \mathbf{A} \mathbf{A}^{\mathsf{T}}) \, (\mathbf{R} \mathbf{w}) \quad\qquad \text{where} \quad\qquad \mathbf{R} = \frac{1}{\det(\mathbf{I} + \mathbf{A} \mathbf{A}^{\mathsf{T}})}  \begin{bmatrix}\mathbf{I} & -\mathbf{A} \\ \mathbf{A}^{\mathsf{T}} & \mathbf{I}  \end{bmatrix} \in \text{SO}(m + n) .\\]

is a rotation matrix. As a result, GDA updates show cyclic behavior around $(\mathbf{0}, \mathbf{0})$ and can easily diverge when $\det(\mathbf{I} + \mathbf{A} \mathbf{A}^{\mathsf{T}}) > 1$. The rotation phenomenon is not limited to the bilinear setting (see Figure 1); it is commonly observed that the generator in GANs often cycles through a subset of modes and fails to capture the diversity of the data distribution (see Figure 2), a problem often known as mode collapsing.

<figure>
  <img src="images/mechanics-of-differentiable-games/softplus.png" style='margin: 10px auto' alt="my alt text"/>
  <figcaption>Figure 1: The paths taken by various gradient-based methods next to a loss surface and its contour for solving $\min_x \max_y f(x, y)$ where $f(x, y) = \log(1 + e^x) + 3xy - \log(1 + e^y)$.</figcaption>
</figure>

<figure>
  <img src="images/mechanics-of-differentiable-games/gda.gif" style='margin: 10px auto' alt="my alt text"/>
  <figcaption>Figure 2: Training GAN on a mixture of 16 Gaussians with gradient descent ascent (GDA). Left: Kernel density plot of samples generated by the generator. Middle: Scatter plots of generated samples in orange and true samples in blue together with contours of the discriminator. Right: Training loss values of the generator and the discrimnator. </figcaption>
</figure>

### Consensus Optimization (CO)
 
[Mescheder et al. (2017)](http://papers.nips.cc/paper/6779-the-numerics-of-gans.pdf) observed that when the rotation phenomenon happens, simultaneous gradients $\xi(\mathbf{x}, \mathbf{y})$ decrease slowly in norm (consider the bilinear setting where $\mathbf{x}, \mathbf{y} \in \mathbb{R}$ and $\mathbf{A} = 1$). Since $\xi(\mathbf{x}, \mathbf{y})$ at fixed points are 0, one way to mitigate rotation is to directly penalize $\|\xi(\mathbf{x}, \mathbf{y})\|^2$, i.e. solving $\min_{\mathbf{x}} \ell_1(\mathbf{x}, \mathbf{y})$ and $\min_{\mathbf{y}} \ell_2(\mathbf{x}, \mathbf{y})$ simultaneously where

\\[ \ell_1(\mathbf{x}, \mathbf{y}) = f(\mathbf{x}, \mathbf{y}) + \frac{1}{2}\gamma \|\xi(\mathbf{x}, \mathbf{y})\|^2, \qquad \ell_2(\mathbf{x}, \mathbf{y}) = -f(\mathbf{x}, \mathbf{y}) + \frac{1}{2}\gamma \|\xi(\mathbf{x}, \mathbf{y})\|^2 \\]

and $\gamma > 0$ is a hyperparameter. The new optimization method is called consensus optimization (CO), whose gradient update has the form $\mathbf{w}_{t + 1} = \mathbf{w}_t - \eta \xi(\mathbf{w}) - \eta \gamma \mathbf{H}^{\mathsf{T}}(\mathbf{w}) \xi(\mathbf{w})$, i.e.

\\[ \begin{bmatrix}\mathbf{x}_{t + 1} \\ \mathbf{y}_{t + 1} \end{bmatrix} = \begin{bmatrix}\mathbf{x}_{t} \\ \mathbf{y}_{t} \end{bmatrix} - \eta \begin{bmatrix}\nabla_{\mathbf{x}} f( \mathbf{x}_t, \mathbf{y}_t)  \\ -\nabla_{\mathbf{y}} f( \mathbf{x}_t, \mathbf{y}_t) \end{bmatrix} - \eta \gamma \begin{bmatrix}\nabla^2_{\mathbf{x}\mathbf{x}} f(\mathbf{x}, \mathbf{y}) & \nabla^2_{\mathbf{x}\mathbf{y}} f(\mathbf{x}, \mathbf{y}) \\  -\nabla^2_{\mathbf{y}\mathbf{x}} f(\mathbf{x}, \mathbf{y}) & -\nabla^2_{\mathbf{y}\mathbf{y}} f(\mathbf{x}, \mathbf{y}) \end{bmatrix}^\mathbf{\mathsf{T}} \begin{bmatrix}\nabla_{\mathbf{x}} f( \mathbf{x}_t, \mathbf{y}_t)  \\ -\nabla_{\mathbf{y}} f( \mathbf{x}_t, \mathbf{y}_t) \end{bmatrix}. \\]

Here, we can view CO as GDA on the new loss functions where the simultaneous gradient is $\xi_\gamma(\mathbf{w}) = \xi(\mathbf{w}) + \gamma \mathbf{H}^{\mathsf{T}}(\mathbf{w}) \xi(\mathbf{w}) = (\mathbf{I} + \gamma \mathbf{H}^{\mathsf{T}} (\mathbf{w}))\xi(\mathbf{w})$ and the Hessian is $\mathbf{H}_\gamma(\mathbf{w}) = \nabla_{\mathbf{w}} \cdot \xi_\gamma(\mathbf{w})^{\mathsf{T}} = \mathbf{H}(\mathbf{w}) + \gamma \mathbf{H}^{\mathsf{T}}(\mathbf{w}) \mathbf{H}(\mathbf{w}) + (\nabla_{\mathbf{w}} \cdot \mathbf{H}(\mathbf{w})) \xi(\mathbf{w})$. Since the objective functions have changed, a natural question to ask is whether the local minimaxes of the original game $\min_{\mathbf{x}} \max_{\mathbf{y}} f(\mathbf{x}, \mathbf{y})$ are still retained. The answer is yes, if $\gamma$ is small enough, for two following reasons:
* $\xi_\gamma(\mathbf{w}^\star) = 0$ implies $\xi(\mathbf{w}^\star) = 0$, if we pick $\gamma$ such that $-\gamma^{-1}$ is not an eigenvalue of $\mathbf{H}(\mathbf{w}^\star)$, i.e. $\mathbf{I} + \gamma \mathbf{H}^{\mathsf{T}} (\mathbf{w}^\star)$ is invertible.
* At any local minimax $\mathbf{w}^\star$ of the original game, $\xi(\mathbf{w}^\star) = 0$ implies that $\mathbf{H}_\gamma(\mathbf{w}^\star) = \mathbf{H}(\mathbf{w}^\star) + \gamma \mathbf{H}^{\mathsf{T}}(\mathbf{w}^\star) \mathbf{H}(\mathbf{w}^\star)$. Since $\nabla^2_{\mathbf{x}\mathbf{x}} f(\mathbf{w}^\star) \succeq 0$ and $\nabla^2_{\mathbf{y}\mathbf{y}} f(\mathbf{w}^\star) \preceq 0$ and $\mathbf{H}^{\mathsf{T}}(\mathbf{w}^\star) \mathbf{H}(\mathbf{w}^\star)$ is positive semi definite, it is clear that $\nabla^2_{\mathbf{x}\mathbf{x}} \ell_1 (\mathbf{w}^\star) \succeq 0$ for all $\gamma$ and $\nabla^2_{\mathbf{y}\mathbf{y}} \ell_2 (\mathbf{w}^\star) \preceq 0$ for small enough $\gamma$. Note that $\nabla^2_{\mathbf{y}\mathbf{y}} \ell_2 (\mathbf{w}^\star)$ is actually "less positive definite" than $\nabla^2_{\mathbf{y}\mathbf{y}} f(\mathbf{w}^\star)$ because of $\mathbf{H}^{\mathsf{T}}(\mathbf{w}^\star) \mathbf{H}(\mathbf{w}^\star)$.

By an argument similar to the proof of Theorem 1, we can easily show that any eigenvalue $\lambda(\mathbf{H}_\gamma)$ of $\mathbf{H}_\gamma(\mathbf{w}^\star)$ has nonnegative real part: $\text{Re}(\lambda(\mathbf{H}_\gamma)) \geq \frac{1}{2} \lambda_{\min} (\mathbf{H}_\gamma + \mathbf{H}_\gamma^{\mathsf{T}}) \geq 0$. [Mescheder et al. (2017)](http://papers.nips.cc/paper/6779-the-numerics-of-gans.pdf) also came up with some upper bound for the imaginary-to-real ratio, showing that that convergence of CO is potentially faster and more stable, although the bound is not intuitive. Empirically, consensus optimization works quite well in settings where GDA struggles with rotational forces (see Figure 1), but its performance is very sensitive to the choice of $\gamma$.

<figure>
  <img src="images/mechanics-of-differentiable-games/co.gif" style='margin: 10px auto' alt="my alt text"/>
  <figcaption>Figure 2: Training GAN on a mixture of 16 Gaussians with consensus optimization (CO). Left: Kernel density plot of samples generated by the generator. Middle: Scatter plots of generated samples in orange and true samples in blue together with contours of the discriminator. Right: Training loss values of the generator and the discrimnator. </figcaption>
</figure>

### Symplectic Gradient Adjustment (SGA)
To be continued...

### Experiments with Toy Datasets

In [None]:
import re
import torch
import numpy as np
import matplotlib.pyplot as plt

from scipy import linalg
from IPython.core.display import HTML
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from collections import OrderedDict


def plot_3d(ax, func, xrange, yrange, cmap="viridis", elev=None, azim=None):
    X = np.arange(xrange[0], xrange[1], 0.01)
    Y = np.arange(yrange[0], yrange[1], 0.01)
    X, Y = np.meshgrid(X, Y)
    Z = func.fn(X, Y)

    ax.plot_surface(X, Y, Z, cmap=cmap)
    ax.view_init(elev=elev, azim=azim)
    ax.grid(False)
    ax.dist = 7.5

    ax.set_xlim(xrange)
    ax.set_ylim(yrange)
    ax.set_zlim(np.min(Z), np.max(Z))


def plot_2d(ax, func, xrange=[-4.5, 4.5], yrange=[-4.5, 4.5], logz=False, num_lines=60, **kwargs):
    X = np.arange(xrange[0], xrange[1], 0.01)
    Y = np.arange(yrange[0], yrange[1], 0.01)
    X, Y = np.meshgrid(X, Y)
    Z = func.fn(X, Y)

    Z = np.log(Z) if logz else Z
    ax.contour(X, Y, Z, num_lines, linewidths=1.0, **kwargs)
    ax.set_xlim(xrange)
    ax.set_ylim(yrange)


def plot_path(
    func, methods, lr, init_point, min_point=None, num_steps=200, xrange=(-4.5, 4.5), yrange=(-4.5, 4.5), azim=None, elev=None, cmap="viridis", num_lines=60, figsize=(14, 4), **kwargs,
):
    xy_values, loss_values = OrderedDict(), OrderedDict()
    for method in methods:
        method_fn = eval(method)
        xy_values[method], loss_values[method] = method_fn(func, lr, init_point, num_steps, **kwargs)

    fig = plt.figure(figsize=figsize)
    ax1 = fig.add_subplot(121, projection="3d")
    plot_3d(ax1, func, xrange, yrange, elev=elev, azim=azim, cmap=cmap)
    ax1.plot([min_point[0]], [min_point[1]], [func.fn(min_point[0], min_point[1])], zorder=10, marker=(5, 1), markersize=10, alpha=0.4, color="C3") if min_point else None
    for i, method in enumerate(xy_values.keys()):
        ax1.plot(xy_values[method][:, 0], xy_values[method][:, 1], loss_values[method], "-o", zorder=10, markersize=1.5, linewidth=1, color="C%d" % i, label=method.upper())

    ax2 = fig.add_subplot(122)
    plot_2d(ax2, func, xrange, yrange, num_lines=num_lines, cmap=cmap)
    ax2.plot([min_point[0]], [min_point[1]], marker=(5, 1), markersize=10, alpha=0.4, color="C3") if min_point else None
    for i, (method, xy_value) in enumerate(xy_values.items()):
        ax2.plot(xy_value[:, 0], xy_value[:, 1], "-o", markersize=1.5, linewidth=1, color="C%d" % i, label=method.upper())
    ax2.legend()

In [None]:
def gda(func, lr, init_point, num_steps=100, **kwargs):
    x, y = init_point
    xy_values, loss_values = [], []
    for _ in range(num_steps):
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        x = x - lr * func.dx(x, y)
        
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        y = y + lr * func.dy(x, y)
    return np.array(xy_values), np.array(loss_values)


def sga(func, lr, init_point, num_steps=100, **kwargs):
    x, y = init_point
    xy_values, loss_values = [], []
    for _ in range(num_steps):
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        x = x - lr * (func.dx(x, y) + kwargs['lambd'] * func.dxy(x, y) * func.dy(x, y))
        
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        y = y + lr * (func.dy(x, y) - kwargs['lambd'] * func.dyx(x, y) * func.dx(x, y))
    return np.array(xy_values), np.array(loss_values)


def eg(func, lr, init_point, num_steps=100, **kwargs):
    x, y = init_point
    xy_values, loss_values = [], []
    for _ in range(num_steps):
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        x = x - lr * func.dx(x - lr * func.dx(x, y), y + lr * func.dy(x, y))
        
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        y = y + lr * func.dy(x - lr * func.dx(x, y), y + lr * func.dy(x, y))
    return np.array(xy_values), np.array(loss_values)


def ogda(func, lr, init_point, num_steps=100, **kwargs):
    x, y = init_point
    xy_values, loss_values = [], []
    x_prev, y_prev = x, y
    
    for _ in range(num_steps):
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        x = x - 2 * lr * func.dx(x, y) + lr * func.dx(x_prev, y_prev)
        
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        y = y + 2 * lr * func.dy(x, y) - lr * func.dy(x_prev, y_prev)
        x_prev, y_prev = x, y
    return np.array(xy_values), np.array(loss_values)


def cgd(func, lr, init_point, num_steps=100, **kwargs):
    x, y = init_point
    xy_values, loss_values = [], []
    
    for _ in range(num_steps):
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        dx, dy = func.dx(x, y), func.dy(x, y)
        dxy, dyx = func.dxy(x, y), func.dyx(x, y)
        dxx, dyy = func.dxx(x, y), func.dyy(x, y)
        
        deltax, deltay = -linalg.solve(np.array([[1. / lr + dxx, dxy], [-dyx, 1. / lr - dyy]]), np.array([dx, -dy]))
        x = x + deltax
    
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        dx, dy = func.dx(x, y), func.dy(x, y)
        dxy, dyx = func.dxy(x, y), func.dyx(x, y)
        dxx, dyy = func.dxx(x, y), func.dyy(x, y)
        deltax, deltay = -linalg.solve(np.array([[1. / lr + dxx, dxy], [-dyx, 1. / lr - dyy]]), np.array([dx, -dy]))
        y = y + deltay
#         x = x - lr * (dx + lr * dxy * dy) / (1 + lr ** 2 * dxy * dyx)
#         y = y - lr * (-dy + lr * dyx * dx) / (1 + lr ** 2 * dyx * dxy)
    return np.array(xy_values), np.array(loss_values)


def fr(func, lr, init_point, num_steps=100, **kwargs):
    x, y = init_point
    xy_values, loss_values = [], []
    for _ in range(num_steps):
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        x = x - lr * func.dx(x, y)
        
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        y = y + lr * func.dy(x, y) + lr * (1. / func.dyy(x, y)) * func.dyx(x, y) * func.dx(x, y)
    return np.array(xy_values), np.array(loss_values)


def co(func, lr, init_point, num_steps=100, **kwargs):
    x, y = init_point
    xy_values, loss_values = [], []
    for _ in range(num_steps):
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        x = x - lr * func.dx(x, y) - 2 * lr * kwargs['gamma'] * (func.dx(x, y) * func.dxx(x, y) + func.dy(x, y) * func.dyx(x, y))
        
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)        
        y = y + lr * func.dy(x, y) - 2 * lr * kwargs['gamma'] * (func.dy(x, y) * func.dyy(x, y) + func.dx(x, y) * func.dxy(x, y))
    return np.array(xy_values), np.array(loss_values)


def fun(func, lr, init_point, num_steps=100, **kwargs):
    x, y = init_point
    xy_values, loss_values = [], []
    for _ in range(num_steps):
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
#         x = x - lr * func.dx(x, y) + lr * (1. / func.dxx(x, y)) * func.dxy(x, y) * func.dy(x, y)
#         x = x - lr * func.dx(x, y)
#         y = y + lr * func.dy(x, y) - lr * func.dyx(x, y) * func.dx(x, y)
#         y = y + lr * func.dy(x, y)

        x = x - lr * (func.dx(x, y) + kwargs['lambd'] * func.dxy(x, y) * func.dy(x, y))
        loss = func.fn(x, y)
        xy_values.append((x, y))
        loss_values.append(loss)
        y = y + lr * (func.dy(x, y) - kwargs['lambd'] * func.dyx(x, y) * func.dx(x, y))
    return np.array(xy_values), np.array(loss_values)

In [None]:
from collections import namedtuple

Function = namedtuple('Function', ['fn', 'dx', 'dy', 'dxx', 'dxy', 'dyx', 'dyy'])
func = Function(
    fn=lambda x, y: x * y,
    dx=lambda x, y: y,
    dy=lambda x, y: x,
    dxx=lambda x, y: 0,
    dxy=lambda x, y: 1,
    dyx=lambda x, y: 1,
    dyy=lambda x, y: 0,
)
plot_path(func, methods=['gda', 'sga', 'co', 'cgd'], lr=5e-2, lambd=1, gamma=0.5, init_point=(6, 6), num_steps=300, min_point=(0, 0), num_lines=60, xrange=(-15, 15), yrange=(-15, 15), elev=30)

In [None]:
from collections import namedtuple

Function = namedtuple('Function', ['fn', 'dx', 'dy', 'dxx', 'dxy', 'dyx', 'dyy'])
func = Function(
    fn=lambda x, y: 3 * x * x + y * y + 4 * x * y,
    dx=lambda x, y: 6 * x + 4 * y,
    dy=lambda x, y: 2 * y + 4 * x,
    dxx=lambda x, y: 6,
    dxy=lambda x, y: 4,
    dyx=lambda x, y: 4,
    dyy=lambda x, y: 2,
)
plot_path(func, methods=['gda', 'sga', 'eg', 'fr', 'co', 'ogda', 'cgd', 'fr'], lr=5e-2, lambd=1.0, gamma=0.5, init_point=(6, 6), num_steps=300, min_point=(0, 0), num_lines=60, xrange=(-15, 15), yrange=(-15, 15), elev=30)

In [None]:
from collections import namedtuple

Function = namedtuple('Function', ['fn', 'dx', 'dy', 'dxx', 'dxy', 'dyx', 'dyy'])
func = Function(
    fn=lambda x, y: np.log(1 + np.exp(x)) + 3 * x * y - np.log(1 + np.exp(y)),
    dx=lambda x, y: 1. / (1. + np.exp(-x)) + 3 * y,
    dy=lambda x, y: -1. / (1. + np.exp(-y)) + 3 * x,
    dxx=lambda x, y: 1. / (1. + np.exp(-x)) - 1. / (1. + np.exp(-x)) ** 2,
    dxy=lambda x, y: 3,
    dyx=lambda x, y: 3,
    dyy=lambda x, y: -1. / (1. + np.exp(-x)) + 1. / (1. + np.exp(-x)) ** 2,
)
fig = plot_path(func, methods=['fr', 'fun'], lr=5e-2, lambd=1.0, gamma=0.5, init_point=(6, 6), num_steps=300, min_point=(0, 0), num_lines=60, xrange=(-10, 10), yrange=(-10, 10), elev=30, figsize=(14, 3.5))

In [None]:
from collections import namedtuple

Function = namedtuple('Function', ['fn', 'dx', 'dy', 'dxx', 'dxy', 'dyx', 'dyy'])
func = Function(
    fn=lambda x, y: np.exp(-0.01 * (x * x + y * y)) * (4 * x * x - (y - 3 * x + 0.05 * x ** 3) ** 2 - 0.1 * y ** 4),
    dx=lambda x, y: np.exp(-0.01 * (x * x + y * y)) * (0.00005 * x ** 7 - 0.021 * x ** 5 + 0.002 * x ** 4 * y + 1.3 * x ** 3 - 0.42 * x ** 2 * y + x * (0.002 * y ** 4 + 0.02 * y ** 2 - 10) + 6 * y),
    dy=lambda x, y: np.exp(-0.01 * (x * x + y * y)) * (0.00005 * x ** 6 * y - 0.006 * x ** 4 * y + x ** 3 * (0.002 * y ** 2 - 0.1) + 0.1 * x ** 2 * y + x * (6 - 0.12 * y ** 2) + 0.002 * y ** 5 - 0.38 * y ** 3 - 2 * y),
    dxx=lambda x, y: np.exp(-0.01 * (x * x + y * y)) * (-0.04 * x * (8 * x - 2 * (0.15 * x ** 2 - 3) * (0.05 * x ** 3 - 3 * x + y)) + (-0.6 * x * (0.05 * x ** 3 - 3 * x + y) - 2 * (0.15 * x ** 2 - 3) ** 2 + 8) + 0.0004 * x ** 2 * (-(0.05 * x ** 3 - 3 * x + y) ** 2 + 4 * x ** 2 - 0.1 * y ** 4) - 0.02 * (-(0.05 * x ** 3 - 3 * x + y) ** 2 + 4 * x ** 2 - 0.1 * y ** 4)),
    dxy=lambda x, y: np.exp(-0.01 * (x * x + y * y)) * (-2 * (0.15 * x ** 2 - 3) - 0.02 * y * (8 * x - 2 * (0.15 * x ** 2 - 3) * (0.05 * x ** 3 - 3 * x + y)) + 0.0004 * x * y * (-(0.05 * x ** 3 - 3 * x + y) ** 2 + 4 * x ** 2 - 0.1 * y ** 4) - 0.02 * x * (-2 * (0.05 * x ** 3 - 3 * x + y) - 0.4 * y ** 3)),
    dyx=lambda x, y: np.exp(-0.01 * (x * x + y * y)) * (-2 * (0.15 * x ** 2 - 3) - 0.02 * y * (8 * x - 2 * (0.15 * x ** 2 - 3) * (0.05 * x ** 3 - 3 * x + y)) + 0.0004 * x * y * (-(0.05 * x ** 3 - 3 * x + y) ** 2 + 4 * x ** 2 - 0.1 * y ** 4) - 0.02 * x * (-2 * (0.05 * x ** 3 - 3 * x + y) - 0.4 * y ** 3)),
    dyy=lambda x, y: np.exp(-0.01 * (x * x + y * y)) * ((-1.2 * y ** 2 - 2) + 0.0004 * y ** 2 * (-(0.05 * x ** 3 - 3 * x + y) ** 2 - 0.1 * y ** 4) - 0.02 * (-(0.05 * x ** 3 - 3 * x + y) ** 2 + 4 * x ** 2 - 0.1 * y ** 4) - 0.04 * y * (-2 * (0.05 * x ** 3 - 3 * x + y) - 0.4 * y ** 3)),
)
# plot_path(func, methods=['gda', 'sga', 'eg', 'fr', 'co', 'ogda', 'cgd'], lr=5e-2, lambd=1.0, gamma=0.03, init_point=(5, 5), num_steps=100, min_point=(0, 0), num_lines=60, xrange=(-7.5, 7.5), yrange=(-7.5, 7.5), elev=70)
# plot_path(func, methods=['cgd', 'sga', 'co', 'gda'], lr=5e-2, lambd=1.0, gamma=0.03, init_point=(5, 5), num_steps=100, min_point=(0, 0), num_lines=60, xrange=(-7.5, 7.5), yrange=(-7.5, 7.5), elev=70)
plot_path(func, methods=['gda', 'sga', 'fun'], lr=5e-2, lambd=5e-2, gamma=0.03, init_point=(5, 5), num_steps=100, min_point=(0, 0), num_lines=60, xrange=(-7.5, 7.5), yrange=(-7.5, 7.5), elev=70)

### Experiments with GAN Training

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools


def build_generator():
    return tf.keras.Sequential([
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(2, activation=None),
    ])


def build_discriminator():
    return tf.keras.Sequential([
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(384, activation='relu'),
        tf.keras.layers.Dense(1, activation=None),
    ])


def data_generator(batch_size=256, sigma=0.02):
    mus = np.mgrid[-1.5:2:1, -1.5:2:1].reshape(2, -1).T.astype(np.float32)
    mus = np.tile(mus, (batch_size // 16 + 1, 1))[:batch_size].astype(np.float32)
    return mus + sigma * tf.random.normal((batch_size, 2))


def train(step_fn, num_iterations=10001):
    generator = build_generator()
    discriminator = build_discriminator()
    optimizer = tf.keras.optimizers.Adam(lr=1e-4, beta_1=0.5, beta_2=0.9999)
    
    fixed_noise = tf.random.normal([2560, 64])
    X, Y = np.mgrid[-3:3:0.03, -3:3:0.03]
    fixed_grid = np.vstack([X.flatten(), Y.flatten()]).T.astype(np.float32)
    fixed_samples = data_generator(2560).numpy()
    
    gen_losses, dis_losses = [], []
    for it in range(num_iterations):
        real_samples = data_generator()
        gen_loss, dis_loss = step_fn(generator, discriminator, optimizer, real_samples)
        gen_losses.append(gen_loss.numpy()[0])
        dis_losses.append(dis_loss.numpy()[0])

        if it % 1000 == 0:
            x = generator(fixed_noise, training=False).numpy()
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 4))
            sns.kdeplot(x[:, 0], x[:, 1], shade=True, ax=ax1, cmap='Reds')
            ax1.set_title('Iteration {}'.format(it))

            ax2.contour(X, Y, discriminator(fixed_grid).numpy().reshape(200, 200), levels=10)
            ax2.scatter(fixed_samples[:, 0], fixed_samples[:, 1], s=5, alpha=0.2, c='C0')
            ax2.scatter(x[:, 0], x[:, 1], s=5, alpha=0.2, c='C1')
            ax2.set_title('Discriminator {:.4f}'.format(dis_losses[-1]))
            
            ax3.plot(range(len(gen_losses)), gen_losses, alpha=0.8, label='Generator')
            ax3.plot(range(len(dis_losses)), dis_losses, alpha=0.8, label='Discriminator')
            ax3.set_title('Generator {:.4f}'.format(gen_losses[-1]))
            ax3.legend()

            for ax in [ax1, ax2]:
                ax.set_xlim([-3, 3])
                ax.set_ylim([-3, 3])
                ax.set_aspect('equal')
            plt.show()

In [None]:
@tf.function
def gradient_descent_ascent(generator, discriminator, optimizer, real_samples):
    noise = tf.random.normal((256, 64))
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        fake_samples = generator(noise)
        real_outputs = discriminator(real_samples)
        fake_outputs = discriminator(fake_samples)

        gen_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(fake_outputs), fake_outputs, from_logits=True)
        dis_loss_real = tf.keras.losses.binary_crossentropy(tf.ones_like(real_outputs), real_outputs, from_logits=True)
        dis_loss_fake = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_outputs), fake_outputs, from_logits=True)
        dis_loss = dis_loss_real + dis_loss_fake
    
    gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
    dis_grads = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
    optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    optimizer.apply_gradients(zip(dis_grads, discriminator.trainable_variables))
    return gen_loss, dis_loss


train(step_fn=gradient_descent_ascent)

In [None]:
@tf.function
def concensus_optimization(generator, discriminator, optimizer, real_samples):
    noise = tf.random.normal((256, 64))
    with tf.GradientTape(persistent=True) as reg_tape:
        with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
            fake_samples = generator(noise)
            real_outputs = discriminator(real_samples)
            fake_outputs = discriminator(fake_samples)

            gen_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(fake_outputs), fake_outputs, from_logits=True)
            dis_loss_real = tf.keras.losses.binary_crossentropy(tf.ones_like(real_outputs), real_outputs, from_logits=True)
            dis_loss_fake = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_outputs), fake_outputs, from_logits=True)
            dis_loss = dis_loss_real + dis_loss_fake

        gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
        dis_grads = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
        grad_norm = 0.5 * sum(tf.reduce_sum(tf.square(g)) for g in itertools.chain(gen_grads, dis_grads))
        
    gen_reg_grads = reg_tape.gradient(grad_norm, generator.trainable_variables)
    dis_reg_grads = reg_tape.gradient(grad_norm, discriminator.trainable_variables)

    optimizer.apply_gradients((grad + 1. * reg, var) for grad, reg, var in zip(gen_grads, gen_reg_grads, generator.trainable_variables))
    optimizer.apply_gradients((grad + 1. * reg, var) for grad, reg, var in zip(dis_grads, dis_reg_grads, discriminator.trainable_variables))
    del reg_tape
    return gen_loss, dis_loss


train(step_fn=concensus_optimization)

In [None]:
@tf.function
def symplectic_gradient(generator, discriminator, optimizer, real_samples):
    noise = tf.random.normal((256, 64))
    with tf.GradientTape(persistent=True) as reg_tape:
        with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
            fake_samples = generator(noise)
            real_outputs = discriminator(real_samples)
            fake_outputs = discriminator(fake_samples)

            gen_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(fake_outputs), fake_outputs, from_logits=True)
            dis_loss_real = tf.keras.losses.binary_crossentropy(tf.ones_like(real_outputs), real_outputs, from_logits=True)
            dis_loss_fake = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_outputs), fake_outputs, from_logits=True)
            dis_loss = dis_loss_real + dis_loss_fake

        gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
        dis_grads = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
        import pdb; pdb.set_trace()
        variables = itertools.chain(generator.trainable_variables, discriminator.trainable_variables)
        gradients = itertools.chain(gen_grads, dis_grads)
    
        temp = tf.ones_like(gradients)
        grad_temp = reg_tape.gradient(gradients, variables, output_gradients=temp)
        jacvec = reg_tape.gradient(grad_temp, temp, output_gradients=gradients)
        
        dydxs = reg_tape.gradients(gradients, variables, output_gradients=grads, stop_gradients=variables)
        jactvec = [tf.zeros_like(var) if dydx is None else dydx for var, dydx in zip(variables, dydxs)]
        
        at_v = list_divide_scalar(list_subtract(ht_v, h_v), 2.)
    
    optimizer.apply_gradients((grad + 1. * 0.5 * (jv - jtv), var) for jv, jtv, grad, var in zip(jacvec, jacvectc, gradients, variables))
    del reg_tape
    return gen_loss, dis_loss


train(step_fn=symplectic_gradient)

In [1]:
from IPython.core.display import HTML
HTML(open('../css/custom.css', 'r').read())