In [76]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad
from utilfuncs  import computeFilter, implicit_Ke, implicit_Ae, stiffindex4AK
import matplotlib.pyplot as plt

In [77]:
# INITIALIZATION

nelx, nely = 15, 6
nel = nelx*nely;
ndof = (nelx+1)*(nely+1)*2;
filterRadius = 3
H, Hs = computeFilter(nelx,nely,filterRadius)

In [78]:
# DESIGN TARGET DEFORMATION

#case1: the upper u_x =0 ; u_y is with a parabolic distribution
ut_obj= jnp.zeros((ndof,));
# upper bounary deformation  freedom
node_upper = jnp.arange(1,(nely+1)*nelx+1,nelx+1);
dof_upper = jnp.concatenate((node_upper*2-1,node_upper*2),axis=0)-1;
dof_upper_uy = dof_upper[1::2];
value_upper_uy = jnp.linspace(0,0.1,nely+1)**2;
ut_obj = ut_obj.at[dof_upper_uy].set(value_upper_uy);# target deformation freedom, with deformation at non-interested freedom set to zero
index_u_t = jnp.zeros((ndof,));
index_u_t = index_u_t.at[dof_upper_uy].set(1);# identity matrix of size (node*2,1) with 1 in the target deformation freedom; 


In [79]:
# %% FEM SETUP AND SOLVE

# MATERIAL 
alpha = jnp.ones(((nelx)*(nely),1)); # design variable
alpha = jnp.maximum(jnp.minimum(alpha, 1.), 0.) # filter smoothing the design variable
E = 2; # constant Young's modulus field(not design variable at this moment)
nu = 0.1; # constant Poisson's ratio field(not design variable at this moment)
# Ca = jnp.zeros((9,1));
Ca  =  (E/(1-nu**2))* jnp.array([[1,nu,0],[nu,1,0],[0,0,(1-nu)/2]]).reshape(-1,order='F');
da = Ca.reshape(3,3,order='F') @jnp.array([1,1,0]).reshape(-1,1)
# dalpha = jnp.dot(jnp.power(alpha,self.q),da.T).T # power order interpolation
da_ = jnp.dot(alpha.reshape(-1,1),da.T).T; # linear interpolation
Ca_ = jnp.dot(jnp.ones((nel,1)),Ca.reshape(1,-1,order='F')).T # power order interpolation
sA = implicit_Ae(da_); # element area matrix
sK = implicit_Ke(Ca_); # element stiffness matrix

# ASSEMBEL STIFFNESS MATRIX
A = jnp.zeros((ndof,(nelx+1)*(nely+1)));
K = jnp.zeros((ndof,ndof));
iK, jK, iA,jA = stiffindex4AK(nelx,nely)
A = A.at[(iA,jA)].add(sA.flatten('F'))  ; 
K = K.at[(iK,jK)].add(sK.flatten('F')); # global stiffness matrix

# DEFINE LOADS AND SUPPORTS (HALF MBB-BEAM)
u = jnp.zeros((ndof,1)); # deformation field
theta = jnp.ones(((nelx+1)*(nely+1),1)); # temperater loading

fix_left = jnp.arange(0,(nely+1)*2,2); # left boundary
fix_right = jnp.arange((nely+1)*2-1,(nely+1)*2*(nelx+1),2); # right boundary
dof_fix = jnp.concatenate((fix_left,fix_right),axis=0); # fixed left and right boundary
dof_free = jnp.setdiff1d(jnp.arange(0,ndof),dof_fix); # free freedom

# SOLVE 
Kr = K[dof_free,:][:,dof_free]; # reduced stiffness matrix


f_t = A@theta # equivalent nodal force vector
f_r = f_t[dof_free]; # reduced equivalent nodal force vector
u_t = jnp.linalg.solve(Kr,f_r); # reduced deformation vector
u = u.at[dof_free].set(u_t); # deformation vector


In [82]:
# %% OBJECTIVE FUNCTION AND SENSITIVITY ANALYSIS
# objective function: \sum(|u*index_u_t - ut_obj|**2)
obj = jnp.sum(jnp.power((u*index_u_t - ut_obj),2));
# sensitivity analysis of the objective function 'obj' to the design variable  'alpha'
obj 

Array(73350.055, dtype=float32)

In [89]:
import jax
import jax.numpy as jnp

def my_function(x):
    a = jnp.sin(x)
    y = jnp.stack([jnp.sum(a ** 2), jnp.sum(a)])
    return y, a

x = jnp.array([1., 2., 3.])
y, a = my_function(x)

# Compute Jacobian with respect to x
jacobian_fn = jax.jacobian(my_function, argnums=0)
grad_x = jacobian_fn(x)

# Compute Jacobian with respect to a
grad_a = jacobian_fn(x)[1]

print("y =", y)
print("a =", a)
print("grad_x =", grad_x)
print("grad_a =", grad_a)


y = [1.55481   1.8918884]
a = [0.84147096 0.9092974  0.14112   ]
grad_x = (Array([[ 0.90929735, -0.7568025 , -0.2794155 ],
       [ 0.5403023 , -0.41614684, -0.9899925 ]], dtype=float32), Array([[ 0.5403023 , -0.        , -0.        ],
       [ 0.        , -0.41614684, -0.        ],
       [ 0.        , -0.        , -0.9899925 ]], dtype=float32))
grad_a = [[ 0.5403023  -0.         -0.        ]
 [ 0.         -0.41614684 -0.        ]
 [ 0.         -0.         -0.9899925 ]]


In [None]:
# OBTAIN TARGET DEFORAMTION AND OBJECTIVE FUNCTION

u.shape
# theta.shape

In [None]:
# ASSEMBLE STIFFNESS MATRICIES K AND THERMAL STRESS MATRIX A

In [None]:
# SOLVE THE LINEAR SYSTEM Ku = f
u_solve = jnp.linalg.solve(K,f)

In [None]:
# OBJECTIVE FUNCTION
obj =  jnp.sum(((u_solve- ut_obj)*index_u_t)**2)
# OPTIMIZATION