In [2]:
import numpy as np
from hdmm import workload

def err(A, W):
    AtA1 = np.linalg.pinv(A.T.dot(A))
    WtW = W.T.dot(W)
    return np.sum(WtW * AtA1)

def obj(As, W):
    V = np.zeros((W.shape[0], len(As)))
    Bs = [np.linalg.pinv(A) for A in As]
    Xs = [W.dot(B) for B in Bs]
    V = np.vstack([np.sum(X**2, axis=1) for X in Xs])
    s = np.sum(1.0/V, axis=0)
    f = np.sum(1.0 / s)
    
    dV = 1 / (V**2 * s**2)
    dXs = [2*X*dv[:,None] for X, dv in zip(Xs, dV)]
    dBs = [W.T.dot(dX) for dX in dXs]
    dAs = []
    for A, B, dB in zip(As, Bs, dBs):
        m, n = A.shape
        dA = -B.dot(dB.T).dot(B)
        dA += B.dot(B.T).dot(dB).dot(np.eye(m) - A.dot(B))
        dA += (np.eye(n) - B.dot(A)).dot(dB).dot(B.T).dot(B)
        dAs.append(dA.T)
    
    return f, dAs
    

In [6]:
W = workload.AllRange(16)
A = np.vstack([np.eye(16), np.random.rand(16)])
I = np.eye(16)
f, dA = obj([A, I], W)

approx = np.zeros(16)
for i in range(16):
    A[i,i] += 1e-5
    f1, _ = obj([A,I], W)
    approx[i] = (f1 - f) / 1e-5
    A[i,i] -= 1e-5
print(np.diag(dA[0]))
print(approx)

[-11.01274238 -20.86668479 -17.04529003 -13.00929539 -10.60649447
 -32.64858989 -23.29262734  -9.86175454 -38.40060603 -14.22138308
 -35.52036922 -33.53803085 -29.81053493 -19.46231608 -15.96811841
 -17.80842562]
[-11.01259884 -20.86640997 -17.0450691  -13.0091495  -10.60638053
 -32.64815327 -23.29232371  -9.86164854 -38.40008985 -14.2212227
 -35.51989183 -33.53758048 -29.81013569 -19.46206008 -15.96790725
 -17.80820211]


In [13]:
import autograd.numpy as np
from autograd import grad
from autograd.extend import defvjp
from functools import reduce

#def pinv_vjp(g, ans, vs, gvs, A):
#    A1 = np.linalg.pinv(A)
#    In = np.eye(A.shape[1])
#    Im = np.eye(A.shape[0])
#    term1 = -np.dot(A1, np.dot(g.T, A1))
#    term2 = np.dot(np.dot(A1, A1.T), np.dot(g, Im - np.dot(A, A1)))
#    term3 = np.dot(In - np.dot(A1, A), np.dot(g, np.dot(A1.T, A1)))
#    return (term1 + term2 + term3).T

#np.linalg.pinv.defvjp(pinv_vjp)

def pinv_vjp(ans, A):
    A1 = np.linalg.pinv(A)
    In = np.eye(A.shape[1])
    Im = np.eye(A.shape[0])
    def foo(g):
        term1 = -np.dot(A1, np.dot(g.T, A1))
        term2 = np.dot(np.dot(A1, A1.T), np.dot(g, Im - np.dot(A, A1)))
        term3 = np.dot(In - np.dot(A1, A), np.dot(g, np.dot(A1.T, A1)))
        return (term1 + term2 + term3).T
    return foo

defvjp(np.linalg.pinv, pinv_vjp)

def kron_obj(As, Ws):
    # As is a l x d table
    # Ws is a k x d table
    L = len(As)
    K = len(Ws)
    D = len(As[0])
    
    #deltas = sum([reduce(np.kron, [np.sum(A, axis=0)[:,None] for A in kron]) for kron in As])
    #delta = np.max(deltas)**2
    #delta = np.sum(eps)**2
    delta = 1.0
    
    #print 'delta', np.max(deltas) / np.min(deltas)
    
    # Todo: global normalization rather than local
    Bs = [[np.linalg.pinv(A/np.sum(A, axis=0)) for A in kron] for kron in As]
    V = [[None for _ in range(K)] for _ in range(L)]
    for l in range(L):
        for k in range(K):
            v = [None for _ in range(D)]
            for d in range(D):
                A = As[l][d] / np.sum(As[l][d], axis=0)
                X = np.dot(Ws[k][d], Bs[l][d])
                # check to make sure strategy supports workload
                v[d] = np.sum(X**2, axis=1)[:,None]
                if not np.allclose(np.dot(X, A), Ws[k][d]):
                    print('checkpt')
            V[l][k] = reduce(np.kron, v).flatten()
            
    V2 = np.array([np.concatenate(vs) for vs in V]) / eps[:,None]**2
    
    s = np.sum(1.0/V2, axis=0)
    f = np.sum(1.0 / s)
    return delta*f
                
I = np.eye(8)
A1 = np.vstack([np.eye(8),np.random.rand(8)])
A2 = np.vstack([np.eye(5),np.random.rand(5)])
As = [[A1, A2]]
W1 = np.random.rand(8,8)
W2 = np.random.rand(4,5)
Ws = [[W1, W2]]
eps = np.ones(2)

print(kron_obj(As, Ws))
dA = grad(kron_obj)(As, Ws)

print(dA[0][0].shape, dA[0][1].shape)

93.08239456914286
(9, 8) (6, 5)


  return lambda g: g[idxs]


In [17]:
from scipy import optimize
from hdmm import workload
from experiments.census_workloads import CensusSF1
from scipy.optimize import minimize

sf1 = CensusSF1()

Ws = [[S.W for S in K.workloads] for K in sf1.workloads]
ps = [1,1,6,1,10]
As = [[np.vstack([np.eye(n), np.random.rand(p,n)]) for p, n in zip(ps, sf1.domain)] for _ in range(2)]
D = len(As[0])
L = len(As)

def vect_to_mats(params):
    idx = 0
    ans = []
    for _ in range(2):
        Ai = []
        for n, p in zip(sf1.domain, ps):
            stop = idx+n*(n+p)
            Ai.append(params[idx:stop].reshape(n+p, n))
            idx = stop
        ans.append(Ai)
    return ans

def mats_to_vect(As):
    vects = []
    for i in range(2):
        vects.append(np.concatenate([A.flatten() for A in As[i]]))
    return np.concatenate(vects)

gradient1 = grad(kron_obj, argnum=0)
#gradient2 = grad(kron_obj, argnum=1)
id_err = kron_obj([[np.eye(n) for n in sf1.domain]], Ws)

def loss_and_grad(params):
    #eps = params[:2]
    As = vect_to_mats(params)
    #eps = params[:2]
    ans = kron_obj(As, Ws)
    dAs = gradient1(As, Ws)
    #deps = gradient2(As, eps, Ws)
    dparams = mats_to_vect(dAs)
    print(id_err / ans)
    #print ans, params.sum(), np.sum([[np.sum(A) for A in Ai] for Ai in As])
    return ans, dparams

#print kron_obj(As, Ws)
#print grad(kron_obj)(As, Ws)

#eps = np.ones(2)
params = mats_to_vect(As)
bounds = [(0, None)] * params.size
res = optimize.minimize(loss_and_grad, x0=params, method='L-BFGS-B', jac=True, bounds=bounds)

ModuleNotFoundError: No module named 'experiments.census_workloads'

In [208]:
As = vect_to_mats(res.x)
print sf1.domain, len(Ws)
print kron_obj(As, Ws) / id_err
As = vect_to_mats(res.x)
def normalize(A):
    return A / A.sum(axis=0)
print normalize(As[1][1])

(2, 2, 64, 17, 115) 33
0.079835069431
[[ 0.01244673  0.        ]
 [ 0.          1.        ]
 [ 0.98755327  0.        ]]


In [131]:
Ws = [[S.W for S in K.workloads] for K in sf1.workloads]
As = vect_to_mats(res.x)
_, V = kron_obj(As, Ws)
#dAs = gradient(As, Ws)



In [129]:
V

array([[  0.00000000e+00,   0.00000000e+00,   0.00000000e+00, ...,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00],
       [  4.65193887e-04,   2.31354605e-05,   2.63760322e-05, ...,
          6.99569653e-06,   6.55013243e-06,   8.45687156e-05]])