<a href="https://colab.research.google.com/github/dasamitansu159/Physic-Informed-Neural-Networks/blob/main/Thin_fims_copy_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
d0 = 0.158
s = -0.015
l0 = 0.137
h0 = 1
gama = 0.0001
mu = 0.01

def phi_h(h):
  return 6 * s * (d0**2 / h ** 4) * (1 - 3 * ((l0/h) ** 6))
phi_h0 = phi_h(h0)
print(phi_h0)
lambda_n = 2 * np.pi / (np.sqrt(- phi_h0/ gama))

In [None]:
import numpy as onp
import jax.numpy as np
from jax import random, grad, vmap, jit, jacfwd, jacrev
from jax.example_libraries import optimizers
from jax.nn import relu
from jax import lax
from jax.flatten_util import ravel_pytree
import itertools
from functools import partial
from torch.utils import data
from tqdm import trange

In [None]:
t_dwet = 12 * mu * gama / ((h0 ** 3) * (phi_h0 ** 2)) * np.log(h0/0.001)
print(t_dwet)

In [None]:
lambda_n

In [None]:
x_star = ((np.square(h0)) / d0) * np.sqrt(gama/ (6 * abs(s)))
t_star = ((np.power(h0, 5))/np.power(d0, 4)) * ((mu * gama) / (12 * np.square(s)))
print(x_star, t_star)
domain = lambda_n / x_star
t_domain = t_dwet / t_star

In [None]:
domain, t_domain

In [None]:
def MLP(layers, L=1.0, M=1, activation=np.tanh):

    def input_encoding(t, x):
        w = 2.0 * np.pi / L
        k = np.arange(1, M + 1)
        out = np.hstack([t, x])
        return out

    def init(rng_key):
      def init_layer(key, d_in, d_out):
          k1, k2 = random.split(key)
          glorot_stddev = 1.0 / np.sqrt((d_in + d_out) / 2.)
          W = glorot_stddev * random.normal(k1, (d_in, d_out))
          b = np.zeros(d_out)
          return W, b
      key, *keys = random.split(rng_key, len(layers))
      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
      return params

    def apply(params, inputs):
        t = inputs[0]
        x = inputs[1]
        H = input_encoding(t, x)
        for W, b in params[:-1]:
            outputs = np.dot(H, W) + b
            H = activation(outputs)
        W, b = params[-1]
        outputs = np.dot(H, W) + b
        return outputs

    return init, apply


In [None]:
# def modified_MLP(layers, L=1.0, M=1, activation = relu):
#   def xavier_init(key, d_in, d_out):
#     glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
#     W = glorot_stddev * random.normal(key, (d_in, d_out))
#     b = np.zeros(d_out)
#     return W, b

#   # Define input encoding function
#   def input_encoding(t, x):
#       w = 2 * np.pi / L
#       k = np.arange(1, M + 1)
#       out = np.hstack([t, 1,
#                          np.cos(k * w * x), np.sin(k * w * x)])
#       return out


#   def init(rng_key):
#       U1, b1 =  xavier_init(random.PRNGKey(12345), layers[0], layers[1])
#       U2, b2 =  xavier_init(random.PRNGKey(54321), layers[0], layers[1])
#       def init_layer(key, d_in, d_out):
#           k1, k2 = random.split(key)
#           W, b = xavier_init(k1, d_in, d_out)
#           return W, b
#       key, *keys = random.split(rng_key, len(layers))
#       params = list(map(init_layer, keys, layers[:-1], layers[1:]))
#       return (params, U1, b1, U2, b2)

#   def apply(params, inputs):
#       params, U1, b1, U2, b2 = params

#       t = inputs[0]
#       x = inputs[1]
#       inputs = input_encoding(t, x)
#       U = activation(np.dot(inputs, U1) + b1)
#       V = activation(np.dot(inputs, U2) + b2)
#       for W, b in params[:-1]:
#           outputs = activation(np.dot(inputs, W) + b)
#           inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V)
#       W, b = params[-1]
#       outputs = np.dot(inputs, W) + b
#       return outputs
#   return init, apply
def modified_MLP(layers, L=1.0, M=1, activation = relu):
  def xavier_init(key, d_in, d_out):
    glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
    W = glorot_stddev * random.normal(key, (d_in, d_out))
    b = np.zeros(d_out)
    return W, b

  # Define input encoding function
  def input_encoding(t, x):
      w = 2 * np.pi / L
      k = np.arange(1, M + 1)
      out = np.hstack([t, x])
      return out


  def init(rng_key):
      U1, b1 =  xavier_init(random.PRNGKey(12345), layers[0], layers[1])
      U2, b2 =  xavier_init(random.PRNGKey(54321), layers[0], layers[1])
      def init_layer(key, d_in, d_out):
          k1, k2 = random.split(key)
          W, b = xavier_init(k1, d_in, d_out)
          return W, b
      key, *keys = random.split(rng_key, len(layers))
      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
      return (params, U1, b1, U2, b2)

  def apply(params, inputs):
      params, U1, b1, U2, b2 = params

      t = inputs[0]
      x = inputs[1]
      inputs = input_encoding(t, x)
      U = activation(np.dot(inputs, U1) + b1)
      V = activation(np.dot(inputs, U2) + b2)
      for W, b in params[:-1]:
          outputs = activation(np.dot(inputs, W) + b)
          inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V)
      W, b = params[-1]
      outputs = np.dot(inputs, W) + b
      return outputs
  return init, apply

## DataSet Generator

In [None]:
class DataGenertor(data.Dataset):
  def __init__(self, t0, t1, n_t = 10, n_x = 64, rng_key=random.PRNGKey(1234)):
    "Initialization"
    self.t0 = t0
    self.t1 = t1
    self.n_t = n_t
    self.n_x = n_x
    self.key = rng_key

  def __getitem__(self, index):
    self.key, subkey = random.split(self.key)
    batch = self.__data_generation(subkey)
    return batch

  @partial(jit, static_argnums=(0,))
  def __data_generation(self, key):
    subkeys = random.split(key, 2)
    t_r = random.uniform(subkeys[0], shape=(self.n_t,), minval=self.t0, maxval=self.t1).sort()
    x_r = random.uniform(subkeys[1], shape=(self.n_x,), minval=0, maxval=domain)

    #x_star = ((np.square(h0)) / d0) * np.sqrt(gama/ (6 * abs(s)))
#     t_star = ((np.power(h0, 5))/np.power(d0, 4)) * ((mu * gama) / (12 * np.square(s)))

    #x_r = x_r / x_star
    #t_r = t_r / t_star

    batch = (t_r, x_r)
    return batch

## Model PINN

In [None]:
class PINN:
  def __init__(self, key, layers, M_x, t0, t1, n_t, n_x, tol=1.0):

    self.tol = tol
    self.M = np.triu(np.ones((n_t, n_t)), k=1).T

#     self.init, self.apply = modified_MLP(layers, L = 2.0, M = M_x, activation=np.tanh)
    self.init, self.apply = modified_MLP(layers, L = 2.0, M = M_x, activation=np.tanh)
    params = self.init(rng_key = key)


    self.opt_init, self.opt_update, self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3,
                                                                                                   decay_steps=10000,
                                                                                                   decay_rate = 0.9))
    self.opt_state = self.opt_init(params)
    _, self.unravel = ravel_pytree(params)

    self.h_pred_fn = vmap(vmap(self.neural_net, (None, 0, None)), (None, None, 0))
    self.r_pred_fn = vmap(vmap(self.residual_net, (None, 0, None)), (None, None, 0))

    self.loss_log = []
    self.loss_res_log = []

    self.itercount = itertools.count()

  def neural_net(self, params, t, x):
      z = np.stack([t, x])
      outputs = self.apply(params, z)
      h = outputs[0]
      return np.maximum(h, 0.137)

  def h_x(self, params, t, x):
      return np.sum(grad(self.neural_net, argnums=2)(params, t, x))

  def h_xx(self, params, t, x):
      return np.sum(grad(self.h_x, argnums=2)(params, t, x))

  def h_xxx(self, params, t, x):
      return np.sum(grad(self.h_xx, argnums=2)(params, t, x))

  def exp_a(self, params, t, x):
    h = (self.neural_net(params, t, x))
    phi = -2 * s * (np.square(d0)/ h **3) * (1 - ((np.power(l0, 6)/ h ** 6)))
    phi_star = 6 * h0 * np.square((d0/h0)) * abs(s)
    #phi_dim = (1/(6 * abs(s)) * (np.power(h0, 3)/ np.square(d0)) * phi)
    phi_dim = abs(phi / phi_star)

    return np.sum(grad(self.h_x, argnums=2)(params, t, x)  - phi_dim)

  def exp(self, params, t, x):
      h = (self.neural_net(params, t, x))
      a_x = grad(self.exp_a, argnums=2)(params, t, x)

      return np.sum((h ** 3 / np.power(h0, 3)) * a_x)


  def residual_net(self, params, t, x):
      h = self.neural_net(params, t, x)
      h_t = grad(self.neural_net, argnums=1)(params, t, x)

      fin_x = grad(self.exp, argnums=2)(params, t, x)

      res = h_t + fin_x
      return res

  @partial(jit, static_argnums=(0,))
  def residuals_and_weights(self, params, batch, tol):
      t_r, x_r = batch
      r_pred = self.r_pred_fn(params, t_r, x_r)
      L_t = np.mean(r_pred**2, axis= 1)
      w = lax.stop_gradient(np.exp(-tol * (self.M @ L_t)))
      return L_t, w

  @partial(jit, static_argnums=(0,))
  def boundary_loss(self, params, batch):
      t, x = batch

      loss_periodic = np.mean((self.h_pred_fn(params, t, np.zeros_like(x)) - self.h_pred_fn(params, t, np.ones_like(x) * domain)) ** 2)

      return loss_periodic

  def initial_condition(self, params, t, x):
    return 1 + 0.001 * np.cos(2 * np.pi * x / domain)

  @partial(jit, static_argnums=(0,))
  def ics_loss(self, params, batch):
    t, x = batch
    a = self.h_pred_fn(params, np.full_like(x, 0), x)
#       loss_ics = np.mean((a - np.full_like(a, self.initial_condition(np.full_like(x, 0), x))) ** 2)
    initial_condition = vmap(vmap(self.initial_condition, (None, 0, None)), (None, None, 0))
    loss_ics = np.mean((a - initial_condition(params, np.zeros_like(x), x)) ** 2)
    return loss_ics

  @partial(jit, static_argnums=(0,))
  def loss(self, params, batch):
      L_t, w = self.residuals_and_weights(params, batch, self.tol)
      loss_bc = self.boundary_loss(params, batch)
      loss_ics = self.ics_loss(params, batch)
      loss = np.mean(L_t * w) + loss_bc + loss_ics
      return loss

  @partial(jit, static_argnums=(0,))
  def step(self, i, opt_state, batch):
      params = self.get_params(opt_state)
      g = grad(self.loss)(params, batch)

      return self.opt_update(i, g, opt_state)

  def train(self, dataset, nIter=10000):
      res_data = iter(dataset)
      pbar =trange(nIter)

      for it in pbar:
        batch = next(res_data)

        self.current_count=next(self.itercount)
        self.opt_state = self.step(self.current_count, self.opt_state, batch)

        if it % 1000 == 0:
          params = self.get_params(self.opt_state)

          loss_value = self.loss(params, batch)
          boundary_loss = self.boundary_loss(params, batch)
          ics_loss = self.ics_loss(params, batch)

          _, W_value = self.residuals_and_weights(params, batch, tol)
          pbar.set_postfix({"Loss": loss_value,
                            "Loss_ics": ics_loss,
                            "Loss_boundary": boundary_loss})




In [None]:
key = random.PRNGKey(1234)

M_x = 5
t0 = 0
t1 = 1.5
n_t = 150
n_x = 150

d = 2 * M_x + 2
tol_list = [1e-3]
layers = [2, 128, 128, 128, 1]
dataset = DataGenertor(t0, t1, n_t, n_x)

N = 1
for k in range(N):
  print("Final Time: {}".format((k+1) * t1))
  model = PINN(key, layers, M_x, t0, t1, n_t, n_x)

  for tol in tol_list:
    model.tol = tol
    print("tol: ", model.tol)

    model.train(dataset, nIter = 50000)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Function to plot u predictions over a 2D grid of x and y values
def plot_u_2d(model, x_range=(0, t1/t_star), y_range=(0, domain), n_points=1000):
    # Generate grid of x and y values
    x_values = np.linspace(x_range[0], x_range[1], n_points)
    y_values = np.linspace(y_range[0], y_range[1], n_points)


    # Get model parameters
    params = model.get_params(model.opt_state)

    # Predict u values over the grid
    u_flat = model.h_pred_fn(params, x_values, y_values)
    U = u_flat.reshape(n_points, n_points)

    # Plotting
    plt.figure(figsize=(8, 6))
    plt.contourf(x_values, y_values, U, levels=50, cmap='viridis')
    plt.colorbar(label='Predicted h')
    plt.xlabel('t')
    plt.ylabel('x')
    plt.title('Predicted h over 2D grid')
    plt.grid(True)
    plt.show()

# Call the plot function
plot_u_2d(model)


In [None]:
def plot_u_vs_y_at_x(model, x_values=[0.5], y_min=0, y_max=domain, num_points=100):
    # Get the parameters from the trained model
    params = model.get_params(model.opt_state)

    # Define a range for y values
    y_values = np.linspace(y_min, y_max, num_points)

    # Plot the results for each x_value
    plt.figure(figsize=(10, 6))
    for x_value in x_values:
        # Predict the u values for the given x and y values
        x_vals = np.full_like(y_values, x_value)

        # Ensure x_vals and y_values are in the right shape
        x_vals = x_vals.reshape(-1)
        y_values = y_values.reshape(-1)

        # Predict using the model
        u_values = model.h_pred_fn(params, x_vals[:, None], y_values[:, None])

        # Extract the diagonal of the matrix to get u values at the corresponding y positions
        u_values = np.diagonal(u_values)

        # Plot the results
        plt.plot(y_values, u_values, label=f'h(t={x_value}, x)')

    plt.xlabel('x')
    plt.ylabel('h')
    plt.title('Relationship between h and x at different t values')
    plt.legend()
    plt.grid(True)
    plt.savefig("Thin_fim")
    plt.show()


# Plot the visualization for different x_values
plot_u_vs_y_at_x(model, x_values=[0.1, 0.4, 0.8, 1.0, 1.2, 1.4, 1.5])