Imports:

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.example_libraries import stax, optimizers
import matplotlib.pyplot as plt
import pyvista as pv
import pinns 
import datetime
import jax.scipy.optimize
import jax.flatten_util
import scipy
import scipy.optimize
import random

from pygccx import model as ccx_model
from pygccx import model_keywords as mk
from pygccx import step_keywords as sk
from pygccx import enums

rnd_key = jax.random.PRNGKey(1234)
np.random.seed(14124)


Set the default precision and the execution device.

In [None]:
jax.config.update("jax_enable_x64", False)
# print("GPU devices: ", jax.devices('gpu'))
dev = jax.devices('gpu')[0] if jax.device_count()>1 and len(jax.devices('gpu'))>0 else jax.devices('cpu')[0]
print(dev)

### Geometry definition 

Define the geometry patches:

In [None]:
def get_domain(r0: float, r1: float, R: float, h: float, H: float):

    basis1 = pinns.functions.BSplineBasisJAX(np.array([-1, 0, 1]), 2)
    basis2 = pinns.functions.BSplineBasisJAX(np.array([-1, 1]), 2)
    basis3 = pinns.functions.BSplineBasisJAX(np.array([-1, 1]), 1)

    def tmp_gen(angle, r_0, r_1):
        pts = np.zeros([4, 3, 2, 3])
        weights = np.ones([4, 3, 2])

        a = np.pi/2-angle/2
        rs = np.linspace(r_0, r_1, 4)
        pts[-1, 0, 0, :] = [np.cos(-angle/2), np.sin(-angle/2), 0]
        pts[-1, 1, 0, :] = [1/np.sin(a), 0, 0]
        pts[-1, 2, 0, :] = [np.cos(angle/2), np.sin(angle/2), 0]
        pts[0, :, 0, :2] = rs[0] * pts[-1, :, 0, :2]
        pts[1, :, 0, :2] = rs[1] * pts[-1, :, 0, :2]
        pts[2, :, 0, :2] = rs[2] * pts[-1, :, 0, :2]
        pts[3, :, 0, :2] = rs[3] * pts[-1, :, 0, :2]
        pts[0, :, 0, 2] = -1
        pts[1, :, 0, 2] = -1
        pts[2, :, 0, 2] = -1
        pts[3, :, 0, 2] = -1
        pts[:, :, 1, :] = pts[:, :, 0, :]
        pts[:, :, 1, 2] = -pts[:, :, 1, 2]
        weights[:, 1, :] = np.sin(a)

        return pts, weights

    geoms = dict()

    pts, weights = tmp_gen(np.pi/2, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2
    pts[3, 1, :, 0] = pts[3, 0, :, 0]
    pts[1, 1, :, 0] = 2*pts[0, 1, :, 0]/3+pts[-1, 1, :, 0]/3
    pts[2, 1, :, 0] = pts[0, 1, :, 0]/3+2*pts[-1, 1, :, 0]/3
    weights[-1, 1, :] = 1.0

    geoms['flat'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts.copy(), weights.copy(), 0, 3)

    pts2 = pts[-1, :, :, :]
    weights[...] = 1.0
    linsp = np.linspace(0, 1, basis1.n)

    pts[0, :, :, :] = pts2
    pts[-1, :, :, :] = pts2
    pts[-1, :, :, 0] *= R/r1
    pts[-1, :, :, 1] *= H/h
    # pts[0, :, :, 2] *= H/h

    for i in range(1, basis1.n-1):
        pts[i, :, :, 2] = (1-linsp[i]**0.25)*pts[0, :, :, 2] + \
            linsp[i]**0.25*pts[-1, :, :, 2]
        pts[i, :, :, 0] = (1-linsp[i])*pts[0, :, :, 0] + \
            linsp[i]*pts[-1, :, :, 0]
        pts[i, :, :, 1] = (1-linsp[i]**4)*pts[0, :, :, 1] + \
            linsp[i]**4*pts[-1, :, :, 1]
        pts[i, :, :, 1] *= 2*(linsp[i]-1/2)**2+0.5
        pts[i, :, :, 2] *= 2*(linsp[i]-1/2)**2+0.5

    geoms['spoke'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)

    pts, weights = tmp_gen(np.pi/2, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_0'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_0'].rotate((0, 0, np.pi/2))

    pts, weights = tmp_gen(np.pi/2, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_1'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_1'].rotate((0, 0, 2*np.pi/2))

    pts, weights = tmp_gen(np.pi/2, r0, r1)
    pts[:, :, :, 2] *= h/2
    # pts[2:,:,:,2] *= h/2

    geoms['round_2'] = pinns.geometry.PatchNURBS(
        [basis1, basis2, basis3], pts, weights, 0, 3)
    geoms['round_2'].rotate((0, 0, 3*np.pi/2))

    return geoms


geoms = get_domain(0.4, 0.8, 3.0, 1.0, 1.5)
names = list(geoms.keys())

Determine the connectivity of the patches:

In [None]:
with jax.disable_jit(True):
    connectivity = pinns.geometry.match_patches(geoms, eps=1e-4, verbose=False)

for c in connectivity:
    print(c)

In [None]:
def construct_inp(geoms: dict, E: float, mu: float, n: int, meshsize: float):
    
    with ccx_model.Model("", "", jobname='holder', working_dir="./") as model:

        gmsh = model.get_gmsh()
        
        surfaces = dict()
        
        for name in geoms:
            meshgrid = np.meshgrid(np.linspace(-1,1,n), np.linspace(-1,1,n))
            positions = geoms[name][:,:,-1](np.concatenate((meshgrid[0].reshape([-1,1]), meshgrid[1].reshape([-1,1])),-1))
            indices = []
            for i in range(positions.shape[0]):
                tag = gmsh.model.occ.addPoint(positions[i,0], positions[i,1], positions[i,2], tag = -1, meshSize = meshsize)
                indices.append(tag)
            tag = gmsh.model.occ.addBSplineSurface(indices, n, -1, 2, 2)
            surfaces[(name, 2, 0)] = tag 
            
            meshgrid = np.meshgrid(np.linspace(-1,1,n), np.linspace(-1,1,n))
            positions = geoms[name][:,:,1](np.concatenate((meshgrid[0].reshape([-1,1]), meshgrid[1].reshape([-1,1])),-1))
            indices = []
            for i in range(positions.shape[0]):
                tag = gmsh.model.occ.addPoint(positions[i,0], positions[i,1], positions[i,2], tag = -1, meshSize = meshsize)
                indices.append(tag)
            tag = gmsh.model.occ.addBSplineSurface(indices, n, -1, 2, 2)
            surfaces[(name, 2, -1)] = tag
            
            meshgrid = np.meshgrid(np.linspace(-1,1,n), np.linspace(-1,1,n))
            positions = geoms[name][:,-1,:](np.concatenate((meshgrid[0].reshape([-1,1]), meshgrid[1].reshape([-1,1])),-1))
            indices = []
            for i in range(positions.shape[0]):
                tag = gmsh.model.occ.addPoint(positions[i,0], positions[i,1], positions[i,2], tag = -1, meshSize = meshsize)
                indices.append(tag)
            tag = gmsh.model.occ.addBSplineSurface(indices, n, -1, 2, 2)
            surfaces[(name, 1, 0)] = tag
            
            meshgrid = np.meshgrid(np.linspace(-1,1,n), np.linspace(-1,1,n))
            positions = geoms[name][:,1,:](np.concatenate((meshgrid[0].reshape([-1,1]), meshgrid[1].reshape([-1,1])),-1))
            indices = []
            for i in range(positions.shape[0]):
                tag = gmsh.model.occ.addPoint(positions[i,0], positions[i,1], positions[i,2], tag = -1, meshSize = meshsize)
                indices.append(tag)
            tag = gmsh.model.occ.addBSplineSurface(indices, n, -1, 2, 2)
            surfaces[(name, 1, -1)] = tag
            
            meshgrid = np.meshgrid(np.linspace(-1,1,n), np.linspace(-1,1,n))
            positions = geoms[name][-1,:,:](np.concatenate((meshgrid[0].reshape([-1,1]), meshgrid[1].reshape([-1,1])),-1))
            indices = []
            for i in range(positions.shape[0]):
                tag = gmsh.model.occ.addPoint(positions[i,0], positions[i,1], positions[i,2], tag = -1, meshSize = meshsize)
                indices.append(tag)
            tag = gmsh.model.occ.addBSplineSurface(indices, n, -1, 2, 2)
            surfaces[(name, 0, 0)] = tag
            
            meshgrid = np.meshgrid(np.linspace(-1,1,n), np.linspace(-1,1,n))
            positions = geoms[name][1,:,:](np.concatenate((meshgrid[0].reshape([-1,1]), meshgrid[1].reshape([-1,1])),-1))
            indices = []
            for i in range(positions.shape[0]):
                tag = gmsh.model.occ.addPoint(positions[i,0], positions[i,1], positions[i,2], tag = -1, meshSize = meshsize)
                indices.append(tag)
            tag = gmsh.model.occ.addBSplineSurface(indices, n, -1, 2, 2)
            surfaces[(name, 0, -1)] = tag
            
        #tag = gmsh.model.occ.add_curve_loop([surfaces[('spoke', 0, -1)], surfaces[('spoke', 1, -1)], surfaces[('spoke', 1, 0)], surfaces[('spoke', 2, -1)], surfaces[('spoke', 2, 0)],
        #                                     surfaces[('flat', 0, 0)], surfaces[('flat', 2, 0)], surfaces[('flat', 2, -1)],
        #                                     surfaces[('round_0', 0, -1)], surfaces[('round_0', 0, 0)], surfaces[('round_0', 2, -1)], surfaces[('round_0', 2, 0)],
        #                                     surfaces[('round_1', 0, -1)], surfaces[('round_1', 0, 0)], surfaces[('round_1', 2, -1)], surfaces[('round_1', 2, 0)],
        #                                     surfaces[('round_2', 0, -1)], surfaces[('round_2', 0, 0)], surfaces[('round_2', 2, -1)], surfaces[('round_2', 2, 0)],
        #                                     ])
        
        gmsh.model.occ.remove_all_duplicates()
        closed_surface_loop = [surfaces[('spoke', 0, -1)], surfaces[('spoke', 1, 0)], surfaces[('spoke', 1, -1)], surfaces[('spoke', 2, -1)], surfaces[('spoke', 2, 0)]]
        closed_surface_loop += [surfaces[('flat',0,0)], surfaces[('flat',2,0)], surfaces[('flat',2,-1)]]
        closed_surface_loop += [surfaces[('round_0',0,0)], surfaces[('round_0',0,-1)], surfaces[('round_0',2,0)], surfaces[('round_0',2,-1)]]
        closed_surface_loop += [surfaces[('round_1',0,0)], surfaces[('round_1',0,-1)], surfaces[('round_1',2,0)], surfaces[('round_1',2,-1)]]
        closed_surface_loop += [surfaces[('round_2',0,0)], surfaces[('round_2',0,-1)], surfaces[('round_2',2,0)], surfaces[('round_2',2,-1)]]
        
        tag = gmsh.model.occ.add_surface_loop(closed_surface_loop)
        
        tag_volume = gmsh.model.occ.add_volume([tag])
        
        gmsh.model.occ.synchronize()
        
        gmsh.model.add_physical_group(3, [tag_volume], name='HOLDER')
        gmsh.model.add_physical_group(2, [surfaces[('spoke', 0, -1)]], name='FIX')
        
        gmsh.model.mesh.set_size_callback(lambda dim, tag, x, y, z, lc: meshsize)
        gmsh.model.mesh.generate(3)
        
        gmsh.write("test.geo_unrolled")
        gmsh.write('holder.msh')
            
    

In [None]:
with jax.disable_jit(True):
    construct_inp(geoms, 2000, 0.0, 16,0.1)