In [119]:
import numpy as np
import torch
from hdmm.templates import TemplateStrategy
from hdmm import workload

In [67]:
mask = torch.tril(torch.ones(5,5, dtype=torch.uint8), diagonal=-1)
B = torch.zeros(5,5)
z = torch.rand(10)
B[mask] = z
#print(B)
B = B + torch.t(B)
#torch.diag(torch.tensor(0.5), diagonal=0, out=B)
torch.diag(B).max()

tensor(0.)

In [238]:
class McKennaConvex(TemplateStrategy):
    def __init__(self, n):
        self.n = n
        self._mask = torch.tril(torch.ones(n,n, dtype=torch.uint8), diagonal=-1)
        self._params = torch.zeros(n*(n-1)//2)
        self.X = torch.zeros(n,n)

    def _set_workload(self, W):
        self.V = torch.tensor(W.gram().dense_matrix().astype(np.float32))
        self.W = W

    def _loss(self):
        V = self.V
        X = 0.5*torch.eye(self.n, out=self.X)
        
        X[self._mask] = self._params
        X = X + torch.t(X) # warning: don't do this in place

        try:
            zz = torch.cholesky(X)
            #iX = torch.cholesky_inverse(zz)
            iX = torch.inverse(X)
        except:
            return torch.tensor(np.inf)
      
        return torch.sum(iX * V) 

    def optimize(self, W, iters=5000):
        self._set_workload(W)

        eig, P = torch.symeig(self.V, eigenvectors=True)
        eig[eig < 1e-10] = 0.0
        X = P @ torch.diag(torch.sqrt(eig)) @ torch.t(P)
        X /= torch.diag(X).max()
        
        self._params = X[self._mask].requires_grad_(True)
        
        # have to implement the optimization loop manually :(
        
        beta = 1.0
        for it in range(500):
            self._params.detach_().requires_grad_(True)
            curr_loss = self._loss()
            curr_loss.backward()
            params = self._params
            for i in range(0, 25):
                self._params = params - beta * params.grad.data
                loss = self._loss()
                if loss < curr_loss:
                    break
                beta *= 0.5
            if it % 100 == 0:
                print(beta,torch.sqrt(loss/self.W.shape[0]))
            
            #print(loss)
              

In [239]:
n = 512
W = workload.Prefix(n)
#W = workload.AllRange(n)
temp = McKennaConvex(n)
temp.optimize(W)

0.015625 tensor(3.0193, grad_fn=<SqrtBackward>)
0.0009765625 tensor(2.8757, grad_fn=<SqrtBackward>)
0.0009765625 tensor(2.8355, grad_fn=<SqrtBackward>)
0.0009765625 tensor(2.8132, grad_fn=<SqrtBackward>)
0.0009765625 tensor(2.7986, grad_fn=<SqrtBackward>)


In [266]:
R = workload.Prefix(64)
I = workload.Identity(64)
T = workload.Total(64)

W = workload.VStack([workload.Kronecker([R,T]), workload.Kronecker([T,R])])
A = workload.Kronecker([I,T])

WtW = W.gram()
A = workload.Marginals((64,64), np.array([0,0,1,0]))
AtA = A.gram()
AtA1 = AtA.pinv()

z = np.random.rand(WtW.shape[0])
for X, Y in zip(WtW.matrices, (WtW @ AtA1 @ AtA).matrices):
    err = np.linalg.norm(X.dot(z) - Y.dot(z))
    print(err)

2.0682157488120567e-10
6390.995693388852


In [267]:
from hdmm import error
error.expected_error(W, A)

3477.499999999999

In [261]:
WtW 

array([[64., 64., 64., ...,  1.,  1.,  1.],
       [64., 64., 64., ...,  1.,  1.,  1.],
       [64., 64., 64., ...,  1.,  1.,  1.],
       ...,
       [ 1.,  1.,  1., ...,  1.,  1.,  1.],
       [ 1.,  1.,  1., ...,  1.,  1.,  1.],
       [ 1.,  1.,  1., ...,  1.,  1.,  1.]])