In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.example_libraries import stax, optimizers
import matplotlib.pyplot as plt
import pinns 
import datetime
import jax.scipy.optimize
import jax.flatten_util
import scipy
import scipy.optimize

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)
rnd_key = jax.random.PRNGKey(1234)

In [None]:
w = 0.5
L = 2

knots = np.array([ [[0,0],[0,L/2],[0,L]] , [[w,0],[w,L/2],[w,L]] ])
weights = np.ones(knots.shape[:2])
basis1 = pinns.bspline.BSplineBasis(np.array([-1,1]),1)
basis2 = pinns.bspline.BSplineBasis(np.array([-1,1]),2)

geom = pinns.geometry.PatchNURBS([basis1, basis2], knots, weights, rnd_key)

pts,_ = geom.importance_sampling(10000)

plt.figure()
plt.scatter(pts[:,0], pts[:,1], s = 1)

In [None]:

def interface_function2d(bases, dim, end, nn, opened = (False,False), reversed = False):

    if not opened[0]:
        f1 = lambda x: (x-bases[1-dim].interval[0])/ (bases[1-dim].interval[1]-bases[1-dim].interval[0])
    else:
        f1 = lambda x: 1.0

    if not opened[1]:
        f2 = lambda x: (x-bases[1-dim].interval[1])/ (bases[1-dim].interval[0]-bases[1-dim].interval[1])
    else:
        f2 = lambda x: 1.0

    faux = lambda x: ((x-bases[dim].interval[0 if end>0 else 1])**1/(bases[dim].interval[1 if end>0 else 0]-bases[dim].interval[0 if end>0 else 1])**1)
    if dim == 0:
        fret = lambda ws, x: (f1(x[...,1])*f2(x[...,1])*nn(ws, x[...,1][...,None]).flatten()*faux(x[...,0]))[...,None]
    else:
        fret = lambda ws, x: (f1(x[...,0])*f2(x[...,0])*nn(ws, x[...,0][...,None]).flatten()*faux(x[...,1]))[...,None]
    return fret


class Model(pinns.PINN):
    def __init__(self, rand_key):
        super().__init__()
        self.key = rand_key

        N = [80,80]
        nl = 8
        acti = stax.Tanh #stax.elementwise(lambda x: jax.nn.leaky_relu(x)**2)
        block = stax.serial(stax.FanOut(2),stax.parallel(stax.serial(stax.Dense(nl), acti, stax.Dense(nl), acti),stax.Dense(nl)),stax.FanInSum)
        block2 = lambda n: stax.serial(stax.FanOut(2),stax.parallel(stax.serial(stax.Dense(n), acti, stax.Dense(n), acti),stax.Dense(n)),stax.FanInSum)
 
        self.add_neural_network('u1',stax.serial(block,block,block,block,block, stax.Dense(1)),(-1,2)) # iron
        self.init_points(N)

        self.freq = 4
        

    def init_points(self, N):        

        self.points = {}

        ys = np.random.rand(10000,2)*2-1
        Weights = np.ones((10000,))*2
        ys = np.meshgrid(np.polynomial.legendre.leggauss(N[0])[0], np.polynomial.legendre.leggauss(N[1])[0])
        ys = np.concatenate((ys[0].flatten()[:,None], ys[1].flatten()[:,None]), -1)
        Weights = np.kron(np.polynomial.legendre.leggauss(N[0])[1], np.polynomial.legendre.leggauss(N[1])[1]).flatten()
        
        # ys, Weights = pinns.geometry.tensor_product_integration(geom1.basis, N)
        self.points['ys'] = ys
        self.points['ws'] = Weights
        DGys = geom._eval_omega(ys)
        Inv = np.linalg.inv(DGys)
        det = np.abs(np.linalg.det(DGys))
        self.points['K'] = np.einsum('mij,mjk,m->mik',Inv,np.transpose(Inv,[0,2,1]),det)
        self.points['omega'] = det
       
        
       
        

    def solution(self, ws, x):
        # iron
        u = self.neural_networks['u1'](ws['u1'],x)
        v = ((1-x[...,0])*(x[...,0] + 1)*(1-x[...,1])*(x[...,1]+1))[...,None]
        w =  (jnp.cos(np.pi/2*x[...,0])*(1-x[...,1]))[...,None]
        return u*v+w

    
    def loss_pde(self, ws):
        grad = pinns.operators.gradient(lambda x : self.solution(ws,x))(self.points['ys'])  
        uval = self.solution(ws,self.points['ys']).flatten()   
        lpde = 0.5*jnp.dot(jnp.einsum('mi,mij,mj->m',grad,self.points['K'],grad), self.points['ws']) - self.freq*jnp.dot(uval,uval)
        return lpde

    def loss(self, ws):
        lpde = self.loss_pde(ws)
        return lpde
    

In [None]:
rnd_key = jax.random.PRNGKey(4321)
model = Model(rnd_key)
w0 = model.init_unravel()
weights = model.weights 

In [None]:
loss_compiled = jax.jit(model.loss_handle)
lossgrad_compiled = jax.jit(model.lossgrad_handle)

print('Starting optimization')

def loss_grad(w):
    l, gr = lossgrad_compiled(jnp.array(w))
    return np.array( l.to_py() ), np.array( gr.to_py() ) 

tme = datetime.datetime.now()
#results = jax.scipy.optimize.minimize(loss_grad, x0 = weights_vector, method = 'bfgs', options = {'maxiter': 10})
# result = scipy.optimize.minimize(loss_grad, x0 = w0.to_py(), method = 'BFGS', jac = True, tol = 1e-8, options = {'disp' : True, 'maxiter' : 2000}, callback = None)
result = scipy.optimize.minimize(loss_grad, x0 = w0.to_py(), method = 'L-BFGS-B', jac = True, tol = 1e-11, options = {'disp' : True, 'maxiter' : 3000, 'iprint': 1})
tme = datetime.datetime.now() - tme

weights = model.weights_unravel(jnp.array(result.x))
model.weights = weights
print()
print('Elapsed time', tme)

In [None]:
x,y = np.meshgrid(np.linspace(-1,1,100),np.linspace(-1,1,100))
ys = np.concatenate((x.flatten()[:,None],y.flatten()[:,None]),1)
xy1 = geom(ys)


u1 = model.solution(weights, ys).reshape(x.shape)
plt.figure()

plt.contourf(xy1[:,0].reshape(x.shape), xy1[:,1].reshape(x.shape), u1, levels = 100)
plt.colorbar()