Imports:

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.example_libraries import stax, optimizers
import matplotlib.pyplot as plt
import pyvista as pv
import pinns 
import datetime
import jax.scipy.optimize
import jax.flatten_util
import scipy
import scipy.optimize
import random

rnd_key = jax.random.PRNGKey(1234)
np.random.seed(14124)

Set the default precision and the execution device.

In [None]:
jax.config.update("jax_enable_x64", False)
# print("GPU devices: ", jax.devices('gpu'))
dev = jax.devices('gpu')[0] if jax.device_count()>1 and len(jax.devices('gpu'))>0 else jax.devices('cpu')[0]
print(dev)

### Geometry definition 

Define the geometry patches:

In [None]:
def get_domain(r0: float, r1: float, R: float, h: float, H: float):

    basis1 = pinns.functions.BSplineBasisJAX(np.array([-1, 0, 1]), 2)
    basis2 = pinns.functions.BSplineBasisJAX(np.array([-1, 1]), 2)
    basis3 = pinns.functions.BSplineBasisJAX(np.array([-1, 1]), 1)

    def tmp_gen(angle, r_0, r_1):
        pts = np.zeros([4, 3, 2, 3])
        weights = np.ones([4, 3, 2])

        a = np.pi/2-angle/2
        rs = np.linspace(r_0, r_1, 4)
        pts[-1, 0, 0, :] = [np.cos(-angle/2), np.sin(-angle/2), 0]
        pts[-1, 1, 0, :] = [1/np.sin(a), 0, 0]
        pts[-1, 2, 0, :] = [np.cos(angle/2), np.sin(angle/2), 0]
        pts[0, :, 0, :2] = rs[0] * pts[-1, :, 0, :2]
        pts[1, :, 0, :2] = rs[1] * pts[-1, :, 0, :2]
        pts[2, :, 0, :2] = rs[2] * pts[-1, :, 0, :2]
        pts[3, :, 0, :2] = rs[3] * pts[-1, :, 0, :2]
        pts[0, :, 0, 2] = -1
        pts[1, :, 0, 2] = -1
        pts[2, :, 0, 2] = -1
        pts[3, :, 0, 2] = -1
        pts[:, :, 1, :] = pts[:, :, 0, :]
        pts[:, :, 1, 2] = -pts[:, :, 1, 2]
        weights[:, 1, :] = np.sin(a)

        return pts, weights

    geoms = dict()

    pts, weights = tmp_gen(np.pi/2, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2
    pts[3, 1, :, 0] = pts[3, 0, :, 0]
    pts[1, 1, :, 0] = 2*pts[0, 1, :, 0]/3+pts[-1, 1, :, 0]/3
    pts[2, 1, :, 0] = pts[0, 1, :, 0]/3+2*pts[-1, 1, :, 0]/3
    weights[-1, 1, :] = 1.0

    geoms['flat'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts.copy(), weights.copy(), 0, 3)

    pts2 = pts[-1, :, :, :]
    weights[...] = 1.0
    linsp = np.linspace(0, 1, basis1.n)

    pts[0, :, :, :] = pts2
    pts[-1, :, :, :] = pts2
    pts[-1, :, :, 0] *= R/r1
    pts[-1, :, :, 1] *= H/h
    # pts[0, :, :, 2] *= H/h

    for i in range(1, basis1.n-1):
        pts[i, :, :, 2] = (1-linsp[i]**0.25)*pts[0, :, :, 2] + \
            linsp[i]**0.25*pts[-1, :, :, 2]
        pts[i, :, :, 0] = (1-linsp[i])*pts[0, :, :, 0] + \
            linsp[i]*pts[-1, :, :, 0]
        pts[i, :, :, 1] = (1-linsp[i]**4)*pts[0, :, :, 1] + \
            linsp[i]**4*pts[-1, :, :, 1]
        pts[i, :, :, 1] *= 2*(linsp[i]-1/2)**2+0.5
        pts[i, :, :, 2] *= 2*(linsp[i]-1/2)**2+0.5

    geoms['spoke'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)

    pts, weights = tmp_gen(np.pi/2, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_0'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_0'].rotate((0, 0, np.pi/2))

    pts, weights = tmp_gen(np.pi/2, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_1'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_1'].rotate((0, 0, 2*np.pi/2))

    pts, weights = tmp_gen(np.pi/2, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_2'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_2'].rotate((0, 0, 3*np.pi/2))

    return geoms


geoms = get_domain(0.4, 0.8, 3.0, 1.0, 1.5)
names = list(geoms.keys())

Export as a VTK file for visualization in paraview.

In [None]:
with jax.disable_jit():
    objects = [pinns.extras.plot(geoms[n], dict(), N=32) for n in geoms]

obj = objects[0]
for i in range(1, len(objects)):
    obj = obj.merge(objects[i])

obj.save('testing.vtk')

Determine the connectivity of the patches:

In [None]:
with jax.disable_jit(True):
    connectivity = pinns.geometry.match_patches(geoms, eps=1e-4, verbose=False)

for c in connectivity:
    print(c)

In [None]:
for k in geoms:
    pinns.geometry.save_patch('holder_'+k+'.geom', geoms[k])

import pickle 

with open('connectivity_holder.pkl', 'wb') as file: 
    pickle.dump(connectivity, file) 

### ANN spaces definition

The network is an MLP with residual connections and width set by the `nl` paraemter.
There are 2 spaces defined: first has 0 Dirichlet BCs on one facet and the other has no Dirichlet BCs enforced. 

In [None]:
nl = 20
acti =  stax.elementwise(lambda x: jax.nn.leaky_relu(x)**2)
w_init = jax.nn.initializers.normal()

block_first = stax.serial(stax.FanOut(2),stax.parallel(stax.serial(stax.Dense(nl,W_init = w_init), acti, stax.Dense(nl,W_init = w_init), acti),stax.Dense(nl,W_init = w_init)),stax.FanInSum)
block = stax.serial(stax.FanOut(2),stax.parallel(stax.serial(stax.Dense(nl,W_init = w_init), acti, stax.Dense(nl,W_init = w_init), acti),stax.Dense(nl,W_init = w_init)),stax.FanInSum)
nn = stax.serial(block_first,block, stax.Dense(3))

space_bc = pinns.FunctionSpaceNN(pinns.DirichletMask(nn, 3, [(-1,1), (-1,1), (-1,1)], [{'dim': 0, 'end': -1}]), ((-1,1), (-1,1), (-1,1))) 
space = pinns.FunctionSpaceNN(nn,((-1,1), (-1,1), (-1,1)))

### PINNs

Define the PINN class. The loss has to be defined. In this case, the nonlinear geometry is used.

In [None]:
class Pinn(pinns.PINN):
    
    def __init__(self):
          
        self.weights = {n: space_bc.init_weights(rnd_key) if 'spoke' in n else space.init_weights(rnd_key) for n in names}
        self.solutions = pinns.connectivity_to_interfaces({n: space_bc if 'spoke' in n else space for n in names}, connectivity)
        
        E = 0.02e5
        nu = 0.3
        self.E = E
        self.nu = nu
        
        self.lamda = E*nu/(1+nu)/(1-2*nu)
        self.mu = E/2/(1+nu)

        rho = 0.2
        g = 9.81
        self.rho = rho
        
        self.f = np.array([0,-g*rho,0]) 
        self.energy = lambda F,C,J,params: params[0]*jnp.sum(F**2, axis=(-2,-1)) + params[1]*jnp.abs(J)**2*jnp.sum(jnp.linalg.inv(F)**2, axis=(-1,-2)) + params[2]*J**2 - params[3]*jnp.log(jnp.abs(J))+params[4]
        self.energy = lambda F,C,J,params: 0.5*self.mu*(C[...,0,0]+C[...,1,1]+C[...,2,2]-3)-self.mu*jnp.log(jnp.abs(J))+0.5*self.lamda*jnp.log(jnp.abs(J))**2
        self.energy = lambda F,E,J,params: 0.5*self.lamda*(E[...,0,0]+E[...,1,1]+E[...,2,2])**2+self.mu*jnp.sum(E*E, axis=(-1,-2))
        
        self.a = 0.5*self.mu
        self.b = 0.0
        self.c = 0.0
        self.d = self.mu
        self.e = -1.5*self.mu

        self.kpen = 1e3
        self.Ab = np.array([[0.0,1.0,0.0]]), np.array([[-0.7]]), np.array([[0.0,-1.0,0.0]]), np.array([[-0.7]])
        super(Pinn, self).__init__(geoms)
   
 
    def loss(self, training_parameters, points):
        

        jacs = [pinns.functions.jacobian(lambda x : self.solutions[n](training_parameters, x))(points[n].points_reference) for n in names]
        jacs_x = [points[names[i]].jacobian_transformation(jacs[i]) for i in range(len(names))]
        Fs = [jnp.eye(3)+jacs_x[i] for i in range(len(names))]
        Cs = [jnp.einsum('mij,mik->mjk', Fs[i], Fs[i]) for i in range(len(names))]
        Cs = [0.5*(Cs[i]-jnp.eye(3)[None,...]) for i in range(len(names))]
        
        dets = [jnp.linalg.det(Fs[i]) for i in range(len(names))]
         
        Es = [jnp.dot(self.energy(Fs[i], Cs[i], dets[i], [self.a, self.b,self.c,self.d,self.e]), points[names[i]].dx()) for i in range(len(names))]
        rhss = [jnp.dot(dets[i] * jnp.einsum('k,mk->m', self.f, self.solutions[names[i]](training_parameters, points[names[i]].points_reference)), points[names[i]].dx()) for i in range(len(names))] 

        contact_res  = jnp.dot(pinns.geometry.gap_to_convex_polytope(self.Ab[0], self.Ab[1], points['ds2'].points_physical+self.solutions['round_2'](training_parameters, points['ds2'].points_reference)), points['ds2'].weights)
        contact_res += jnp.dot(pinns.geometry.gap_to_convex_polytope(self.Ab[2], self.Ab[3], points['ds0'].points_physical+self.solutions['round_0'](training_parameters, points['ds0'].points_reference)), points['ds0'].weights)
        
        return sum(Es) - sum(rhss) + self.kpen * contact_res
    
        
model = Pinn()  

### Training


In [None]:
opt_type = 'ADAM'

if opt_type == 'ADAM':
    
    
    n_batches = 500 
    n_points = 1000000
    n_points_batch = n_points//n_batches
    
    lr_opti = optimizers.piecewise_constant([2000,3000,4000,5000,7000], [0.005, 0.005/2, 0.005/4, 0.005/8,0.005/16,0.005/32])
    #lr_opti = optimizers.piecewise_constant([2000,3000,4000,5000], [0.005, 0.005/4, 0.005/4,0.005/4,0.005/4])
    lr_opti = optimizers.piecewise_constant([7000], [0.01/2, 0.001])
    opt_init, opt_update, get_params = optimizers.adam(lr_opti)

    opt_state = opt_init(model.weights)
    weights_init = model.weights
    
    # get initial parameters
    params = get_params(opt_state)

    loss_grad = jax.jit(lambda ws, pts: (model.loss(ws, pts), jax.grad(model.loss)(ws, pts)), device = dev)

    def step(params, opt_state, key):
        # points = model.get_points_MC(5000)
        points = model.points_MonteCarlo(n_points_batch, key, [{'patch': 'round_2', 'label': 'ds2', 'axis': 0, 'end': -1, 'n': n_points_batch}, {'patch': 'round_0', 'label': 'ds0', 'axis': 0, 'end': -1, 'n': n_points_batch}])
        loss = model.loss(params, points)
        grads = jax.grad(model.loss)(params, points)
        #loss, grads = loss_grad(params, points)
        opt_state = opt_update(0, grads, opt_state)

        params = get_params(opt_state)
        
        return params, opt_state, loss

    step_compiled = jax.jit(step, device = dev)
    step_compiled(params, opt_state, rnd_key)

    n_epochs = 20

    hist = []
    hist_weights = []
    
    # min_loss = 10000
    tme = datetime.datetime.now()
    for ep in range(n_epochs):   
        
        losses = []
        for b in random.sample(range(n_batches), n_batches):
         
            params, opt_state, loss = step_compiled(params, opt_state, jax.random.PRNGKey(b+0*np.random.randint(1000)))
            #print("\tbatch %d/%d"%(b+1, n_batches))
            hist.append(loss)
            losses.append(loss)
        
        hist_weights.append(params.copy())
        print('Epoch %d/%d - mean loss %e, max loss %e, min loss %e, std loss %e'%(ep+1, n_epochs, np.mean(losses), np.max(losses), np.min(losses), np.std(losses)))
        
    # update params
    model.weights = params
    weights = params
    tme = datetime.datetime.now() - tme
    print('Elapsed time ', tme)


### Plot 

Save the solution as a `.vtk` file.

In [None]:
with jax.disable_jit():
    pv_objects = [pinns.extras.plot(geoms[n], {'displacement': lambda y: model.solutions[n](weights, y)}, N= 25) for n in geoms]

obj = pv_objects[0]
for i in range(1,len(pv_objects)):
    obj = obj.merge(pv_objects[i])
obj.save('solution_holder_contacts.vtk')

In [None]:
for ep in range(n_epochs):
    with jax.disable_jit():
        pv_objects = [pinns.extras.plot(geoms[n], {'displacement': lambda y: model.solutions[n](hist_weights[ep], y)}, N= 25) for n in geoms]

    obj = pv_objects[0]
    for i in range(1,len(pv_objects)):
        obj = obj.merge(pv_objects[i])
    obj.save('solution_holder_contacts_%d.vtk'%(ep+1))