## Allen cahn PINN FOR 2d

In [59]:
import jax.numpy as jnp
from jax import grad, jit, random
import matplotlib.pyplot as plt


In [60]:
# Define the neural network
class SimpleNN:
    def __init__(self, layers, key):
        self.layers = layers
        self.params = self.initialize_params(layers, key)

    def initialize_params(self, layers, key):
        keys = random.split(key, len(layers) - 1)
        params = []
        for i, k in enumerate(keys):
            w = random.normal(k, (layers[i], layers[i + 1])) * jnp.sqrt(2.0 / layers[i])
            b = jnp.zeros(layers[i + 1])
            params.append((w, b))
        return params

    def forward(self, params, x):
        for i, (w, b) in enumerate(params[:-1]):
            x = jnp.tanh(jnp.dot(x, w) + b)
        w, b = params[-1]
        out = jnp.dot(x, w) + b
        return out[0]

In [61]:
@jit
def allen_cahn_pde(x, y, grad_y, hessian_y):
    dy_t = grad_y[2]  # Partial derivative with respect to time
    dy_xx = hessian_y[0, 0]  # Second partial derivative with respect to x
    dy_yy = hessian_y[1, 1]  # Second partial derivative with respect to y
    return dy_t - 0.001 * (dy_xx + dy_yy) - 5 * (x - x**3)

# Define the physics loss
def physics_loss(params, model, x,N):
    y = model.forward(params, x)
   
    y_hat = jnp.fft.fft(y)
  
    k = jnp.fft.fftfreq(N, d = 1) * 2 * jnp.pi
    grad_y_hat = 1j * k * y_hat
    grad_yy_hat = -k ** 2 * y_hat
    non_linear_term = 5 * (y -y**3)
    pde_residual_hat = grad_y_hat - 0.001 * grad_yy_hat - non_linear_term
    pde_residual = jnp.fft.ifft(pde_residual_hat).real
    # # print(f"y shape: {y.shape}")
    # grad_y = grad(lambda x: model.forward(params, x))(x)
    # hessian_y = jnp.array([
    #     [grad(lambda x: grad(lambda x: model.forward(params, x))(x)[i])(x) for i in range(x.shape[0])]
    #     for _ in range(x.shape[0])
    # ])
    # pde_residual = allen_cahn_pde(x, y, grad_y, hessian_y)
    # # print(f"pde_residual shape: {pde_residual.shape}")
    return jnp.mean(pde_residual**2)

In [62]:

# Training loop
def train(model, params, key, losses_record, epoches_record, epochs=50, lr=1e-3 , N=128):
    opt_state = params
    for epoch in range(epochs):
        x = jnp.linspace(0, 1, N)
        # print(x.shape)
        loss = physics_loss(opt_state, model,x,N )  # (x, y, t)
        losses_record.append(loss)
        epoches_record.append(epoch)
        grads = grad(lambda params: physics_loss(params, model, x, N))(opt_state)
        opt_state = [(w - lr * gw, b - lr * gb) for (w, b), (gw, gb) in zip(opt_state, grads)]
        if epoch % 5 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")
    return opt_state

In [63]:


losses_record = []
epoches_record = []
# Main program
if __name__ == "__main__":
    layers = [128, 20, 20, 1]  # Input (x, y, t), hidden layers, output
    key = random.PRNGKey(0)
 
    model = SimpleNN(layers, key)
    # print_model_architecture(model)
    
    params = model.params
    trained_params = train(model, params, key, losses_record, epoches_record)


plt.plot(epoches_record, losses_record)
plt.xlabel("Epoch")
plt.legend(["Training Loss"])

ValueError: axis -1 is out of bounds for array of dimension 0