# Burgers' Equation Solution using PINNs

In [1]:
# %pip install torch
# %pip install matplotlib
# %pip install tqdm


In [2]:
import torch
from torch import nn
from torch import optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split

import math
import numpy as np

import matplotlib.pyplot as plt
import tqdm
from IPython.display import clear_output


### Neural Network structure

In [3]:
class PINN(nn.Module):
  def __init__(self, fc_list, viscosity):
    super().__init__()

    # for some reason Tanh is the only one that works
    self.activation = nn.Tanh()

    # viscosity coefficient
    self.l = viscosity

    self.loss_function = nn.MSELoss()

    self.fc_list = fc_list # first, last = input_size, output_size
    self.network = nn.ModuleList([nn.Linear(fc_list[i], fc_list[i + 1]) for i in range(len(fc_list) - 1)])

    # works better than nothing or nn.init.xavier_normal_
    for i in range(len(self.fc_list)-1):
      nn.init.kaiming_normal_(self.network[i].weight.data)

  def forward(self, x):
    output = x

    for i in range(len(self.fc_list) - 2):
      output = self.network[i](output)
      output = self.activation(output)

    # even though result is between -1 and 1 is looks like it's better not to apply Tanh
    return self.network[-1](output)

  # t = 0 -> F(x, t) = sin(-pi * x)
  def error_initial(self, state):
    return self.loss_function(self.forward(state), (-math.pi * state[:, 0].clone().detach()).sin().unsqueeze(1))

  # x = ±1 -> F(x, t) = 0
  # also F(0, t) = 0 but it's more fair not to include x=0 here
  def error_border(self, state):
    return self.loss_function(self.forward(state), torch.zeros(state.size(0), 1))

  # ∂t*u + u*∂x*u - λ*∂xx*u tells how far is model from truth
  def error_pde(self, state):
    # it's better not to touch anything under here because it always fails to work normally

    state_ = state.clone()
    state_.requires_grad = True

    u = self.forward(state_)

    dx_dt = torch.autograd.grad(
      u, state_,
      torch.ones([state_.shape[0],1]),
      retain_graph=True, create_graph=True
    )[0]
    dxx_dtt = torch.autograd.grad(
      dx_dt, state_,
      torch.ones(state_.shape),
      create_graph=True
    )[0]

    u_t = dx_dt[:,[1]]
    u_x = dx_dt[:,[0]]
    u_xx = dxx_dtt[:,[0]]

    residue = u_t + u * u_x - self.l * u_xx
    self.last_pde_error = self.loss_function(residue, torch.zeros_like(residue))
    return self.loss_function(residue, torch.zeros_like(residue))

  def error_total(self, state_beginning, state_border, state_general):
    return self.error_initial(state_beginning) + self.error_border(state_border) + self.error_pde(state_general)

### Network construction

In [4]:
EPOCHS = 10000
LEARNING_RATE = 0.0015
SCHEDULER_RATE = 20
GAMMA = 0.9999
BATCH_SIZE = 1000

DATASET_SIZE = 50000
TRAINING_FRACTION = 0.95

FC_LAYERS = [2, 16, 32, 32, 32, 32, 16, 1]

VISCOSITY = 0.01 / math.pi

solver = PINN(FC_LAYERS, VISCOSITY)


### Visualizations

In [5]:
# TERRIBLE CODE MUST BE REWRITTEN FROM SCRATCH

def plot_stats(train_loss: list[float], valid_loss: list[float]):
  plt.figure(figsize=(16, 8))

  plt.plot(train_loss, label='Training loss')
  plt.plot(valid_loss, label='Validation loss')

  plt.legend()

  plt.ylabel("Loss")
  plt.xlabel("Epoch")

  plt.show()

def plot_fun(x):
  u_tensor = solver(x)
  u = u_tensor.squeeze().detach().numpy()
  return u

def plot_t():
  x = np.arange(-1, 1.01, 0.01)
  x_t = torch.Tensor(x)
  inputs1 = torch.cat([x_t.view(201, 1), torch.zeros(x_t.size()).view(201, 1)], dim=1)
  inputs2 = torch.cat([x_t.view(201, 1), torch.full_like(x_t, 0.25).view(201, 1)], dim=1)
  inputs3 = torch.cat([x_t.view(201, 1), torch.full_like(x_t, 0.5).view(201, 1)], dim=1)
  inputs4 = torch.cat([x_t.view(201, 1), torch.full_like(x_t, 0.75).view(201, 1)], dim=1)

  u1 = plot_fun(inputs1)
  u2 = plot_fun(inputs2)
  u3 = plot_fun(inputs3)
  u4 = plot_fun(inputs4)

  plt.plot(x, u1, marker='', linestyle='solid', color='#fd8a8a')
  plt.plot(x, u2, marker='', linestyle='solid', color='#ffcbcb')
  plt.plot(x, u3, marker='', linestyle='solid', color='#a8d1d1')
  plt.plot(x, u4, marker='', linestyle='solid', color='#9ea1d4')

  plt.xlabel('x')
  plt.ylabel('u')

  plt.show()


### Dataset generation

In [6]:
class CustomDataset(Dataset):
  def __init__(self, num_samples):
    super().__init__()

    self.num_samples = num_samples

    num_beginning = int(num_samples)
    num_border = int(num_samples)
    num_random = int(num_samples)

    self.random_x = torch.rand((num_random, 1)) * 2 - 1
    self.random_t = torch.rand((num_random, 1))
    self.random_all = torch.cat([self.random_x, self.random_t], dim=1)

    self.border_x = torch.randint(0, 2, (num_border, 1), dtype=torch.float32) * 2 - 1
    self.border_t = torch.rand((num_border, 1))
    self.border_all = torch.cat([self.border_x, self.border_t], dim=1)

    self.beginning_x = torch.rand((num_beginning, 1)) * 2 - 1
    self.beginning_t = torch.zeros(num_beginning, 1)
    self.beginning_all = torch.cat([self.beginning_x, self.beginning_t], dim=1)

  def __len__(self):
    return self.num_samples

  def __getitem__(self, item):
    return self.beginning_all[item], self.border_all[item], self.random_all[item]

dataset = CustomDataset(DATASET_SIZE)

training_dataset, validation_dataset = random_split(dataset,
  (int(len(dataset) * TRAINING_FRACTION), len(dataset) -  int(len(dataset) * TRAINING_FRACTION)),
  generator=torch.Generator().manual_seed(238)
)

training_loader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)


### Network training

In [7]:
def train_batch(model, batch_beginning, batch_border, batch_general, optimizer):
  loss = model.error_total(batch_beginning, batch_border, batch_general)
  optimizer.zero_grad()
  # maybe retain_graph=False
  loss.backward(retain_graph=True)
  optimizer.step()

  return loss.item()

def validate_batch(model, batch_beginning, batch_border, batch_general):
  loss = model.error_total(batch_beginning, batch_border, batch_general)

  return loss.item()


In [8]:
def train(model, training_loader, validation_loader, optimizer, scheduler):
  training_loss_history, validation_loss_history = [], []

  for epoch in range(EPOCHS):
    training_loss = 0
    validation_loss = 0
    training_cnt = 0
    validation_cnt = 0

    for i, (batch_beginning, batch_border, batch_general) in enumerate(training_loader):
      training_loss += train_batch(model, batch_beginning, batch_border, batch_general, optimizer)
      training_cnt += 1

    for i, (batch_beginning, batch_border, batch_general) in enumerate(validation_loader):
      validation_loss += validate_batch(model, batch_beginning, batch_border, batch_general)
      validation_cnt += 1

    training_loss_history.append(training_loss / training_cnt)
    validation_loss_history.append(validation_loss / validation_cnt)

    if epoch % 10 == 0:
      # clear_output(wait=True)
      # plot_stats(training_loss_history, validation_loss_history)
      # plot_t()
      print("Epoch {0}, PDE MSE Error: {1}".format(epoch, model.last_pde_error.item()))

    if epoch % SCHEDULER_RATE == 0:
      scheduler.step()

In [None]:
optimizer = optim.Adam(solver.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)

train(solver, training_loader, validation_loader, optimizer, scheduler)


In [None]:
plot_t()