In [None]:
import fenics as fn
import matplotlib.pyplot as plt
from dolfin_utils.meshconvert import meshconvert
import os
from subprocess import call

import numpy as np
import jax
import jax.numpy as jnp
from jax.example_libraries import stax, optimizers
import matplotlib.pyplot as plt
import pinns 
import datetime
import jax.scipy.optimize
import jax.flatten_util
import scipy
import scipy.optimize


In [None]:
class FEM():
            
    def __init__(self,rhs=100,c1=1.0,c2=10.0, meshsize = 0.05, verb = False):
        path='./slice/'
        
        with open(path + 'slice.geo', 'r') as file:
            data = file.read()
            
        s = "meshsize=%.18f;\n"%(meshsize)
        
        s = s + data
        
        with  open(path+"tmp.geo", "w") as file:
            file.write(s)
            file.close()
        if verb: print('geo file created',flush = True)
        
        if verb:
            os.system('gmsh %stmp.geo -nt 20 -3 -o %stmp.msh -format msh2 '%(path,path))
        else:
            os.system('gmsh %stmp.geo -nt 20 -3 -o %stmp.msh -format msh2 >/dev/null 2>&1'%(path,path))
        if verb: print('mesh file created',flush=True)

        if verb:
            os.system('dolfin-convert %stmp.msh %stmp.xml'%(path,path))
        else:
            os.system('dolfin-convert %stmp.msh %stmp.xml >/dev/null 2>&1'%(path,path))
        
        if verb: print('mesh file converted in fenics format',flush=True) 

        mesh = fn.Mesh(path+'tmp.xml')
        domains = fn.MeshFunction("size_t", mesh, path+'tmp_physical_region.xml')
        boundaries = fn.MeshFunction('size_t', mesh, path+'tmp_facet_region.xml')

        self.mesh = mesh
        ncells = [  mesh.num_vertices(), mesh.num_edges(), mesh.num_faces(), mesh.num_facets(), mesh.num_cells() ]
        
        
        def nonlin(u):
            
            return c2
        
        
        # Coil
        def setup_coil(mesh,subdomains):
            DG = fn.FunctionSpace(mesh,"DG",0)
            J = fn.Function(DG)
            idx = []
            for cell_no in range(len(subdomains.array())):
                subdomain_no = subdomains.array()[cell_no]
                if subdomain_no == 11:
                    idx.append(cell_no)
            J.vector()[:] = 0
            J.vector()[idx] = rhs
            return J
        
    
        
        """ define function space and boundary conditions"""
        
        CG = fn.FunctionSpace(mesh, 'CG', 1) # Continuous Galerkin
        
        # Define boundary condition
        bc = fn.DirichletBC(CG, fn.Constant(0.0), boundaries,10)
        
        # Define subdomain markers and integration measure
        dx = fn.Measure('dx', domain=mesh, subdomain_data=domains)
        
        J = setup_coil(mesh, domains)
        
        class Coefficient(fn.UserExpression): # UserExpression instead of Expression
            def __init__(self, markers, **kwargs):
                super().__init__(**kwargs) # This part is new!
                self.markers = markers
            def eval_cell(self, values, x, cell):
                if self.markers[cell.index] == 11:
                    values[0] = 0.0   # iron
                elif self.markers[cell.index] == 12:
                    values[0] = c1      # air

                else:
                    print('no such domain',self.markers[cell.index] )
                    
        coeff = Coefficient(domains, degree=1)
        
        
        """ weak formulation """
        
        az  = fn.Function(CG)
        u  = fn.Function(CG)
        v  = fn.TestFunction(CG)
        #az = Function(CG)
        #a  = (1/mu)*dot(grad(az), grad(v))*dx
        a = fn.inner(coeff*fn.grad(u), fn.grad(v))*dx + fn.inner(nonlin(fn.grad(u))*fn.grad(u),fn.grad(v))*dx(11)
        L  = J*v*dx(11)
        
        F = a - L
        # solve variational problem
        fn.solve(F == 0, u, bc)
        az = u
        self.az = az
        # function space for H- and B- field allocated on faces of elements
        W = fn.VectorFunctionSpace(mesh, fn.FiniteElement("DP", fn.triangle, 0),1)
        self.G = fn.project(fn.grad(az), W)
        
    
        
    
    def __call__(self,x_eval,y_eval):
        
        
        Afem = 0 * x_eval
        for i in range(x_eval.size):
            try:
                Afem[i] = self.az(x_eval[i],y_eval[i])
            except:
                Afem[i] = np.nan
        return Afem
    
    def grad(self,x_eval,y_eval):
        
        
        Gfem = []
        for i in range(x_eval.size):
            try:
                Gfem.append(self.G(x_eval[i],y_eval[i]))
            except:
                Gfem.append([ np.nan , np.nan])
        return np.array(Gfem)

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)
rnd_key = jax.random.PRNGKey(1234)

In [None]:

def create_geometry(key, scale = 1):
    R = 1
    r = 0.2
    
    knots = np.array([ [[0,r],[0,R]], [[r/2,r],[np.tan(np.pi/8)*R,R]] , [[r,r],[R/np.sqrt(2),R/np.sqrt(2)]]])
    weights = np.ones(knots.shape[:2])
    weights[1,-1] = np.sin(np.pi/4)
    basis1 = pinns.bspline.BSplineBasis(np.array([-1,1]),2)
    basis2 = pinns.bspline.BSplineBasis(np.array([-1,1]),1)

    geom1 = pinns.geometry.PatchNURBS([basis1, basis2], knots, weights, key)
   
    knots = np.array([ [[r,0],[r,r/2],[r,r]], [[R,0],[R,np.tan(np.pi/8)*R],[R/np.sqrt(2),R/np.sqrt(2)]] ])
    weights = np.ones(knots.shape[:2])
    weights[1,1] = np.sin(np.pi/4)
    basis1 = pinns.bspline.BSplineBasis(np.array([-1,1]),1)
    basis2 = pinns.bspline.BSplineBasis(np.array([-1,1]),2)

    geom2 = pinns.geometry.PatchNURBS([basis1, basis2], knots, weights, key)

    knots = np.array([ [[0,0],[0,r/2],[0,r]], [[r/2,0],[r/2,r/2],[r/2,r]], [[r,0],[r,r/2],[r,r]] ])
    weights = np.ones(knots.shape[:2])
    
    basis1 = pinns.bspline.BSplineBasis(np.array([-1,1]),2)
    basis2 = pinns.bspline.BSplineBasis(np.array([-1,1]),2)

    geom3 = pinns.geometry.PatchNURBS([basis1, basis2], knots, weights, key)

    
    return  geom1, geom2, geom3

In [None]:

geom1, geom2, geom3 = create_geometry(rnd_key)

pts,_ = geom1.importance_sampling(10000)

plt.figure()
plt.scatter(pts[:,0], pts[:,1], s = 1)

pts,_ = geom2.importance_sampling(10000)
plt.scatter(pts[:,0],pts[:,1], s = 1)

pts,_ = geom3.importance_sampling(10000)
plt.scatter(pts[:,0],pts[:,1], s = 1)

In [None]:
def interface_function2d(bases, dim, end, nn, opened = (False,False), reversed = False):

    if not opened[0]:
        f1 = lambda x: (x-bases[1-dim].interval[0])/ (bases[1-dim].interval[1]-bases[1-dim].interval[0])
    else:
        f1 = lambda x: 1.0

    if not opened[1]:
        f2 = lambda x: (x-bases[1-dim].interval[1])/ (bases[1-dim].interval[0]-bases[1-dim].interval[1])
    else:
        f2 = lambda x: 1.0

    faux = lambda x: ((x-bases[dim].interval[0 if end>0 else 1])**1/(bases[dim].interval[1 if end>0 else 0]-bases[dim].interval[0 if end>0 else 1])**1)
    if dim == 0:
        fret = lambda ws, x: (f1(x[...,1])*f2(x[...,1])*nn(ws, x[...,1][...,None]).flatten()*faux(x[...,0]))[...,None]
    else:
        fret = lambda ws, x: (f1(x[...,0])*f2(x[...,0])*nn(ws, x[...,0][...,None]).flatten()*faux(x[...,1]))[...,None]
    return fret


class Model(pinns.PINN):
    def __init__(self, rand_key):
        super().__init__()
        self.key = rand_key

        N = [80,80]
        nl = 8
        acti = stax.Tanh #stax.elementwise(lambda x: jax.nn.leaky_relu(x)**2)
        block = stax.serial(stax.FanOut(2),stax.parallel(stax.serial(stax.Dense(nl), acti, stax.Dense(nl), acti),stax.Dense(nl)),stax.FanInSum)
        block2 = lambda n: stax.serial(stax.FanOut(2),stax.parallel(stax.serial(stax.Dense(n), acti, stax.Dense(n), acti),stax.Dense(n)),stax.FanInSum)
 
        self.add_neural_network('u1',stax.serial(block,block,block,block,block, stax.Dense(1)),(-1,2)) # iron
        self.add_neural_network('u2',stax.serial(block,block,block,block,block, stax.Dense(1)),(-1,2)) # air 
        self.add_neural_network('u3',stax.serial(block,block,block,block,block, stax.Dense(1)),(-1,2)) # copper
        self.add_neural_network('u12',stax.serial(block, block,block, block, stax.Dense(1)),(-1,1))
        self.add_neural_network('u13',stax.serial(block, block,block, block, stax.Dense(1)),(-1,1))
        self.add_neural_network('u23',stax.serial(block, block,block, block, stax.Dense(1)),(-1,1))
        self.add_trainable_parameter('u123',(1,))
        self.init_points(N)
        
        self.interface12 = interface_function2d(geom1.basis,0,1,self.neural_networks['u12'])
        self.interface21 = interface_function2d(geom2.basis,1,1,self.neural_networks['u12'])
        self.interface23 = interface_function2d(geom2.basis,0,0,self.neural_networks['u23'])
        self.interface32 = interface_function2d(geom3.basis,0,1,self.neural_networks['u23'])
        self.interface13 = interface_function2d(geom1.basis,1,0,self.neural_networks['u13'])
        self.interface31 = interface_function2d(geom3.basis,1,1,self.neural_networks['u13'])
        self.c1 = 1
        self.c2 = 10
        self.rhs = 100

    def init_points(self, N):        

        self.points = {}

        ys = np.random.rand(10000,2)*2-1
        Weights = np.ones((10000,))*2
        ys = np.meshgrid(np.polynomial.legendre.leggauss(N[0])[0], np.polynomial.legendre.leggauss(N[1])[0])
        ys = np.concatenate((ys[0].flatten()[:,None], ys[1].flatten()[:,None]), -1)
        Weights = np.kron(np.polynomial.legendre.leggauss(N[0])[1], np.polynomial.legendre.leggauss(N[1])[1]).flatten()
        
        # ys, Weights = pinns.geometry.tensor_product_integration(geom1.basis, N)
        self.points['ys1'] = ys
        self.points['ws1'] = Weights
        DGys = geom1._eval_omega(ys)
        Inv = np.linalg.inv(DGys)
        det = np.abs(np.linalg.det(DGys))
        self.points['K1'] = np.einsum('mij,mjk,m->mik',Inv,np.transpose(Inv,[0,2,1]),det)
        self.points['omega1'] = det
       
        # ys, Weights = pinns.geometry.tensor_product_integration(geom2.basis, N)
        self.points['ys2'] = ys
        self.points['ws2'] = Weights
        DGys = geom2._eval_omega(ys)
        Inv = np.linalg.inv(DGys)
        det = np.abs(np.linalg.det(DGys))
        self.points['K2'] = np.einsum('mij,mjk,m->mik',Inv,np.transpose(Inv,[0,2,1]),det)
        self.points['omega2'] = det
        
        # ys, Weights = pinns.geometry.tensor_product_integration(geom3.basis, N)
        self.points['ys3'] = ys
        self.points['ws3'] = Weights
        DGys = geom3._eval_omega(ys)
        Inv = np.linalg.inv(DGys)
        det = np.abs(np.linalg.det(DGys))
        self.points['K3'] = np.einsum('mij,mjk,m->mik',Inv,np.transpose(Inv,[0,2,1]),det)
        self.points['omega3'] = det
       
        

    def solution1(self, ws, x):
        # iron
        u = self.neural_networks['u1'](ws['u1'],x)
        v = ((1-x[...,0])*(x[...,0] + 1)*(1-x[...,1])*(x[...,1]+1))[...,None]
        w =  self.interface12(ws['u12'],x) + self.interface13(ws['u13'],x)
        w = w + ws['u123']*( (x[...,0]+1) * (1-x[...,1]) )[...,None]**3
        return u*v+w

    def solution2(self, ws, x):
        
        u = self.neural_networks['u2'](ws['u2'],x)
        v = ((1-x[...,0])*(x[...,0] + 1)*(1-x[...,1])*(x[...,1]+1))[...,None]
        w = self.interface21(ws['u12'],x) + self.interface23(ws['u23'],x)
        w = w + ws['u123']*( (1-x[...,0]) * (1+x[...,1]) )[...,None]**3
        return u*v+w
    
    def solution3(self, ws, x):
        
        u = self.neural_networks['u3'](ws['u3'],x)*0
        v = ((1-x[...,0])*(x[...,0] + 1)*(1-x[...,1])*(x[...,1]+1))[...,None]
        w =  self.interface32(ws['u23'],x) + self.interface31(ws['u13'],x)
        w = w + ws['u123']*( (x[...,0]+1) * (x[...,1]+1) )[...,None]**3
        return (u*v + w)
        
    def loss_pde(self, ws):
        grad1 = pinns.operators.gradient(lambda x : self.solution1(ws,x))(self.points['ys1'])[...,0,:]
        grad2 = pinns.operators.gradient(lambda x : self.solution2(ws,x))(self.points['ys2'])[...,0,:]
        grad3 = pinns.operators.gradient(lambda x : self.solution3(ws,x))(self.points['ys3'])[...,0,:]
        
        
        lpde1 = 0.5*self.c1*jnp.dot(jnp.einsum('mi,mij,mj->m',grad1,self.points['K1'],grad1), self.points['ws1']) 
        lpde2 = 0.5*self.c1*jnp.dot(jnp.einsum('mi,mij,mj->m',grad2,self.points['K2'],grad2), self.points['ws2'])  
        lpde3 = 0.5*self.c2*jnp.dot(jnp.einsum('mi,mij,mj->m',grad3,self.points['K3'],grad3), self.points['ws3'])  - jnp.dot(self.rhs*self.solution3(ws,self.points['ys3']).flatten()*self.points['omega3']  ,self.points['ws3'])
        return lpde1+lpde2+lpde3

    def loss(self, ws):
        lpde = self.loss_pde(ws)
        return lpde
    

In [None]:
rnd_key = jax.random.PRNGKey(4321)
model = Model(rnd_key)
w0 = model.init_unravel()
weights = model.weights 



In [None]:


loss_compiled = jax.jit(model.loss_handle)
lossgrad_compiled = jax.jit(model.lossgrad_handle)

print('Starting optimization')

def loss_grad(w):
    l, gr = lossgrad_compiled(jnp.array(w))
    return np.array( l.to_py() ), np.array( gr.to_py() ) 

tme = datetime.datetime.now()
#results = jax.scipy.optimize.minimize(loss_grad, x0 = weights_vector, method = 'bfgs', options = {'maxiter': 10})
# result = scipy.optimize.minimize(loss_grad, x0 = w0.to_py(), method = 'BFGS', jac = True, tol = 1e-8, options = {'disp' : True, 'maxiter' : 2000}, callback = None)
result = scipy.optimize.minimize(loss_grad, x0 = w0.to_py(), method = 'L-BFGS-B', jac = True, tol = 1e-11, options = {'disp' : True, 'maxiter' : 3000, 'iprint': 1})
tme = datetime.datetime.now() - tme

weights = model.weights_unravel(jnp.array(result.x))
model.weights = weights
print()
print('Elapsed time', tme)

In [None]:
x,y = np.meshgrid(np.linspace(-1,1,100),np.linspace(-1,1,100))
ys = np.concatenate((x.flatten()[:,None],y.flatten()[:,None]),1)
xy1 = geom1(ys)
xy2 = geom2(ys)
xy3 = geom3(ys)

u1 = model.solution1(weights, ys).reshape(x.shape)
u2 = model.solution2(weights, ys).reshape(x.shape)
u3 = model.solution3(weights, ys).reshape(x.shape)

plt.figure()
ax = plt.gca()
ax.contourf(xy1[:,0].reshape(x.shape), xy1[:,1].reshape(x.shape), u1, levels = 100, vmin = min([u1.min(),u2.min(),u3.min()]), vmax = max([u1.max(),u2.max(),u3.max()]))
ax.contourf(xy2[:,0].reshape(x.shape), xy2[:,1].reshape(x.shape), u2, levels = 100, vmin = min([u1.min(),u2.min(),u3.min()]), vmax = max([u1.max(),u2.max(),u3.max()]))
ax.contourf(xy3[:,0].reshape(x.shape), xy3[:,1].reshape(x.shape), u3, levels = 100, vmin = min([u1.min(),u2.min(),u3.min()]), vmax = max([u1.max(),u2.max(),u3.max()]))
# plt.colorbar()

In [None]:
fem = FEM(rhs = model.rhs, c1 = model.c1, c2 = model.c2,meshsize = 0.01)

u1_ref = fem(xy1[:,0],xy1[:,1]).reshape(x.shape)
u2_ref = fem(xy2[:,0],xy2[:,1]).reshape(x.shape)
u3_ref = fem(xy3[:,0],xy3[:,1]).reshape(x.shape)

plt.figure()
ax = plt.gca()
ax.contourf(xy1[:,0].reshape(x.shape), xy1[:,1].reshape(x.shape), u1_ref, levels = 100, vmin = min([np.nanmin(u1_ref),np.nanmin(u2_ref),np.nanmin(u3_ref)]), vmax = max([np.nanmax(u1_ref),np.nanmax(u2_ref),np.nanmax(u3_ref)]))
ax.contourf(xy2[:,0].reshape(x.shape), xy2[:,1].reshape(x.shape), u2_ref, levels = 100, vmin = min([np.nanmin(u1_ref),np.nanmin(u2_ref),np.nanmin(u3_ref)]), vmax = max([np.nanmax(u1_ref),np.nanmax(u2_ref),np.nanmax(u3_ref)]))
ax.contourf(xy3[:,0].reshape(x.shape), xy3[:,1].reshape(x.shape), u3_ref, levels = 100, vmin = min([np.nanmin(u1_ref),np.nanmin(u2_ref),np.nanmin(u3_ref)]), vmax = max([np.nanmax(u1_ref),np.nanmax(u2_ref),np.nanmax(u3_ref)]))
#plt.colorbar()

delta1 = np.abs(u1-u1_ref)
delta2 = np.abs(u2-u2_ref)
delta3 = np.abs(u3-u3_ref)

plt.figure()
ax = plt.gca()
ax.contourf(xy1[:,0].reshape(x.shape), xy1[:,1].reshape(x.shape), delta1, levels = 100, vmin = min([np.nanmin(delta1),np.nanmin(delta2),np.nanmin(delta3)]), vmax = max([np.nanmax(delta1),np.nanmax(delta2),np.nanmax(delta3)]))
ax.contourf(xy2[:,0].reshape(x.shape), xy2[:,1].reshape(x.shape), delta2, levels = 100, vmin = min([np.nanmin(delta1),np.nanmin(delta2),np.nanmin(delta3)]), vmax = max([np.nanmax(delta1),np.nanmax(delta2),np.nanmax(delta3)]))
ax.contourf(xy3[:,0].reshape(x.shape), xy3[:,1].reshape(x.shape), delta3, levels = 100, vmin = min([np.nanmin(delta1),np.nanmin(delta2),np.nanmin(delta3)]), vmax = max([np.nanmax(delta1),np.nanmax(delta2),np.nanmax(delta3)]))
plt.show()

In [None]:
xs1 = geom1(model.points['ys1'])
xs2 = geom2(model.points['ys2'])
xs3 = geom3(model.points['ys3'])

u1_ref = fem(xs1[:,0],xs1[:,1]).flatten()
u2_ref = fem(xs2[:,0],xs2[:,1]).flatten()
u3_ref = fem(xs3[:,0],xs3[:,1]).flatten()

u1 = model.solution1(weights, model.points['ys1']).flatten()
u2 = model.solution2(weights, model.points['ys2']).flatten()
u3 = model.solution3(weights, model.points['ys3']).flatten()

int1 = jnp.nansum((u1-u1_ref)**2*model.points['omega1']*model.points['ws1']) / jnp.nansum((u1_ref)**2*model.points['omega1']*model.points['ws1'])
int2 = jnp.nansum((u2-u2_ref)**2*model.points['omega2']*model.points['ws2']) / jnp.nansum((u2_ref)**2*model.points['omega2']*model.points['ws2'])
int3 = jnp.nansum((u3-u3_ref)**2*model.points['omega3']*model.points['ws3']) / jnp.nansum((u3_ref)**2*model.points['omega3']*model.points['ws3'])

print(int1, int2, int3)


In [None]:
print(np.nanmean(delta1)/np.max(u1))
print(np.nanmean(delta2)/np.max(u2))
print(np.nanmean(delta3)/np.max(u3))

In [None]:
plt.figure()
plt.scatter(xs1[:,0],xs1[:,1], s= 1)
plt.scatter(xs2[:,0],xs2[:,1], s= 1)
plt.scatter(xs3[:,0],xs3[:,1], s= 1)

plt.figure()
fn.plot(fem.mesh)

In [None]:
t = np.linspace(-1,1,1000)
z = model.solution3(weights,np.concatenate((t[:,None],t[:,None]*0+1),1)).flatten()
u = model.interface31(weights['u13'],np.concatenate((t[:,None],t[:,None]*0+1),1)).flatten()
print(u.shape)
xy = geom3(np.concatenate((t[:,None],t[:,None]*0+1),1))
plt.plot(t,z)
plt.plot(t,fem(xy[:,0],xy[:,1]))
plt.plot(t,u)