In [None]:
#GDA and CGD updates pytorch https://github.com/julienroyd/competitive_gradient_descent/blob/master/exp2_gaussian_mixture/CGD_vs_GDA_GaussianMixture_GAN.ipynb
def compute_gda_update(f, x, g, y, eta=None):
    """
    Computes the gradient step for both players
    f: loss function to minimise for player X
    x: current action (parameters) of player X
    g: loss function to minimise for player Y
    y: current action (parameters) of player y
    """
    x_update = list(grad(outputs=f, inputs=x, retain_graph=True))
    y_update = list(grad(outputs=g, inputs=y))

    return x_update, y_update

def compute_cgd_update(f, x, g, y, eta, max_it=10):
    """
    Iteratively estimate the solution for the local Nash equilibrium using the conjugate gradient method
    f: loss function to minimise for player X
    x: current action (parameters) of player X
    g: loss function to minimise for player Y
    y: current action (parameters) of player y
    """
    start = time.time()

    # Computing the gradients

    df_dx = grad(outputs=f, inputs=x, create_graph=True, retain_graph=True)
    dg_dy = grad(outputs=g, inputs=y, create_graph=True, retain_graph=True)

    df_dy = grad(outputs=f, inputs=y, create_graph=True, retain_graph=True)
    dg_dx = grad(outputs=g, inputs=x, create_graph=True, retain_graph=True)

    with torch.no_grad():

        # Creating the appropriate structure for the parameter updates and initialising to 0

        x_update, y_update = [], []
        for x_grad_group, y_grad_group in zip(df_dx, dg_dy):
            x_update.append(torch.zeros_like(x_grad_group))
            y_update.append(torch.zeros_like(y_grad_group))

        # Creating the appropriate structure for the residuals and basis vectors and initialise them

        r_xk, r_yk, p_xk, p_yk = [], [], [], []
        for x_param_update, y_param_update in zip(df_dx, dg_dy):
            r_xk.append(torch.clone(x_param_update))
            p_xk.append(torch.clone(x_param_update))
            r_yk.append(torch.clone(y_param_update))
            p_yk.append(torch.clone(y_param_update))

    # Iteratively solve for the local Nash Equilibrium

    for k in range(max_it):

        # Computes the Hessian-vector product Ap

        hvp_x = grad(outputs=df_dy, inputs=x, grad_outputs=p_yk, retain_graph=True)
        hvp_y = grad(outputs=dg_dx, inputs=y, grad_outputs=p_xk, retain_graph=True)

        with torch.no_grad():

            # Computes the matrix-basisVector product Ap

            Ap_x, Ap_y = [], []
            for i in range(len(p_xk)):
                Ap_x.append(p_xk[i] + eta * hvp_x[i])
                Ap_y.append(p_yk[i] + eta * hvp_y[i])

            # Computes step size alpha_k

            num, denom = 0., 0.
            for i in range(len(r_xk)):

                r_k_i = torch.cat([r_xk[i].flatten(), r_yk[i].flatten()])
                num += torch.dot(r_k_i, r_k_i)

                Ap_i = torch.cat([Ap_x[i].flatten(), Ap_y[i].flatten()])
                p_k_i = torch.cat([p_xk[i].flatten(), p_yk[i].flatten()])

                denom += torch.dot(p_k_i, Ap_i)

            alpha_k = num / denom

            # Computes new updates

            for i in range(len(x_update)):
                x_update[i] += alpha_k * p_xk[i]
                y_update[i] += alpha_k * p_yk[i]

            # Computes new residuals

            r_xkplus1, r_ykplus1 = [], []
            for i in range(len(r_xk)):
                r_xkplus1.append(r_xk[i] - alpha_k * Ap_x[i])
                r_ykplus1.append(r_yk[i] - alpha_k * Ap_y[i])

            # Check convergence condition

            r_xkplus1_squared_sum, r_ykplus1_squared_sum = 0., 0.
            x_update_squared_norm, y_update_squared_norm = 0., 0.
            for i in range(len(r_xkplus1)):
                r_xkplus1_squared_sum += torch.sum(r_xkplus1[i] ** 2.)
                r_ykplus1_squared_sum += torch.sum(r_ykplus1[i] ** 2.)

                x_update_squared_norm += torch.sum(x_update[i] ** 2.)
                y_update_squared_norm += torch.sum(y_update[i] ** 2.)

            r_kplus1_norm = torch.sqrt(r_xkplus1_squared_sum + r_ykplus1_squared_sum)
            update_norm = torch.sqrt(x_update_squared_norm + y_update_squared_norm)

            if r_kplus1_norm <= 1e-6:
                break

            else:

                # Computes beta_k

                num, denom = 0., 0.
                for i in range(len(r_xk)):
                    r_kplus1_i = torch.cat([r_xkplus1[i].flatten(), r_ykplus1[i].flatten()])
                    denom += torch.dot(r_kplus1_i, r_kplus1_i)

                    r_k_i = torch.cat([r_xk[i].flatten(), r_yk[i].flatten()])
                    denom += torch.dot(r_k_i, r_k_i)

                beta_k = num / denom

                # Computes new basis vectors

                for i in range(len(p_xk)):
                    p_xk[i] = r_xkplus1[i] + beta_k * p_xk[i]
                    p_yk[i] = r_ykplus1[i] + beta_k * p_yk[i]

                r_xk = deepcopy(r_xkplus1)
                r_yk = deepcopy(r_ykplus1)

    return x_update, y_update