Generating libraries


In [77]:
import jax
import numpy as np
import jax.random as random
import jax.numpy as jnp
import jax.numpy.fft as jfft
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from skimage import measure
from numpy import sqrt
from numpy import round
from matplotlib import pyplot as plt
from matplotlib import contour
from jax.numpy.fft import fft2, ifft2
from jax.numpy.fft import fftn, ifftn
from numpy import real
from jax.example_libraries.stax import serial, Gelu
from jax.example_libraries.optimizers import optimizer, make_schedule
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter
from jax.random import PRNGKey, uniform
from jax import grad

## Design the simple neural netweork

In [78]:
class SimpleNN:
    
    # define the initilization of the class
    
    def __init__(self, layers, key):
        self.layers = layers
        self.params = self.initialize_params(layers, key)


    

    def initialize_params(self, layers, key):
        keys = random.split(key, len(layers) - 1)
        params = []
        for i, k in enumerate(keys):
            w = random.normal(k, (layers[i], layers[i + 1])) * jnp.sqrt(2.0 / layers[i]) # javiar initilization
            b = jnp.zeros(layers[i + 1])
            params.append((w, b))
        return params


    

    def forward(self, params, x):
        for i, (w, b) in enumerate(params[:-1]):
            x = jnp.tanh(jnp.dot(x, w) + b)
        w, b = params[-1]
        out = jnp.dot(x, w) + b
        return out

##  Defining the physical loss function

In [79]:
def physical_loss(params, model, input_condition, Lx, Ly, Nx, Ny, dt, epsillon):

    u = model.forward(params, input_condition) # initial input data(real space input data)
     

    p = jnp.concatenate([2 * jnp.pi / Lx * jnp.arange(0, Nx//2), 2 * jnp.pi / Lx * jnp.arange(-Nx//2  , 0)]) # wavenumber in x direction
    q = jnp.concatenate([2 * jnp.pi / Ly * jnp.arange(0, Ny//2), 2 * jnp.pi / Ly * jnp.arange(-Ny//2 , 0)])

    p2 = p**2 # square of wavenumber in x direction
    q2 = q**2
    pp2, qq2 = jnp.meshgrid(p2, q2)

    cahn = epsillon**2
    s_hat = jfft.fft2(cahn * u - dt*(u**3 - 3* u))
    v_hat = s_hat / (cahn + dt * (2 + cahn * (pp2+ qq2)))

    u = jnp.fft.ifft2(v_hat) # v_hat gives the u_(k+1)_hat (on the fourier space) and this inverse to the real space and we get u_k+1

    return jnp.mean(u**2) # return the mean square error of the u_k+1



## Defining the training data 

In [80]:
def train_fourier(model, input_condition, params, key, Lx=2*jnp.pi, Ly=2*jnp.pi, Nx=128, Ny=128, epochs=100, lr=1e-3):
    opt_state = params
    losses = []

    for epoch in range(epochs):
        loss = physical_loss(opt_state, model, input_condition, Lx, Ly, Nx, Ny, dt=0.0001, epsillon=0.05)
        losses.append(loss) 

        grads = grad(lambda params: physical_loss(params, model, input_condition, Lx, Ly, Nx, Ny, dt=0.0001, epsillon=0.05))(opt_state) # calculate the gradient

        opt_state = [(w - lr * gw, b - lr * gb) for (w, b), (gw, gb) in zip(opt_state, grads)] # update the paramters value (weight value)

        if epoch % 50 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")
    
    return opt_state, losses

## Initilize the models, parameter and training setup

In [81]:
def init_model(key, layers, N=128):
    params = []
    for i in range(len(layers) - 1):
        w_key, b_key = jax.random.split(key)
        weight = jax.random.normal(w_key, (layers[i], layers[i + 1])) * jnp.sqrt(2. / layers[i])
        bias = jnp.zeros(layers[i + 1])
        params.append((weight, bias))
    return params


## Trianing the data 

In [82]:
if __name__ == "__main__":

    key = PRNGKey(42)

    layers = [2, 64, 64, 1]  # 2 inputs data  (x, y), 2 hidden layers of size 64, 1 output
    model = SimpleNN(layers,key)

    N = 128
    Lx = 2.0 * jnp.pi
    Ly = 2.0 * jnp.pi
    L  = 2.0 * jnp.pi
    
    hx = Lx / N
    hy = Ly / N

    x = jnp.linspace(-0.5*Lx +hx, 0.5*Lx, N)
    y = jnp.linspace(-0.5*Ly +hy, 0.5*Ly, N)

    xx, yy = jnp.meshgrid(x, y)

    epsillon = 0.05
    input_condition = jnp.tanh(( 2- jnp.sqrt(xx**2 + yy**2)) / jnp.sqrt(2 * epsillon))


    input_condition_reshaped = jnp.stack([xx.ravel(), yy.ravel()], axis=-1) # reshape the input condition size (flatted the two dimension into one dimension)

    params = model.params # initalize the parameters

    trained_params, losses = train_fourier(model, input_condition, params, key)


    # plot the input data

    plt.figure(figsize=(18, 6)) 
    plt.subplot(1, 3, 1)
    plt.title("Initial Condition (u_k)")
    plt.imshow(input_condition, cmap="viridis")
    plt.colorbar()

    # model outpou in real space
    model_output = model.forward(trained_params, input_condition_reshaped)  # Pass the trained parameters
    model_output_real = jfft.ifft2(model_output.reshape(N, N)).real  # Transform Fourier output back to real space
    plt.subplot(1, 3, 2)
    plt.title("Evolved condition (u_k+1)")
    plt.imshow(model_output_real, cmap="viridis")
    plt.colorbar()


    # Training loss
    plt.subplot(1, 3, 3)
    plt.title("Training Loss")
    plt.plot(losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()






TypeError: dot_general requires contracting dimensions to have the same shape, got (128,) and (2,).

3d of AC equation

In [58]:

# # defining parameter for 3d allen cahn

# Lx = 1.2  #length of the domain in x
# Ly = 1.2  #length of the domain in y
# Lz = 1.2  
# hx = Lx / Nx #spatial step size in coordinate x
# hy = Ly / Ny #spatial step size in coordinate y
# hz = Lz / Nz
# dt = 0.001 #time step size
# T = 0.5 #final time
# Nt = int(jnp.round(T/dt)) #number of time steps
# ns = Nt / 10 #number of snapshots


# # defining the function of (x,y,z) direction
# def x_gridpoint_3d(Nx, Lx, hx):
#     x = jnp.linspace(-0.5*Lx+hx,0.5*Lx,Nx)
#     return x
# x = x_gridpoint_3d(Nx, Lx, hx) #number of grid points in x direction and step size and limitation on x  axis
# # print(x.shape)

# def y_gridpoint_3d(Ny, Ly, hy):
#     y = jnp.linspace(-0.5*Ly+hy,0.5*Ly,Ny)
#     return y
# y = y_gridpoint_3d(Ny, Ly, hy) #number of grid points in y direction and step size and limitation on y  axis
# # print(y.shape) 

# def z_gridpoint_3d(Nz, Lz, hz):
#     z = jnp.linspace(-0.5*Lz+hz,0.5*Lz,Nz)
#     return z

# z = z_gridpoint_3d(Nz, Lz, hz) #number of grid points in z direction and step size and limitation on z  axis
# # print(z.shape)

# # creating the meshgrid in x, y, and z direction

# xx, yy, zz = jnp.meshgrid(x, y, z) #creating meshgrid in x, y, and z direction 
# # print(xx.shape , yy.shape, zz.shape)

# # defining the small parameter and cahn number

# epsillon = hx #small parameter
# cahn = epsillon**2 #cahn number 


# # initial condition of allen cahn equation in 3D
# u = np.random.rand(Nx, Ny, Nz)- 0.5 #initial condition of allen cahn equation

# # defining the wavenumber in x, y, and z direction
# kx = jnp.concatenate([2 * jnp.pi / Lx * jnp.arange(0, Nx//2), 2 * jnp.pi / Lx * jnp.arange(-Nx//2 , 0)]) # wavenumber in x direction
# ky = jnp.concatenate([2 * jnp.pi / Ly * jnp.arange(0, Ny//2), 2 * jnp.pi / Ly * jnp.arange(-Ny//2 , 0)]) # wavenumber in y direction
# kz = jnp.concatenate([2 * jnp.pi / Lz * jnp.arange(0, Nz//2), 2 * jnp.pi / Lz * jnp.arange(-Nz//2 , 0)]) # wavenumber in z direction

# # square of wavenumber in x, y, and z direction
# k2x = kx**2 # square of wavenumber in x direction
# k2y = ky**2 # square of wavenumber in y direction
# k2z = kz**2 # square of wavenumber in z direction
# print(k2x.shape)
# print(k2y.shape)
# print(k2z.shape)
# # creating meshgrid in x, y, and z direction for square of wavenumber
# kxx, kyy, kzz = jnp.meshgrid(k2x, k2y, k2z, indexing= 'ij')
# print(kxx.shape,kyy.shape,kzz.shape) # creating meshgrid in x, y, and z direction for square of wavenumber


# # Visualize the initial condition
# u_numpy = np.array(u)  # Convert JAX array to NumPy array


# #simulation loop for 3D Allen-Cahn equation
# for iter in range(1, Nt):
#     # Real part of the solution
#     u = jnp.real(u)
#     print(u)
   
    
#     # Fourier transform of the solution
#     s_hat = jfft.fftn(cahn * u - dt * (u**3 - 3 * u))
#     v_hat = s_hat / (cahn + dt * (2 + cahn * (kxx + kyy + kzz)))
#     u = jfft.ifftn(v_hat)  #inversre the fourier space to the real space

 
    
# #     if iter % ns == 0:
# #         u_numpy = np.array(u)  # Convert JAX array to numpy array
# #         min_val, max_val = np.min(u_numpy), np.max(u_numpy)
# #         print(f"Min value: {min_val}, Max value: {max_val}")
# #         if min_val < 0 < max_val:
# #             verts, faces, _, _ = measure.marching_cubes(u_numpy, level=0)
# #             p1 = Poly3DCollection(verts[faces], facecolor='g', edgecolor='none')
# #             fig = plt.figure()
# #             ax = fig.add_subplot(111, projection='3d')
# #             ax.add_collection3d(p1)
# #             ax.set_box_aspect([1,1,1])
# #             ax.view_init(elev=45, azim=45)
# #             ax.set_xlabel('X Direction')
# #             ax.set_ylabel('Y Direction')
# #             ax.set_zlabel('Z Direction')

# # plt.show()





            