<a href="https://colab.research.google.com/github/m-mehabadi/grad-maker/blob/main/_notebooks/Testing_GradientMaker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

http://www.cs.cmu.edu/~pradeepr/convexopt/Lecture_Slides/dual-ascent.pdf

https://web.stanford.edu/class/ee364b/lectures/primal_dual_subgrad_slides.pdf

https://www.cvxpy.org/examples/basic/quadratic_program.html

In [5]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import cvxpy as cp

In [17]:
def gradient_maker(grads, solver=None):
    """
    - make sure to install `cvxpy`. you can use: `pip install cvxpy`
    - `grads` in a numpy's `ndarray`
    - `grads.shape == (n, d)`, where `n` is the number of domains and `d` is the dimension
    - this method will return a tuple of size two, where:
        * the first one is the generalized vector to use with size `d`
        * the second one is the weight vector of the linear combination
    - finally, use g, _ = gradient_maker(grads), if you have no need to use the 2nd return
    """

    import cvxpy as cp
    from numpy import linalg as la

    def nearestPD(A):

        B = (A + A.T) / 2
        _, s, V = la.svd(B)

        H = np.dot(V.T, np.dot(np.diag(s), V))

        A2 = (B + H) / 2

        A3 = (A2 + A2.T) / 2

        if isPD(A3):
            return A3

        spacing = np.spacing(la.norm(A))
        
        I = np.eye(A.shape[0])
        k = 1
        while not isPD(A3):
            mineig = np.min(np.real(la.eigvals(A3)))
            A3 += I * (-mineig * k**2 + spacing)
            k += 1

        return A3


    def isPD(B):
        try:
            _ = la.cholesky(B)
            return True
        except la.LinAlgError:
            return False

    #
    G = grads.T
    n, d = grads.shape
    g_ = np.mean(grads, axis=0).reshape(-1, 1)

    #
    P = nearestPD(n*G.T@G)
    q = -n*G.T@g_
    F = -G.T@G
    h = np.zeros(n, dtype=np.float32)
    A = np.ones(n, dtype=np.float32).reshape(1, -1)
    b = np.ones((1, 1), dtype=np.float32)

    # define opt variable
    x = cp.Variable(n)
    prob = cp.Problem(cp.Minimize((1/2)*cp.quad_form(x, P) + q.T @ x),
                    [F @ x <= h,
                    A @ x == b])
    #
    if solver is None:
        solver = cp.OSQP
    prob.solve(solver=solver, verbose=False)
    s = np.array(x.value)

    return G@s, s

### Testcases

Now let's write some test cases to make sure everything is working correctly


In [46]:
grads = (1000*(np.random.randn(20, 23)+1))*np.random.randn(20, 23)

In [47]:
g, _ = gradient_maker(grads)

In [48]:
grads@g

array([ 3.58234724e+06,  1.40550831e+06,  9.95251171e+05,  2.99566385e+06,
        4.47032769e+06,  1.23865763e+04,  2.54277280e+06,  5.31377469e+05,
        1.05736584e+06,  6.40284270e-10, -1.14960130e-09,  1.33006823e+06,
        1.85425661e+06,  4.04886102e+06,  7.00754961e+06, -2.32830644e-10,
        1.03882485e+06,  3.82335840e+06,  4.11144654e+06,  5.52447919e+06])

In [49]:
from numpy.linalg import norm

In [50]:
grads_norm = norm(grads, axis=1)

In [51]:
grads_norm/np.min(grads_norm)

array([1.28438341, 2.23292966, 2.21013496, 1.63182956, 3.22268433,
       1.79012374, 2.1102299 , 1.        , 1.91065134, 2.38551134,
       1.33586117, 1.87075704, 2.18937582, 2.02508051, 4.03531556,
       1.51772263, 1.72153433, 2.45259443, 2.30904057, 2.93477631])

In [52]:
grads_scaled = grads/np.min(grads_norm)

In [53]:
_g, _ = gradient_maker(grads_scaled)

In [54]:
grads_scaled@g

array([ 1.13163214e+03,  4.43987772e+02,  3.14391134e+02,  9.46303988e+02,
        1.41213739e+03,  3.91281102e+00,  8.03239671e+02,  1.67857491e+02,
        3.34012614e+02,  1.84741111e-13, -2.91322522e-13,  4.20156913e+02,
        5.85743434e+02,  1.27899976e+03,  2.21362359e+03, -8.52651283e-14,
        3.28155677e+02,  1.20776545e+03,  1.29876997e+03,  1.74513462e+03])

In [55]:
g/_g

array([3165.64613219, 3165.64721819, 3165.64647622, 3165.64646544,
       3165.64644994, 3165.6462149 , 3165.64783211, 3165.64677526,
       3165.64663906, 3165.64585135, 3165.64677767, 3165.64681481,
       3165.64661183, 3165.64640462, 3165.64508234, 3165.64718399,
       3165.64588334, 3165.64666064, 3165.64626457, 3165.64635597,
       3165.64628764, 3165.64632591, 3165.64651894])