In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad
from utilfuncs  import computeFilter, implicit_Ke, implicit_Ae, stiffindex4AK,unknow_degree
import matplotlib.pyplot as plt

In [None]:
class ProgDef():
    def __init__(self,nelx,nely,index,value):
        self.nelx, self.nely = nelx,nely;
        self.nel = self.nelx*self.nely;
        self.ndof = 2*(self.nelx+1)*(self.nely+1);
        self.filterRadius = 3
        H, Hs = computeFilter(self.nelx,self.nely,self.filterRadius)
        self.H,self.Hs = H,Hs
        self.iK, self.jK, self.iA,self.jA = stiffindex4AK(self.nelx,self.nely)
        self.E = 2;self.nu = 0.3;
        dof_free = unknow_degree(nelx,nely);
        self.dof_free = dof_free;
    
        ut_obj = jnp.zeros((self.ndof,));
        ut_obj = ut_obj.at[index].set(value);# target deformation freedom, with deformation at non-interested freedom set to zero
        index_u_t = jnp.zeros((self.ndof,));
        index_u_t = index_u_t.at[index].set(1);# identity matrix of size (node*2,1) with 1 in the target deformation freedom;
        self.ut_obj , self.index_u_t = ut_obj, index_u_t
        # return self.ut_obj, self.index_u_t
        
    
    def objective(self,x):
        # INPUT: THE DISTRIBUTION OF CTE
        # OUTPUT: THE ERROR OF THE PROGRAMA DISPLACEMENT TO THE TEST DISPLACEMENT
        @jit
        def da_ca_(x):
            Ca  =  (self.E/(1-self.nu**2))* jnp.array([[1,self.nu,0],[self.nu,1,0],[0,0,(1-self.nu)/2]]).reshape(-1,order='F');
            da = Ca.reshape(3,3,order='F') @jnp.array([1,1,0]).reshape(-1,1)
            da_ = jnp.dot(x.reshape(-1,1),da.T).T; # linear interpolation
            Ca_ = jnp.dot(jnp.ones((self.nel,1)),Ca.reshape(1,-1,order='F')).T # power order interpolation
            return da_, Ca_
        @jit
        def sK_sA(da_,Ca_):
            Ke = implicit_Ke(Ca_)
            Ae = implicit_Ae(da_)
            return Ke, Ae
        @jit
        def K_A(sK,sA):
            K = jnp.zeros((self.ndof,self.ndof))
            A = jnp.zeros((self.ndof,(self.nelx+1)*(self.nely+1)));
            K = K.at[(self.iK,self.jK)].add(sK.flatten('F'))
            A = A.at[(self.iA,self.jA)].add(sA.flatten('F'))  
            return K,A
        @jit
        def Kuf(K,A):
            u = jnp.zeros((self.ndof,1)); # deformation field
            theta = jnp.ones(((self.nelx+1)*(self.nely+1),1))*10; # temperater loading

            dof_free = self.dof_free;

            # SOLVE 
            Kr = K[dof_free,:][:,dof_free]; # reduced stiffness matrix
            f_t = A@theta # equivalent nodal force vector
            f_r = f_t[dof_free]; # reduced equivalent nodal force vector
            u_t = jnp.linalg.solve(Kr,f_r); # reduced deformation vector
            u = u.at[dof_free].set(u_t); # deformation vector
            
            return u
        
        da_ , Ca_ = da_ca_(x);
        Ke, Ae = sK_sA(da_,Ca_);
        K, A = K_A(Ke,Ae);
        u = Kuf(K,A);
        u = u.reshape(-1,order='F');
        # the objective function ]
        
        
        obj =  jnp.sqrt(jnp.sum(jnp.power(jnp.abs((u*self.index_u_t)- self.ut_obj),2)))
        # obj = log(obj + 1e-8)
        # obj = jnp.log(obj + 1e-8)
        self.u = u;

        return obj,u
    
    def filt(self,filter_obj):
        
        return (self.H @ filter_obj.reshape(-1,order = 'F') / self.Hs).reshape(-1,1,order='F')

In [None]:
nely = 20; nelx = 20;
node_upper = np.arange(1,(nely+1)*nelx+1+1,nely+1);
dof_upper = np.concatenate((node_upper*2-1,node_upper*2),axis=0)-1;
dof_upper_uy = node_upper *2-1;
dof_upper_ux = node_upper *2-1-1;

value_upper_uy = np.linspace(0,0.1,nelx+1);
value_upper_ux = np.zeros((len(dof_upper_ux),));


# ux + uy 
# defo_index = jnp.concatenate((dof_upper_ux,dof_upper_uy),axis=0);
# defo_value = jnp.concatenate((value_upper_ux,value_upper_uy),axis=0);
defo_index = dof_upper_uy;
defo_value = value_upper_uy;

opti = ProgDef(nelx,nely,defo_index,defo_value);
# obj_fun = jit(opti.objective);
# dobj_dx_fun = jax.jacfwd(obj_fun,argnums=0);

In [None]:
# import numpy as np
# nely = 20; nelx = 20;
# node_upper = np.arange(1,(nely+1)*nelx+1+1,nely+1);
# dof_upper = np.concatenate((node_upper*2-1,node_upper*2),axis=0)-1;
# dof_upper_uy = node_upper *2-1;
# defo_index = dof_upper_uy;
# aaa = unknow_degree(nelx,nely,defo_index)
# aaa.shape

In [None]:
from MMA import mmasub,subsolv
m = 1 
n = nelx* nely; 

x =np.ones(((nelx)*(nely),1))*0.0001;  #initial guess
xval = x;
xmin = np.zeros((n,1))*-0.02# lower bound
xmax = np.ones((n,1))*0.02 # upper bound
xold1 = xval.copy() 
xold2 = xval.copy() 
low = np.ones((n,1))
upp = np.ones((n,1))
a0 = 1.0 
a = np.zeros((m,1)) 
c = 10000*np.ones((m,1))
d = np.zeros((m,1))
move = 0.002 

In [None]:
# %%  OPTIMIZATION
change = 1; loop = 0;
while (change>0.0000000001) and (loop<20):
    loop = loop+1;

    # OBJECTIVE FUNCTION AND SENSITIVITY ANALYSIS
    x = jnp.array(x).reshape(-1,1,order='F');
    obj,u =opti.objective(x);
    jacobian_fn = jax.jacobian(opti.objective, argnums=0);
    dobj_dx = jacobian_fn(x)[0];
    
    # FILTERING OF SENSITIVITIES
    dobj_dx = opti.filt(dobj_dx);

    # OPTIMIZATION
    mu0 = 1.0 # Scale factor for objective function
    mu1 = 1.0 # Scale factor for volume constraint function
    f0val = mu0*obj; # [1,1] objective function value
    df0dx = mu0*dobj_dx.reshape(-1,1,order='F') # gradient of the objective function
    df0dx = np.array(df0dx) # jnp to numpy, (n,1)
    fval = np.array([0])[np.newaxis]; # [1,1] constraint function value, not used here
    dfdx =np.zeros([1,n]); # [1,n] gradient of the constraint function, not used here
    xval = np.array(x).reshape(-1,1,order='F') ; # jnp to numpy, (n,1)
    xmma,ymma,zmma,lam,xsi,eta,mu,zet,s,low,upp = \
        mmasub(m,n,loop,xval,xmin,xmax,xold1,xold2,f0val,df0dx,fval,dfdx,low,upp,a0,a,c,d,move)
    xold2 = xold1.copy()
    xold1 = xval.copy()
    x = xmma.copy()
    x = opti.filt(x);
    change = np.max(np.abs(x-xold1))
    print('loop = ',loop,'change = ',change,'obj = ',obj)

    # PLOT, clear the plot and plot the new one
    %matplotlib qt
    plt.clf()
    plt.plot(value_upper_ux+np.arange(0,nelx+1),value_upper_uy,'o-')
    plt.plot( u[dof_upper_ux].reshape(-1)+np.arange(0,nelx+1),u[dof_upper_uy].reshape(-1),'k-o',linewidth=1,markersize=6)
    # set y axis
    # plt.ylim(0,0.012)
    # plt.xlim(0,nelx+1)
    plt.ylabel('Uy')
    plt.xlabel('Ux')
    plt.legend(['target','optimized'])
    plt.pause(0.01)

In [None]:
# %% PLOT
# display the deformation field x = value_upper_ux,y = value_upper_uy

import matplotlib.pyplot as plt
import numpy as np
%matplotlib qt
plt.figure()
plt.imshow(x.reshape(nely,nelx),cmap='jet')
plt.colorbar()
plt.show()




