In [1]:
from typing import Any, Callable, NamedTuple, Tuple, Union
Step = int
Schedule = Callable[[Step], float]

from IPython import display
from IPython.display import Image, clear_output
from PIL import Image
import glob, os, shutil
import os.path

import time

import scipy.io as io
import scipy.sparse.csgraph as csgraph
from scipy.sparse.csgraph import laplacian as csgraph_laplacian
import scipy as sp
from scipy.stats import gaussian_kde
from scipy.linalg import null_space

import jax
from jax import jit, vmap, random, grad, value_and_grad, hessian
from jax.experimental import optimizers
from jax.experimental.optimizers import optimizer
from jax import numpy as jnp

from functools import partial
import itertools

import math
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import collections as mc
import seaborn as sns

import datetime
from tqdm.notebook import tqdm

import networkx as nx

%load_ext autoreload
%autoreload 2

In [2]:
import utils
from utils import *
from optimizers import *



In [3]:
# load the data from the SuiteSparse Matrix Collection format
# https://www.cise.ufl.edu/research/sparse/matrices/
graphs = ['qh882','dwt_1005','3elt','commanche_dual','bcsstk31']
graphdir = './testcases/'
graphpostfix = 'dwt_1005'
assert graphpostfix in graphs

graph, G, A, L, D, n = load_graph(graphdir+graphpostfix, plot_adjacency=False, verbose=True)

Name: 
Type: Graph
Number of nodes: 1005
Number of edges: 4813
Average degree:   9.5781


In [4]:
del G
del A
del D

In [5]:
@jit
def f(X, A_x, A_y, b_x, b_y):
    obj= X[:,0].T@A_x@X[:,0] + X[:,1].T@A_y@X[:,1] + 2*b_x.T@X[:,0] + 2*b_y.T@X[:,1]
    return obj.real

@jit
def f_l(X, L, C, A_x, A_y, b_x, b_y):
    obj = jnp.trace(jnp.inner(X, A_x@X + 2*jnp.stack([b_x,b_y],axis=1))) + jnp.trace(jnp.inner(L, X.T@X - C))
    return obj.real

@jit
def foc_pgd(X, L, C, A, b_x, b_y):
    obj = jnp.linalg.norm((A + L[0,0]*jnp.eye(A.shape[0]))@X[:,0] + L[1,0]*X[:,1] + b_x) + \
    jnp.linalg.norm((A + L[1,1]*jnp.eye(A.shape[0]))@X[:,1] + L[1,0]*X[:,0] + b_y)
    return obj.real

@jit
def foc_sqp(X, L, C, A, E_0):
    obj = A@X + E_0 + X@L
    return jnp.linalg.norm(obj.real)

def g(X, v, c):
    return np.array([v.T@X[:,0], v.T@X[:,1]]) - c

def h(X, D, c1, c2, c3, c=jnp.array([0,0])):
    return np.array([(X[:,0]-c[0]).T@D@(X[:,0]-c[0]) - c1, 
                     (X[:,1]-c[1]).T@D@(X[:,1]-c[1]) - c2, 
                     2*((X[:,0]-c[0]).T@D@(X[:,1]-c[1]) - c3)])

@jit
def L_init(X_k, C, A, E_0):
    return (jnp.linalg.inv(C)@X_k.T@(A@X_k+E_0)).astype(jnp.float32)

In [6]:
@jit
def project(X1, C, E_0, c=jnp.array([0,0])):
    C1 = X1.T@X1
    C1sqrt = utils._sqrtm(C1)
    Csqrt = utils._sqrtm(C)
    U,s,V = jnp.linalg.svd(Csqrt@C1sqrt)
    X = X1@jnp.linalg.inv(C1sqrt)@U@V.T@Csqrt

    # default to this unless not improve cost
    # normalized v as Q
    negdef = jnp.all(jnp.linalg.eigvals(X.T@E_0) <= 0)
    U_E, _, V_E = jnp.linalg.svd(X.T@E_0)
    X = jax.lax.cond(negdef,
                     lambda _ : X@(-U_E@V_E.T),
                     lambda _ : X,
                     operand=None
                    )
    return X.real

@jit
def _step_noautograd(stp, X_k, A_x, A_y, b_x, b_y):
    """Perform a single gradient (using autograd) + projection step with adaptive momentum."""
    E_0 = stp*jnp.vstack([b_x,b_y]).T
    X_k_x = X_k[:,0] - stp*A_x@X_k[:,0]
    X_k_y = X_k[:,1] - stp*A_y@X_k[:,1]
    X_k_t = jnp.vstack([X_k_x,X_k_y]).T - E_0
    #X_k_t = project(X_k_t, C)
    X_k_t = project(X_k_t, C, E_0)
    return X_k_t

def pgd(X_k, A_x, A_y, b_x, b_y, C, convergence_criterion, 
           maxiters=1000, alpha=1e-2, beta=0.9):
    """Perform iterations of PGD, without autograd."""
    loss = []
    param_hist  = []
    grad_hist = []
    E_0 = jnp.stack([b_x, b_y], axis=1)
    L = jnp.linalg.inv(C)@X_k.T@(A_x@X_k+E_0)
    Ls = [L]
    
    for k in tqdm(range(maxiters)):
        # backtracking line search
        f_xp = 1e8
        stp = 1

        f_x = f(X_k, A_x, A_y, b_x, b_y)
        X_k_t = X_k
        derphi=1
        while f_xp >= f_x - alpha * stp * derphi:
            stp *= beta
            X_k_t = _step_noautograd(stp, X_k, A_x, A_y, b_x, b_y)    
            f_xp = f(X_k_t, A_x, A_y, b_x, b_y)
        
            if stp < 1e-6:
                break  
        if len(loss) > 1 and np.abs(f_x - loss[-1]) <= convergence_criterion:
            break
        #step_sizes.append(stp)
        X_k = X_k_t
        param_hist.append(X_k)
        loss.append(f_x) 
        L = jnp.linalg.inv(C)@X_k.T@(A_x@X_k+E_0)
        Ls.append(L)
        #grad_hist.append(jnp.linalg.norm(f_l_sqp(X_k, -L, C, A_x, E_0)))
        grad_hist.append(foc_pgd(X_k, L, C, A_x, b_x, b_y))
        
    return {'x':X_k, 'lossh':loss, 'sln_path':param_hist, 'foc':grad_hist, 'ext_data':{'L':Ls}}

@jit
def step(i, opt_state, A_x, A_y, b_x, b_y):
    """Perform a single gradient (using autograd) + projection step with adaptive momentum."""
    p = get_params(opt_state)
    g = grad(f)(p, A_x, A_y, b_x, b_y)
    return opt_update(i, g, opt_state)

def pgd_autograd(opt_params, A_x, A_y, b_x, b_y, C, convergence_criterion, maxiters=1000):
    """Perform iterations of PGD, with autograd """
    opt_state, opt_update, get_params = opt_params
    E_0 = jnp.stack([b_x, b_y], axis=1)
    X_k = get_params(opt_state)
    loss = [f(X_k, A_x, A_y, b_x, b_y)]
    L = jnp.eye(2)
    Lh = [L]
    param_hist  = [X_k]
    grad_hist= [foc_pgd(X_k, L, C, A_x, b_x, b_y)]
    for k in tqdm(range(maxiters)):
        opt_state = step(k, opt_state, A_x, A_y, b_x, b_y)
        X_k = get_params(opt_state)
        param_hist.append(X_k)
        l = f(X_k, A_x, A_y, b_x, b_y)
        
        assert not np.isnan(l)
        
        if len(loss) > 1 and np.abs(l - loss[-1]) <= convergence_criterion:
            break
        loss.append(l)

        L = jnp.linalg.inv(C)@X_k.T@(A_x@X_k+E_0)
        Lh.append(L)
        #grad_hist.append(jnp.linalg.norm(f_l_sqp(X_k, -L, C, A_x, E_0)))
        #grad_hist.append(foc_sqp(X_k, L, C, A, E_0))
        grad_hist.append(foc_pgd(X_k, L, C, A_x, b_x, b_y))
        
    return {'x':X_k, 'lossh':loss, 'sln_path':param_hist, 'foc':grad_hist, 'L':Lh}

@jit
def _step(i, opt_state, Z):
    """Perform a single descent + projection step with arbitrary descent direction."""
    return opt_update(i, Z, opt_state)

def D_Z(X, A, d, e):
    I = jnp.eye(A.shape[0])
    Adinv = jnp.linalg.inv(A + d*I)
    XtADinv = X.T@Adinv
    Del = jnp.linalg.inv(XtADinv@X)@XtADinv@e
    Z = Adinv@(-X@Del + e)
    
    return Del, Z

def _D_Z(X, A, P, d, e):
    I = jnp.eye(A.shape[0])
    Adinv = jnp.linalg.inv(A + d*I)
    XtADinv = X.T@Adinv
    Del = jnp.linalg.inv(XtADinv@X)@XtADinv@e
    
    Z = P@Adinv@P@(-X@Del + e)
    
    return Del, Z

@jit
def _sqp(A, A_L, P, L, E_0, X):
    I = jnp.eye(A.shape[0])
    w = jnp.linalg.eigvals(L)
    idx = w.argsort()[::-1]   
    w = w[idx]
    E = -E_0 - (A@X + X@L)
    
    Del_0, Z_0 = _D_Z(X, A_L, P, w[0], E[:,0])
    Del_1, Z_1 = _D_Z(X, A_L, P, w[1], E[:,1])
    
    Z = jnp.stack([Z_0, Z_1], axis=1)
    Del = jnp.stack([Del_0, Del_1], axis=1)
    
    return Z, Del 

@jit
def sqp(A, L, E_0, X):
    I = jnp.eye(A.shape[0])
    w = jnp.linalg.eigvals(L)
    idx = w.argsort()[::-1]   
    w = w[idx]
    E = -E_0 - (A@X + X@L)
    
    Del_0, Z_0 = D_Z(X, A, w[0], E[:,0])
    Del_1, Z_1 = D_Z(X, A, w[1], E[:,1])
    
    Z = jnp.stack([Z_0, Z_1], axis=1)
    Del = jnp.stack([Del_0, Del_1], axis=1)
    
    return Z, Del    

"""Perform iterations of PND + backtracking line search."""
def newton(opt_params, A, A_L, L, C, X_k, b_x, b_y, convergence_criterion, 
           maxiters=100, alpha=1e-2, beta=0.9):
    
    opt_state, opt_update, get_params = opt_params

    loss = [f(X_k, A, A, b_x, b_y)]
    param_hist  = []
    descent_hist = []
    
    E_0 = np.stack([b_x, b_y], axis=1)
    L = L_init(X_k, C, A, E_0)
    #L = jnp.linalg.inv(C)@X_k.T@(A@X_k+E_0)
    #L_sym = (L + L.T)/2
    #L = L_sym
    
    X_k = get_params(opt_state)
    
    grad_hist= []
    hess_hist = []
    
    step_sizes = []
    
    data = {'L':[], 'gradcorr':[], 'stp':[], 'dec':[], 'pre_proj':[]}
    for k in tqdm(range(maxiters)):         
        #Z, Del = sqp(A, L, E_0, X_k)
        Z, Del = _sqp(A, A_L, P, L, E_0, X_k)
        
        # backtracking line search
        f_xp = 1e8
        stp = 1
        f_x = f(X_k, A, A, b_x, b_y)
        derphi = 1
        #f_x, gr = value_and_grad(f)(X_k, A, A, b_x, b_y)
        #derphi = jnp.trace(gr.T@Z)
        len_p = jnp.linalg.norm(Z)
        X_k_t = X_k
        
        opt_state_t = opt_state
        
        while f_xp >= f_x:# - alpha * stp * derphi:
            stp *= beta
            opt_state_t = _step(stp, opt_state, -Z)
            X_k_t = get_params(opt_state_t)
            f_xp = f(X_k_t, A, A, b_x, b_y)
        
            if stp * len_p < 1e-8:
                break       
        step_sizes.append(stp)
        L = L + stp*Del
        
        opt_state = opt_state_t
        X_k = get_params(opt_state_t)
     
        param_hist.append(X_k)
        grad_hist.append(foc_sqp(X_k, L, C, A, E_0))
        
        if len(loss) > 1 and np.abs(f_xp - loss[-1]) <= convergence_criterion:
            break
            
        loss.append(f_x)
        data['gradcorr'].append(derphi)
        data['L'].append(L)
        data['stp'].append(stp)
        data['dec'].append(Z)
        
    return {'x':X_k, 'lossh':loss, 'sln_path':param_hist, 'ext_data':data, 
            'foc':grad_hist, 'step_sizes':step_sizes}
    
def ssm():
    """
    1. compute newton direction z = sqp(X, Z, v, Ax + E0) & subspace S
    2. approximate locally optimal X, L on S; X = min F(\hat{X}, B, V.T@E0)
    """
    pass

In [7]:
def map_vars(A, X_k, fixed_indices, centercons):
    N = A.shape[0] 
    k = fixed_indices.shape[0]
    fixed_idx = np.zeros((k,N))
    for i in range(k):
        fixed_idx[i,fixed_indices[i]] += 1
    if k>0:
        fixed_coordsx = X_k[fixed_indices,0]
        fixed_coordsy = X_k[fixed_indices,1]
        
        constraints = np.concatenate([fixed_idx,np.expand_dims(1-fixed_idx.sum(0),0)])
        fixed_coordsx = np.concatenate([fixed_coordsx,np.expand_dims(centercons[0],0)])
        fixed_coordsy = np.concatenate([fixed_coordsy,np.expand_dims(centercons[1],0)])
        P = null_space(constraints).T
        #_,P = qr_null(constraints).T
        
        pinvcons = np.linalg.pinv(constraints)
        
        n0_x = (pinvcons@fixed_coordsx)
        b_x = (P@(A@n0_x))
        
        n0_y = (pinvcons@fixed_coordsy)
        b_y = (P@(A@n0_y))
        A = (P@A@P.T)
    else:
        constraints = np.expand_dims(np.ones(n),0)    
        P = null_space(constraints).T
        #_,P = qr_null(constraints).T
        
        pinvcons = np.linalg.pinv(constraints)
        n0_x = pinvcons@(np.expand_dims(centercons[0],0))
        b_x = P@(A@n0_x)

        n0_y = pinvcons@(np.expand_dims(centercons[1],0))
        b_y = P@(A@n0_y)
        
        A = P@A@P.T 
        
    return A, P, b_x, b_y, n0_x, n0_y, fixed_idx

def cluster(rng, opt_params, X_k, A, mapped_vars, fixed_indices, maxiters=1000, convergence_criterion=1e-3,
            c1=1, c2=1, c3=0, centroid=jnp.array([0,0]), centercons=None, v=None, D=None, eps=1e-8, method='pgd'):
    """Given an adjacency matrix A and initialization X_k, optimize X."""
    method = method.lower()
    opt_init, opt_update, get_params = opt_params    
    #global opt_update
    
    assert method in ['pgd','pnd','ssm']
    assert len(A.shape) == 2
    assert A.shape[0] == X_k.shape[0]
    
    k = fixed_indices.shape[0]
    fixed_coordsx = X_k[fixed_indices,0]
    fixed_coordsy = X_k[fixed_indices,1]

    N = A.shape[0]

    if v is None:
        v = jnp.ones(N)
    if D is None:
        D = jnp.diag(v)
    if centercons is None:
        centercons = jnp.zeros(2)
    A_L = A
    A, P, b_x, b_y, n0_x, n0_y, fixed_idx = mapped_vars

    C = jnp.block([[c1, c3],[c3, c2]])

    assert jnp.linalg.det(C) > 1e-5
    E_0 = jnp.stack([b_x, b_y], axis=1)
    
    n0 = jnp.stack([n0_x,n0_y],axis=0)
    X_k_n = jnp.array(np.linalg.pinv(P.T)@(X_k-n0.T))
    X_k_n = project(X_k_n, C, E_0, centercons)
    L = np.eye(2)
    
    opt_state = opt_init(X_k_n)
    if method == "pgd":
        A_x = A
        A_y = A

        result = pgd_autograd((opt_state, opt_update, get_params), A_x, A_y, b_x, b_y, C, 
                              convergence_criterion=convergence_criterion, maxiters=maxiters) 
            
        #result = pgd(X_k_n, A_x, A_y, b_x, b_y, C, 
        #                convergence_criterion=convergence_criterion, maxiters=maxiters, alpha=0.5, beta=0.9)  
    elif method == "pnd":
        result = newton((opt_state, opt_update, get_params), A, A_L, L, C, X_k_n, b_x, b_y, 
                        convergence_criterion=convergence_criterion, maxiters=maxiters, alpha=0.0, beta=0.9)        
    else:
        print("method not supported")
        return 1
    X_k = result['sln_path'][np.argmin(result['lossh'])]
    X_k = project(X_k, C, E_0, centercons)
    X_k_n = np.zeros((N,2))
    X_k_n[:,0] = np.array(P.T@X_k[:,0]) + n0_x.T
    X_k_n[:,1] = np.array(P.T@X_k[:,1]) + n0_y.T
        
    result['x'] = X_k_n
    
    mask = (1-fixed_idx.sum(0)).astype(np.bool)
    result['mask'] = mask
    result['centroid'] = centercons
    if fixed_idx.sum() == 0:
        result['g'] = np.array(g(X_k_n, v, centercons))
        result['h'] = np.array(h(X_k_n, np.diag(v), c1, c2, c3, centroid))      
    else:
        result['g'] = np.array(g(X_k_n[mask], v[mask], centercons))
        result['h'] = np.array(h(X_k_n[mask], np.diag(v[mask]), c1, c2, c3, centroid))
    result['P'] = (P)
    result['e'] = np.vstack([b_x,b_y])
    result['n'] = (n0_x, n0_y)
    
    return result

In [8]:
method = "pnd"
seed = 0
eps = 1e-8
alpha = 5e-3
rng = random.PRNGKey(seed)
key, subkey = jax.random.split(rng)

v = np.ones(n)
c1=v.sum()*10**2*1/12
c2=v.sum()*10**2*1/12
c3=0
C = jnp.block([[c1, c3],[c3, c2]])

X_k_r = (random.normal(subkey, (n,2))*np.sqrt(10))

if os.path.isfile(graphdir+graphpostfix+'_evals.npy') and \
   os.path.isfile(graphdir+graphpostfix+'_evecs.npy'):
    w = np.load(graphdir+graphpostfix+'_evals.npy')
    v = np.load(graphdir+graphpostfix+'_evecs.npy')    
else:
    w,v = sp.sparse.linalg.eigsh(L, k=5, which='SM')
    np.save(graphdir+graphpostfix+'_evals.npy',w)
    np.save(graphdir+graphpostfix+'_evecs.npy',v)

X_k = v[:,1:3]

fixed_indices = np.arange(10)
X_k[fixed_indices] = X_k_r[fixed_indices]

#X_k = X_k.astype(jnp.float16)
#L = L.astype(jnp.int16)

In [9]:
del w
del v
del X_k_r
v = jnp.ones(n)
A, P, b_x, b_y, n0_x, n0_y, fixed_idx = map_vars(L, X_k, fixed_indices, v.sum()*jnp.array([0,0]))

mapped_vars = (A, P, b_x, b_y, n0_x, n0_y, fixed_idx)

#print(P.shape)
if method == "pgd":
    # 5e-2
    pgd_lr = 5e-2
    #opt_init, opt_update, get_params = padam(pgd_lr,partial(lambda x, y: project(y,x),C), b1=0.9, b2=0.999, eps=1e-08)
    opt_init, opt_update, get_params = padam(pgd_lr,partial(lambda x, y, z: project(z, y, x), 
                                                    np.stack([b_x,b_y],axis=1), C), b1=0.9, b2=0.999, eps=1e-08)

elif method == "pnd":
    opt_init, opt_update, get_params = psgd(partial(lambda x, y, z: project(z, y, x), 
                                                    np.stack([b_x,b_y],axis=1), C))

In [10]:
result = cluster(rng, (opt_init, opt_update, get_params), 
                 X_k, L, mapped_vars, fixed_indices=fixed_indices,c1=c1, c2=c2, c3=c3, centercons=v.sum()*jnp.array([0,0]), 
                 v=None, D=None, eps=1e-8, maxiters=1000, convergence_criterion=1e-3, method=method)
results = [result]
X_k_n=result['x']

  0%|          | 0/1000 [00:00<?, ?it/s]

TypeError: dot_general requires contracting dimensions to have the same shape, got [994] and [1005].

In [None]:
utils.plot_graph(X_k_n, graph)

In [None]:
f = utils.plot_results(result)

In [None]:
voxel_id, voxel_bound = voxel_cluster(X_k_n, np.array([5, 5]))

In [None]:
utils.plot_graph(X_k_n, graph, title='loss: {} h: {} g: {} foc: {}'.format(str(np.round(np.min(result['lossh']),2)), 
                                                                            np.round(result['g'],2), np.round(result['g'],2), 
                                                                           str(np.round(result['foc'][np.argmin(result['lossh'])],2))), fixed_indices=fixed_indices, c=voxel_id)