Imports:

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.example_libraries import stax, optimizers
from flax import nnx
import matplotlib.pyplot as plt
import pyvista as pv
import pinns 
import datetime
import jax.scipy.optimize
import jax.flatten_util
import scipy
import scipy.optimize
import optax
import random
from jaxkan.KAN import KAN

rnd_key = jax.random.PRNGKey(1234)
np.random.seed(14124)

Set the default precision and the execution device.

In [None]:
jax.config.update("jax_enable_x64", False)
print("Devices: ", jax.devices())
dev = jax.devices('gpu')[0] if jax.device_count()>1 and len(jax.devices('gpu'))>0 else jax.devices('cpu')[0]
dev = jax.devices('gpu')[0]
print(dev)


### Geometry definition 

Define the geometry patches:

In [None]:
def get_domain(r0: float, r1: float, R: float, h: float, H: float):

    angle = np.pi/4
    basis1 = pinns.functions.BSplineBasisJAX(np.array([-1, 0, 1]), 2)
    basis2 = pinns.functions.BSplineBasisJAX(np.array([-1, 1]), 2)
    basis3 = pinns.functions.BSplineBasisJAX(np.array([-1, 1]), 1)

    def tmp_gen(angle, r_0, r_1):
        pts = np.zeros([4, 3, 2, 3])
        weights = np.ones([4, 3, 2])

        a = np.pi/2-angle/2
        rs = np.linspace(r_0, r_1, 4)
        pts[-1, 0, 0, :] = [np.cos(-angle/2), np.sin(-angle/2), 0]
        pts[-1, 1, 0, :] = [1/np.sin(a), 0, 0]
        pts[-1, 2, 0, :] = [np.cos(angle/2), np.sin(angle/2), 0]
        pts[0, :, 0, :2] = rs[0] * pts[-1, :, 0, :2]
        pts[1, :, 0, :2] = rs[1] * pts[-1, :, 0, :2]
        pts[2, :, 0, :2] = rs[2] * pts[-1, :, 0, :2]
        pts[3, :, 0, :2] = rs[3] * pts[-1, :, 0, :2]
        pts[0, :, 0, 2] = -1
        pts[1, :, 0, 2] = -1
        pts[2, :, 0, 2] = -1
        pts[3, :, 0, 2] = -1
        pts[:, :, 1, :] = pts[:, :, 0, :]
        pts[:, :, 1, 2] = -pts[:, :, 1, 2]
        weights[:, 1, :] = np.sin(a)

        return pts, weights

    geoms = dict()

    pts, weights = tmp_gen(angle, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2
    pts[3, 1, :, 0] = pts[3, 0, :, 0]
    pts[1, 1, :, 0] = 2*pts[0, 1, :, 0]/3+pts[-1, 1, :, 0]/3
    pts[2, 1, :, 0] = pts[0, 1, :, 0]/3+2*pts[-1, 1, :, 0]/3
    weights[-1, 1, :] = 1.0

    geoms['flat'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts.copy(), weights.copy(), 0, 3)

    pts2 = pts[-1, :, :, :]
    weights[...] = 1.0
    linsp = np.linspace(0, 1, basis1.n)

    pts[0, :, :, :] = pts2
    pts[-1, :, :, :] = pts2
    pts[-1, :, :, 0] *= R/r1
    pts[-1, :, :, 1] *= H/h
    # pts[0, :, :, 2] *= H/h

    for i in range(1, basis1.n-1):
        pts[i, :, :, 2] = (1-linsp[i]**0.25)*pts[0, :, :, 2] + \
            linsp[i]**0.25*pts[-1, :, :, 2]
        pts[i, :, :, 0] = (1-linsp[i])*pts[0, :, :, 0] + \
            linsp[i]*pts[-1, :, :, 0]
        pts[i, :, :, 1] = (1-linsp[i]**4)*pts[0, :, :, 1] + \
            linsp[i]**4*pts[-1, :, :, 1]
        pts[i, :, :, 1] *= 2*(linsp[i]-1/2)**2+0.5
        pts[i, :, :, 2] *= 2*(linsp[i]-1/2)**2+0.5

    geoms['spoke'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)

    pts, weights = tmp_gen((2*np.pi-angle)/3, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_0'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_0'].rotate((0, 0, angle/2 + (2*np.pi-angle)/3/2))

    pts, weights = tmp_gen((2*np.pi-angle)/3, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_1'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_1'].rotate((0, 0, np.pi))

    pts, weights = tmp_gen((2*np.pi-angle)/3, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_2'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_2'].rotate((0, 0, (2*np.pi-angle)/3 + np.pi))

    return geoms


geoms = get_domain(0.5, 0.8, 3.0, 1.0, 1.5)
names = list(geoms.keys())

Export as a VTK file for visualization in paraview.

In [None]:
with jax.disable_jit():
    objects = [pinns.extras.plot(geoms[n], dict(), N=32) for n in geoms]

obj = objects[0]
for i in range(1, len(objects)):
    obj = obj.merge(objects[i])

obj.save('testing.vtk')

Determine the connectivity of the patches:

In [None]:
with jax.disable_jit(True):
    connectivity = pinns.geometry.match_patches(geoms, eps=1e-4, verbose=False)

for c in connectivity:
    print(c)

In [None]:
for k in geoms:
    pinns.geometry.save_patch('holder_'+k+'.geom', geoms[k])

import pickle 

with open('connectivity_holder.pkl', 'wb') as file: 
    pickle.dump(connectivity, file) 

### ANN spaces definition

The network is an MLP with residual connections and width set by the `nl` paraemter.
There are 2 spaces defined: first has 0 Dirichlet BCs on one facet and the other has no Dirichlet BCs enforced. 

In [None]:
n_layers = 16
acti =  stax.elementwise(lambda x: jax.nn.leaky_relu(x)**2)
w_init = jax.nn.initializers.normal()

class PIMLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dmid, rngs=rngs)
    self.linear3 = nnx.Linear(dmid, dmid, rngs=rngs)
    self.linear4 = nnx.Linear(dmid, dmid, rngs=rngs)
    self.linear5 = nnx.Linear(dmid, dmid, rngs=rngs)
    self.linear6 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    #x = jnp.concat(inputs, axis=-1)
    x1 = nnx.leaky_relu(self.linear1(x))**2
    x2 = nnx.leaky_relu(self.linear2(x1))**2
    x3 = nnx.leaky_relu(self.linear3(x2))**2
    x4 = nnx.leaky_relu(self.linear4(x3))**2
    x5 = nnx.leaky_relu(self.linear5(x4))**2+x2
    return self.linear6(x5)

def monom(x, axis, zero, one, alpha=1.0):
    return ((x[...,axis]-zero)[...,None]/(one-zero))**alpha
  
class MultipatchPIMLP(nnx.Module):
  def __init__(self, dmid: int, rngs: nnx.Rngs):
    names_nn_2d = ["flat_spoke", "flat_round_0", "flat_round_2", "round_0_round_1", "round_1_round_2"]
    names_nn_1d = ["spoke_round_0_flat", "spoke_round_2_flat"]
    #self.nns = {n: KAN(layer_dims = [5,12,12,3], layer_type = 'base', required_parameters =  {'k': 3}, add_bias = True, rngs =rngs) for n in names}
    #self.nns.update({n: KAN(layer_dims = [4,12,12,3], layer_type = 'base', required_parameters =  {'k': 3}, add_bias = True, rngs =rngs) for n in names_nn_2d})
    #self.nns.update({n: KAN(layer_dims = [3,12,12,3], layer_type = 'base', required_parameters =  {'k': 3}, add_bias = True, rngs =rngs) for n in names_nn_1d})
    self.nns = {n: PIMLP(5, dmid, 3, rngs) for n in names}
    self.nns.update({n: PIMLP(4, dmid, 3, rngs) for n in names_nn_2d})
    self.nns.update({n: PIMLP(3, dmid, 3, rngs) for n in names_nn_1d})
  
  def __call__(self, domain_name: str, x: jax.Array, p: jax.Array):
    
    if domain_name == "spoke":
      return self.nns["spoke"](jnp.concat((x, p), axis=-1))*monom(x, 0, -1, 1)*monom(x, 0, 1, -1 ) \
           + self.nns["flat_spoke"](jnp.concat((x[...,1:3], p), axis=-1))*monom(x, 0, 1, -1)*monom(x, 1, 1, -1)*monom(x, 1, -1, 1) \
           + self.nns["spoke_round_0_flat"](jnp.concat((x[...,2][...,None], p), axis=-1))*monom(x, 0, 1, -1)*monom(x, 1, -1, 1) \
           + self.nns["spoke_round_2_flat"](jnp.concat((x[...,2][...,None], p), axis=-1))*monom(x, 0, 1, -1)*monom(x, 1, 1, -1)
                                    
    elif domain_name == "flat": 
      return self.nns["flat"](jnp.concat((x, p), axis=-1))*monom(x, 0, 1, -1)*monom(x, 1, 1, -1)*monom(x, 1, -1, 1) \
           + self.nns["flat_spoke"](jnp.concat((x[...,1:3], p), axis=-1))*monom(x, 0, -1, 1)*monom(x, 1, 1, -1)*monom(x, 1, -1, 1) \
           + self.nns["flat_round_0"](jnp.concat((x[...,0:3:2], p), axis=-1))*monom(x, 0, 1, -1)*monom(x, 1, -1, 1) \
           + self.nns["flat_round_2"](jnp.concat((x[...,0:3:2], p), axis=-1))*monom(x, 0, 1, -1)*monom(x, 1, 1, -1) \
           + self.nns["spoke_round_0_flat"](jnp.concat((x[...,2][...,None], p), axis=-1))*monom(x, 0, -1, 1)*monom(x, 1, -1, 1) \
           + self.nns["spoke_round_2_flat"](jnp.concat((x[...,2][...,None], p), axis=-1))*monom(x, 0, -1, 1)*monom(x, 1, 1, -1)
    elif domain_name == "round_0": 
      return self.nns["round_0"](jnp.concat((x, p), axis=-1))*monom(x, 1, -1, 1)*monom(x, 1, 1, -1) \
           + self.nns["round_0_round_1"](jnp.concat((x[...,0:3:2], p), axis=-1))*monom(x, 1, -1, 1) \
           + self.nns["flat_round_0"](jnp.concat((x[...,0:3:2], p), axis=-1))*monom(x, 0, 1, -1)*monom(x, 1, 1, -1) \
           + self.nns["spoke_round_0_flat"](jnp.concat((x[...,2][...,None], p), axis=-1))*monom(x, 0, -1, 1)*monom(x, 1, 1, -1)
                                    
    elif domain_name == "round_1": 
      return self.nns["round_1"](jnp.concat((x, p), axis=-1))*monom(x, 1, -1, 1)*monom(x, 1, 1, -1) \
           + self.nns["round_0_round_1"](jnp.concat((x[...,0:3:2], p), axis=-1))*monom(x, 1, 1, -1) \
           + self.nns["round_1_round_2"](jnp.concat((x[...,0:3:2], p), axis=-1))*monom(x, 1, -1, 1)
                                    
    elif domain_name == "round_2": 
      return self.nns["round_2"](jnp.concat((x, p), axis=-1))*monom(x, 1, -1, 1)*monom(x, 1, 1, -1) \
           + self.nns["spoke_round_2_flat"](jnp.concat((x[...,2][...,None], p), axis=-1))*monom(x, 0, -1, 1)*monom(x, 1, -1, 1) \
           + self.nns["round_1_round_2"](jnp.concat((x[...,0:3:2], p), axis=-1))*monom(x, 1, 1, -1) \
           + self.nns["flat_round_2"](jnp.concat((x[...,0:3:2], p), axis=-1))*monom(x, 0, 1, -1)*monom(x, 1, -1, 1)
    
      




### PINNs

Define the PINN class. The loss has to be defined. In this case, the nonlinear geometry is used.

In [None]:
class Pinn(pinns.PINN):
    
    def __init__(self):
          
        NP = 2
        input_shape = ((-1,3), (-1,NP))
        #self.solutions = pinns.connectivity_to_interfaces({n: space_bc if 'spoke' in n else space for n in names}, connectivity, decay_fun=lambda x: x**2)
    
        self.solutions = MultipatchPIMLP(n_layers, nnx.Rngs(0))
        
        E = 0.02e5
        nu = 0.3
        self.E = E
        self.nu = nu
        
        self.lamda = E*nu/(1+nu)/(1-2*nu)
        self.mu = E/2/(1+nu)

        rho = 0.0
        g = 9.81
        self.rho = rho
        
        self.f = np.array([0,-g*rho,0]) 
        self.energy = lambda F,C,J,params: params[0]*jnp.sum(F**2, axis=(-2,-1)) + params[1]*jnp.abs(J)**2*jnp.sum(jnp.linalg.inv(F)**2, axis=(-1,-2)) + params[2]*J**2 - params[3]*jnp.log(jnp.abs(J))+params[4]
        self.energy = lambda F,C,J,params: 0.5*self.mu*(C[...,0,0]+C[...,1,1]+C[...,2,2]-3)-self.mu*jnp.log(jnp.abs(J))+0.5*self.lamda*jnp.log(jnp.abs(J))**2
        self.energy = lambda F,E,J,params: 0.5*self.lamda*(E[...,0,0]+E[...,1,1]+E[...,2,2])**2+self.mu*jnp.sum(E*E, axis=(-1,-2))
        #self.energy = lambda F,E,J,params: 0.5*self.lamda*(E[...,0,0]+E[...,1,1]+E[...,2,2])**2+self.mu*(E[...,0,0]**2+E[...,1,1]**2+E[...,2,2]**2)
        
        self.a = 0.5*self.mu
        self.b = 0.0
        self.c = 0.0
        self.d = self.mu
        self.e = -1.5*self.mu

        self.kpen = 8e3
        self.Ab = np.array([[0.0,1.0,0.0]]), np.array([[-0.75]]), np.array([[0.0,-1.0,0.0]]), np.array([[-0.75]])
        super(Pinn, self).__init__(geoms)

 
    def loss(self, model, points, parameters):
        
        jacs = [pinns.functions.jacobian(lambda x, p: model(n, x, p))(points[n].points_reference, parameters) for n in names]
        jacs_x = [points[names[i]].jacobian_transformation(jacs[i]) for i in range(len(names))]
        Fs = [jnp.eye(3)+jacs_x[i] for i in range(len(names))]
        Cs = [jnp.einsum('mji,mki->mjk', Fs[i], Fs[i]) for i in range(len(names))]
        Cs = [0.5*(Cs[i]-jnp.eye(3)[None,...]) for i in range(len(names))]
        
        dets = [jnp.linalg.det(Fs[i]) for i in range(len(names))]
         
        Es = [jnp.dot(self.energy(Fs[i], Cs[i], dets[i], [self.a, self.b,self.c,self.d,self.e]), points[names[i]].dx()) for i in range(len(names))]
        rhss = [jnp.dot(dets[i] * jnp.einsum('k,mk->m', self.f, model(names[i], points[names[i]].points_reference, parameters)), points[names[i]].dx()) for i in range(len(names))] 

        contact_res  = jnp.dot(pinns.geometry.gap_to_convex_polytope(self.Ab[0], self.Ab[1], jnp.einsum('i,j->ij', parameters[:,0]*0.1/2, np.array([0,1,0])) + points['ds2'].points_physical+model('round_2', points['ds2'].points_reference, parameters))**2, points['ds2'].weights)
        contact_res += jnp.dot(pinns.geometry.gap_to_convex_polytope(self.Ab[2], self.Ab[3], jnp.einsum('i,j->ij', parameters[:,1]*0.1/2, np.array([0,-1,0])) + points['ds0'].points_physical+model('round_0', points['ds0'].points_reference, parameters))**2, points['ds0'].weights)
        
        return sum(Es) - sum(rhss) + self.kpen * contact_res 
        
model = Pinn()  


In [None]:
total_params = 0
for k in model.solutions.nns:
    params = nnx.state(model.solutions.nns[k], nnx.Param)
    trainable_params = sum(np.prod(x.shape) for x in jax.tree.leaves(params))
    total_params += trainable_params
    print(f"NN {k}, params {trainable_params}")
print(f"======\nTotal number of parameters {total_params}")

In [None]:
nnx.state(model.solutions.nns["flat"], nnx.Param)["linear1"]["kernel"].value.shape

### Training


In [None]:
opt_type = 'ADAM'

if opt_type == 'ADAM':
    
    
    n_batches = 100
    n_points = 4000000
    n_points_batch = n_points//n_batches
    learning_rate_init = 1e-2
    learning_rate_decay = 0.96
    lr_schedule = optax.schedules.join_schedules([optax.constant_schedule(learning_rate_init), optax.schedules.exponential_decay(learning_rate_init, transition_steps = n_batches, decay_rate=learning_rate_decay, staircase = True)], [20*n_batches])

    optimizer = nnx.Optimizer(model.solutions, optax.adam(lr_schedule))

        
    @nnx.jit
    def train_step(mdl, optimizer: nnx.Optimizer, key, step_no):

        points = model.points_MonteCarlo(n_points_batch, key, [{'patch': 'round_2', 'label': 'ds2', 'axis': 0, 'end': -1, 'n': n_points_batch}, {'patch': 'round_0', 'label': 'ds0', 'axis': 0, 'end': -1, 'n': n_points_batch}])
        parameters = jax.random.uniform(key, (n_points_batch, 2))*2-1
        grad_fn = nnx.value_and_grad(lambda m: model.loss(m, points, parameters), has_aux=False)
        loss, grads = grad_fn(mdl)
        optimizer.update(grads)  
        
        return loss

    n_epochs = 100

    hist = []
    
    # min_loss = 10000
    tme = datetime.datetime.now()
    for ep in range(n_epochs):
        losses = []
        for b in random.sample(range(n_batches), n_batches):
         
            loss = train_step(model.solutions, optimizer, jax.random.PRNGKey(b+0*np.random.randint(100000)), ep)
        
            #print("\tbatch %d/%d"%(b+1, n_batches))
            hist.append(loss)
            losses.append(loss)
    
        
        print('Epoch %d/%d - mean loss %e, max loss %e, min loss %e, std loss %e'%(ep+1, n_epochs, np.mean(losses), np.max(losses), np.min(losses), np.std(losses)))
        
    tme = datetime.datetime.now() - tme
    print('Elapsed time ', tme)


In [None]:
plt.semilogy(hist)

### Plot 

Save the solution as a `.vtk` file.

In [None]:
def plot(p1, p2, fname):
    with jax.disable_jit():
        pv_objects = [pinns.extras.plot(geoms[n], {'displacement': lambda y: model.solutions(n, y, np.zeros((y.shape[0], 2)) + np.array([p1, p2]))}, N= 25) for n in geoms]

    obj = pv_objects[0]
    for i in range(1,len(pv_objects)):
        obj = obj.merge(pv_objects[i])
    obj.save(fname)
    
plot(-1.0, -1.0, "solution_holder_contacts.0.vtk")
plot(-1.0, 1.0, "solution_holder_contacts.1.vtk")
plot(1.0, -1.0, "solution_holder_contacts.2.vtk")
plot(1.0, 1.0, "solution_holder_contacts.3.vtk")

In [None]:
#for ep in range(n_epochs):
#    with jax.disable_jit():
#        pv_objects = [pinns.extras.plot(geoms[n], {'displacement': lambda y: model.solutions[n](hist_weights[ep], y, np.zeros((y.shape[0], 2)))}, N= 25) for n in geoms]
#
#    obj = pv_objects[0]
#    for i in range(1,len(pv_objects)):
#        obj = obj.merge(pv_objects[i])
#    obj.save('solution_holder_contacts_%d.vtk'%(ep+1))

In [None]:

plt = pv.Plotter()

p1 = 0.0
p2 = 0.0
def update_mesh(param, value):
    print(value)
    global p1, p2
    if param==0:
        p1 = value
    else:
        p2 = value
    pv_objects = [pinns.extras.plot(geoms[n], {'displacement': lambda y: model.solutions[n](hist_weights[ep], y, np.zeros((y.shape[0], 2))+[p1,p2])}, N= 25) for n in geoms]
    obj = pv_objects[0]
    for i in range(1,len(pv_objects)):
        obj = obj.merge(pv_objects[i])
        
    plt.add_mesh(pv_objects)

    return


plt.add_slider_widget(lambda v: update_mesh(0, v), [-1, 1], title='Param 1')
plt.add_slider_widget(lambda v: update_mesh(1, v), [-1, 1], title='Param 2')
plt.show()