In [None]:
from dolfin import *
from dolfin_adjoint import *
import moola 

import jax.numpy as jnp
import jax
import jax.random as jrandom
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)

import time

In [34]:
n = 4

mesh = UnitSquareMesh(n, n)

V = FunctionSpace(mesh, "CG", 1)
W = FunctionSpace(mesh, "DG", 0)

f = Function(W)

u = Function(V, name='State')
v = TestFunction(V)

x = SpatialCoordinate(mesh)
w = Expression("sin(pi*x[0])*sin(pi*x[1])", degree=3)
d = 1 / (2 * pi ** 2)
d = Expression("d*w", d=d, w=w, degree=3)

def f1(f_np):
    """Takes a numpy array as input."""

    # cf = MeshFunction("bool", mesh, mesh.geometric_dimension())
    # subdomain = CompiledSubDomain('std::abs(x[0]-0.5) < 0.25 && std::abs(x[1]-0.5) < 0.25')
    # subdomain.mark(cf, True)
    # mesh = refine(mesh, cf)


    
    f.vector().set_local(f_np)

    


    
    F = (inner(grad(u), grad(v)) - f * v) * dx
    bc = DirichletBC(V, 0.0, "on_boundary")
    solve(F == 0, u, bc)

    

    alpha = Constant(1e-6)
    J = assemble((0.5 * inner(u - d, u - d)) * dx + alpha / 2 * f ** 2 * dx)
    control = Control(f)

    rf = ReducedFunctional(J, control)
    
    return rf.derivative().vector().get_local()




In [None]:
def f2(f_np, direction):
    """Takes a numpy array as input."""

    # cf = MeshFunction("bool", mesh, mesh.geometric_dimension())
    # subdomain = CompiledSubDomain('std::abs(x[0]-0.5) < 0.25 && std::abs(x[1]-0.5) < 0.25')
    # subdomain.mark(cf, True)
    # mesh = refine(mesh, cf)


    f.vector().set_local(f_np)

    
    
    F = (inner(grad(u), grad(v)) - f * v) * dx
    bc = DirichletBC(V, 0.0, "on_boundary")
    solve(F == 0, u, bc)

    

    alpha = Constant(1e-6)
    J = assemble((0.5 * inner(u - d, u - d)) * dx + alpha / 2 * f ** 2 * dx)
    control = Control(f)

    rf = ReducedFunctional(J, control)
    
    direction_func = Function(W)
    direction_func.vector().set_local(direction)
    rf.derivative()
    return rf.hessian(direction_func).vector().get_local()

In [None]:
def multilevel_beta_newton_update_2B1_1B1(f1, x_0, alpha, N, R, d_prime, jrandom_key):
    """Makes use of Matrix-Vector Product."""
    d = len(x_0)

    jrandom_key, subkey = jrandom.split(jrandom_key)
    U_idxs = jrandom.choice(subkey, a=d, shape=(d_prime,), replace=False)
    U = jnp.eye(d)[U_idxs].T # (d, d')

    jrandom_key, subkey = jrandom.split(jrandom_key)
    sample_points = beta_sampling(x_0, x_0.shape[0], N, alpha, R, subkey, chosen_basis_idx=U_idxs)  
    ru = (sample_points - jnp.mean(sample_points, axis=0)).T # (d, N)
    
    jrandom_key, subkey = jrandom.split(jrandom_key)
    out_grads = []
    for i in range(N):
        out_grads.append(f1(np.array(sample_points[i])))
        
    out_grads = jnp.array(out_grads)
    print(out_grads)
    print()
    

    gradF = f1(x_0)
    grad_X_low_inv = jnp.linalg.inv(out_grads.T[U_idxs].dot(ru[U_idxs].T)/float(N))
    cov = jnp.cov(sample_points.T)
#     print(cov)
    return -cov.dot(U.dot(grad_X_low_inv.dot(U.T.dot(gradF))))

In [None]:
W = FunctionSpace(mesh, "DG", 0)

f = interpolate(Expression("x[0]+x[1]", name='Control', degree=1), W)


x_0 = np.array(f.vector())
alpha = 10
N = 100 # 8*len(f.vector().get_local())
R = 0.1
d_prime = len(f.vector().get_local()) # //2
jrandom_key = jrandom.PRNGKey(0)

start_time = time.time()
res = multilevel_beta_newton_update_2B1_1B1(f1, x_0, alpha, N, R, d_prime, jrandom_key)
print(time.time() - start_time)

In [35]:
def beta_2E1(x_0, alpha, N, R, jrandom_key, control_variate=True):
    jrandom_key, subkey = jrandom.split(jrandom_key)
    sample_points = beta_sampling(x_0, x_0.shape[0], N, alpha, R, subkey)  
    if control_variate: 
        ru = sample_points - jnp.mean(sample_points, axis=0)
    else:
        ru = sample_points - x_0
    jrandom_key, subkey = jrandom.split(jrandom_key)
    
    out_grads = []
    for i in range(N):
        temp = np.array(sample_points[i])
        out_grads.append(f1(temp))
        
    out_grads = jnp.array(out_grads)   
    
    g_ru = out_grads.T.dot(ru)/float(N)
    cov =  jnp.cov(sample_points.T) #
    return g_ru.dot(jnp.linalg.inv(cov))

In [36]:

W = FunctionSpace(mesh, "DG", 0)

f = interpolate(Expression("x[0]+x[1]", name='Control', degree=1), W)


x_0 = f.vector().get_local()
alpha = 10
N = 100 # len(f.vector().get_local())
R = 0.1
d_prime = len(f.vector().get_local())
jrandom_key = jrandom.PRNGKey(0)

start_time = time.time()
H_est = beta_2E1(x_0, alpha, N, R, jrandom_key, control_variate=True)
print("time est hess", time.time() - start_time)


time est hess 2.986694574356079


In [None]:
jnp.linalg.inv(jnp.array(H)).dot(f1(x_0))

In [None]:
H = []

U = jnp.eye(len(x_0))
start_time = time.time()
for i in range(len(x_0)):
    H.append(f2(x_0, U[i]))
    
print("get true", time.time() - start_time)

In [1]:
from dolfin import *
from dolfin_adjoint import *
import moola 

from moola.adaptors.dolfin_vector import DolfinPrimalVector


import jax.numpy as jnp
import jax
import jax.random as jrandom
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)

import time

class BetaApproximations():

    '''
    This class implements the limit-memory BFGS approximation of the inverse Hessian.
    '''
    def __init__(self, jrandom_key):
        alpha = 10
        N = 100
        R = 0.1
        self.alpha = alpha
        self.N = N
        self.R = R
        self.d_prime = None

        self.jrandom_key = jrandom_key


    def multi_two_B_one(self, obj, x_0):
        """Makes use of Matrix-Vector Product."""
        alpha = self.alpha
        N = self.N
        R = self.R 
        d_prime = self.d_prime

        x_0_np = x_0.data.vector().get_local()
#         x_0_np = x_0.vector().get_local()

        f = x_0.copy() # TODO make instance variable? 
        
        d = len(x_0_np)
        d_prime = d

        jrandom_key, subkey = jrandom.split(self.jrandom_key)
        U_idxs = jrandom.choice(subkey, a=d, shape=(d_prime,), replace=False)
        U = jnp.eye(d)[U_idxs].T # (d, d')

        jrandom_key, subkey = jrandom.split(jrandom_key)
        sample_points = beta_sampling(x_0_np, x_0_np.shape[0], N, alpha, R, subkey, chosen_basis_idx=U_idxs)  
        ru = (sample_points - jnp.mean(sample_points, axis=0)).T # (d, N)
        
        jrandom_key, subkey = jrandom.split(jrandom_key)
        out_grads = []
        # parallel ? 
        curr_f = Function(x_0.data.function_space())
        curr_f = moola.DolfinPrimalVector(curr_f)

        for i in range(N):
            f.data.vector().set_local(np.array(sample_points[i]))
            curr_f.assign(f) # there is an issue with hashing if i assign directly to f. Weird stuff. 
            obj(curr_f)            
            out_grads.append(obj.derivative(curr_f).data.vector().get_local())
            
        out_grads = jnp.array(out_grads)

        gradF = jnp.array(obj.derivative(x_0).data.vector().get_local())

        grad_X_low_inv = jnp.linalg.inv(out_grads.T[U_idxs].dot(ru[U_idxs].T)/float(N))
        cov = jnp.cov(sample_points.T)

        self.jrandom_key = jrandom_key

        update_dir_np = -cov.dot(U.dot(grad_X_low_inv.dot(U.T.dot(gradF))))
        update_dir = Function(x_0.data.function_space())
        update_dir.vector().set_local(update_dir_np)
        return DolfinPrimalVector(update_dir)

def beta_sampling(x_0, dim, N, alpha, radius, new_jrandom_key, chosen_basis_idx=None):
    new_jrandom_key, subkey = jrandom.split(new_jrandom_key)
    # sample gaussian and normalize 
    if chosen_basis_idx is None:
        dirs = jrandom.normal(subkey, shape=(N, dim)) 
    else:
        dirs = jrandom.normal(subkey, shape=(N, len(chosen_basis_idx)))
        temp_dirs = jnp.zeros((dim, N))
        temp_dirs = jax.ops.index_update(temp_dirs, chosen_basis_idx, dirs.T)
        dirs = temp_dirs.T

    dirs = dirs/jnp.linalg.norm(dirs, axis=1).reshape(-1, 1)

    new_jrandom_key, subkey = jrandom.split(new_jrandom_key)
    beta_p = (jrandom.beta(subkey, alpha, alpha, shape=(N, 1)) - 0.5) * 2 * radius 

    res = []
    res += dirs * beta_p

    return x_0 + jnp.array(res)

In [2]:
n = 4
mesh = UnitSquareMesh(n, n)

V = FunctionSpace(mesh, "CG", 1)
W = FunctionSpace(mesh, "DG", 0)

f = interpolate(Expression("x[0]+x[1]", name='Control', degree=1), W)
u = Function(V, name='State')
v = TestFunction(V)

F = (inner(grad(u), grad(v)) - f * v) * dx
bc = DirichletBC(V, 0.0, "on_boundary")
solve(F == 0, u, bc)


x = SpatialCoordinate(mesh)
w = Expression("sin(pi*x[0])*sin(pi*x[1])", degree=3)
d = 1 / (2 * pi ** 2)
d = Expression("d*w", d=d, w=w, degree=3)

alpha = Constant(1e-6)
J = assemble((0.5 * inner(u - d, u - d)) * dx + alpha / 2 * f ** 2 * dx)
control = Control(f)


rf = ReducedFunctional(J, control)


problem = MoolaOptimizationProblem(rf)
f_moola = moola.DolfinPrimalVector(f)


BA = BetaApproximations(jrandom.PRNGKey(0))
start_time = time.time()
# res = BA.multi_two_B_one(rf, f)
res = BA.multi_two_B_one(problem.obj, f_moola)

print("time", time.time() - start_time)




time 5.143530368804932


In [5]:
res.data.vector().get_local()

array([-0.30166127, -0.29917819, -0.2746284 , -0.31940874, -0.24310204,
       -0.0057808 , -1.01043119, -0.49620249, -0.31701142, -0.26429647,
        0.23372132,  0.24267653, -0.03265319,  0.26577341, -0.74947903,
       -0.52948927, -0.01439345, -0.23969558,  0.27276698, -0.01738293,
       -0.26765003, -0.26068044, -1.27390548, -1.33091054, -0.49705388,
       -1.01470235, -0.51571138, -0.75715749, -1.32440965, -1.27652431,
       -1.81784233, -1.8029376 ])

In [11]:
true_dir = np.array([ 0.2952727 ,  0.2952727 ,  0.26725353,  0.31252623,
              0.24504388,  0.01229741,  1.        ,  0.49504388,
              0.31252623,  0.26725353, -0.24167915, -0.24167915,
              0.02955094, -0.26969832,  0.74504388,  0.51229741,
              0.01229741,  0.24504388, -0.26969832,  0.02955094,
              0.25832085,  0.25832085,  1.26725353,  1.31252623,
              0.49504388,  1.        ,  0.51229741,  0.74504388,
              1.31252623,  1.26725353,  1.7952727 ,  1.7952727 ])

In [26]:
(-res.vector().get_local() - true_dir)/true_dir

array([ 0.02163618,  0.01322672,  0.02759504,  0.02202219, -0.00792448,
       -0.5299175 ,  0.01043119,  0.00234042,  0.0143514 , -0.01106464,
       -0.03292726,  0.00412686,  0.10497961, -0.01455297,  0.00595287,
        0.03355836,  0.170446  , -0.0218259 ,  0.01137811, -0.41176383,
        0.0361147 ,  0.00913434,  0.00524911,  0.01400681,  0.00406024,
        0.01470235,  0.00666405,  0.01625892,  0.00905385,  0.00731565,
        0.0125717 ,  0.00426949])

In [12]:
start_time= time.time()
# curr_f = Function(f_moola.data.function_space())

f.vector().set_local(true_dir)

# curr_f = moola.DolfinPrimalVector(curr_f)
print(time.time() - start_time)
start_time = time.time()
rf(f)

rf.derivative().vector().get_local()
print(time.time() - start_time)

0.00015091896057128906
0.08802056312561035
