# PINN Implementation of Ashourvan & Diamond Paper

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import grad

import mlflow
import mlflow.pytorch

from tqdm.notebook import tqdm

import imageio.v2 as imageio
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import os

## Define Parameters

In [2]:
l = 1.0     # not provided by paper, need to check relation with Λ 
α = 6.0
D_c = 0.78
C_χ = 0.95
a_u = 1.0
μ_c = 0.78
β = 0.1
Λ = 4000.0
ϵ_c = 6.25

In [3]:
g_i = 5.1
ϵ_i = 0.002

In [4]:
physical_params = {
    'l': l,
    'α': α,
    'D_c': D_c,
    'C_χ': C_χ,
    'a_u': a_u,
    'μ_c': μ_c,
    'β': β,
    'Λ': Λ,
    'ϵ_c': ϵ_c,
    'g_i': g_i,
    'ϵ_i': ϵ_i
}

## Define PDEs

### Define repeated calculations

In [5]:
def compute_w(C_χ, l, ϵ, α, a_u, u):
    return C_χ * l**2 * ϵ / torch.sqrt(α**2 + a_u * u**2)

### Dynamic equation for mean density


In [6]:
def pde_mean_density(x, n_t, n_x, n_xx, ϵ, l, α, D_c):    
    intermediate = l**2 * ϵ * n_x / α
    intermediate_x = grad(intermediate, x, grad_outputs=torch.ones_like(intermediate), retain_graph=True, create_graph=True)[0]
    
    return (n_t - intermediate_x - D_c * n_xx)**2

### Dynamic equation for mean vorticity

In [7]:
def pde_mean_vorticity(x, ϵ, n_x, u_t, u_xx, l, α, μ_c, w):    
    intermediate = (l**2 * ϵ / α - w) * n_x
    intermediate_x = grad(intermediate, x, grad_outputs=torch.ones_like(intermediate), retain_graph=True, create_graph=True)[0]
    
    return (u_t - intermediate_x - w * u_xx - μ_c * u_xx)**2

### Dynamic equation for turbulent potential entrosphy

In [8]:
def pde_tpe(x, ϵ, n_x, u_x, ϵ_t, ϵ_x, l, β, Λ, ϵ_c, w):
    intermediate = l**2 * torch.sqrt(ϵ) * ϵ_x
    intermediate_x = grad(intermediate, x, grad_outputs=torch.ones_like(intermediate), retain_graph=True, create_graph=True)[0]
    
    return (ϵ_t - β * intermediate_x - Λ * (w * (n_x - u_x)**2 - ϵ**(3/2) / ϵ_c**0.5 + ϵ))**2

## Define Initial Conditions

In [9]:
def n_initial_cond(t, x):
    return -g_i * x

def u_initial_cond(t, x):
    return torch.zeros(x.shape, device=x.device.type)

def ϵ_initial_cond(t, x):
    return torch.full(x.shape, ϵ_i, device=x.device.type)

## Define Boundary Conditions

In [10]:
def n_boundary_cond(t, x):
    out = torch.full(x.shape, -g_i, device=x.device.type)
    out = out * x
    
    return out

def u_boundary_cond(t, x):
    return torch.zeros(x.shape, device=x.device.type)


def ϵ_x_boundary_cond(t, x):
    return torch.zeros(x.shape, device=x.device.type)

## PINN Implementation

In [11]:
class CustomOutputLayer(nn.Module):
    def __init__(self, pars):
        super(CustomOutputLayer, self).__init__()
        self.n = nn.Sequential(nn.Linear(pars['width'], 1))
        self.u = nn.Sequential(nn.Linear(pars['width'], 1))
        self.ϵ = nn.Sequential(nn.Linear(pars['width'], 1))
        
    def forward(self, x):
        return torch.hstack((self.n(x), self.u(x), torch.abs(self.ϵ(x) + 1e-2)))

In [12]:
class PINN(nn.Module):
    def __init__(self, pars: dict):
        super().__init__()
        self.pars = pars
        
        self.modules = [nn.BatchNorm1d(2), nn.Linear(2, self.pars['width'])] # nn.LayerNorm(2)
        for i in range(self.pars['layers'] - 1):
            # self.modules.append(nn.LayerNorm(self.pars['width']))
            self.modules.append(nn.GELU())
            self.modules.append(nn.Linear(self.pars['width'], self.pars['width']))
        
        # self.modules.append(nn.LayerNorm(self.pars['width']))
        self.modules.append(CustomOutputLayer(pars))
        
        self.model = nn.Sequential(*self.modules)
        self.model.to(self.pars['device'])
        
        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.pars['lr'])
        self.num_params = sum([len(params) for params in [p for p in self.model.parameters()]])
        
        self.epoch = 0
        
        t = np.linspace(self.pars['t_min'], self.pars['t_max'], 100)
        x = np.linspace(self.pars['x_min'], self.pars['x_max'], 100)

        self.eval_t, self.eval_x = np.meshgrid(t, x)
        self.eval_t = torch.Tensor(self.eval_t).reshape(-1, 1).to(self.pars['device'])
        self.eval_x = torch.Tensor(self.eval_x).reshape(-1, 1).to(self.pars['device'])
        
        self.eval_t.requires_grad_()
        self.eval_x.requires_grad_()
        
        eval_t_interior, eval_x_interior = np.meshgrid(t[1:-1], x[1:-1])
        eval_t_interior = torch.Tensor(eval_t_interior).reshape(-1, 1).to(self.pars['device'])
        eval_x_interior = torch.Tensor(eval_x_interior).reshape(-1, 1).to(self.pars['device'])
        
        self.eval_X_interior = torch.hstack((eval_t_interior, eval_x_interior))
        self.eval_X_interior.requires_grad_()
        
        eval_t_initial = torch.zeros_like(self.eval_x)
        self.eval_X_initial = torch.hstack((eval_t_initial, self.eval_x))
        self.eval_X_initial.requires_grad_()
        
        eval_t_boundary = torch.Tensor(np.vstack([t, t])).reshape(-1, 1).to(self.pars['device'])
        eval_x_boundary = torch.Tensor(np.vstack([x[0] * np.ones_like(t), x[-1] * np.ones_like(t)])).reshape(-1, 1).to(self.pars['device'])
        
        self.eval_X_boundary = torch.hstack((eval_t_boundary, eval_x_boundary))
        self.eval_X_boundary.requires_grad_()
        
        self.plot_t = self.eval_t.view(100, 100).detach().cpu().numpy()
        self.plot_x = self.eval_x.view(100, 100).detach().cpu().numpy()
        
        self.plot_files = {}
        
        if 'plot_vars_list' in self.pars:
            for varname in self.pars['plot_vars_list']:
                self.plot_files[varname] = []
        
    def __call__(self, X):
        return self.model(X)
        
    def sample_interior_points(self):
        t = torch.empty((self.pars['interior_batch_size'], 1), device=self.pars['device']).uniform_(self.pars['t_min'], self.pars['t_max'])
        x = torch.empty((self.pars['interior_batch_size'], 1), device=self.pars['device']).uniform_(self.pars['x_min'], self.pars['x_max'])
        X_interior = torch.cat((t, x), 1)
        X_interior.requires_grad_()
        
        return X_interior
    
    def sample_initial_points(self):
        t = torch.zeros(self.pars['initial_batch_size'], 1, device=self.pars['device'])
        x = torch.empty((self.pars['initial_batch_size'], 1), device=self.pars['device']).uniform_(self.pars['x_min'], self.pars['x_max'])
        X_initial = torch.cat((t, x), 1)
        X_initial.requires_grad_()
        
        return X_initial
    
    def sample_boundary_points(self):
        options = torch.tensor([self.pars['x_min'], self.pars['x_max']], device=self.pars['device'])
        
        t = torch.empty((self.pars['boundary_batch_size'], 1), device=self.pars['device']).uniform_(self.pars['t_min'], self.pars['t_max'])
        x = options[torch.randint(0, 2, (self.pars['boundary_batch_size'], 1), device=self.pars['device'])]
        X_boundary = torch.cat((t, x), 1)
        X_boundary.requires_grad_()
        
        return X_boundary
    
    def forward(self, X_interior, X_initial, X_boundary):
        # X shape: (batch_size, 2), where 2nd dimension is [t, x]
        # Y shape: (batch_size, 3), where 2nd dimension is [n, u, ϵ]
        
        t_interior = X_interior[:, 0].reshape(-1, 1)
        x_interior = X_interior[:, 1].reshape(-1, 1)
        t_initial = X_initial[:, 0].reshape(-1, 1)
        x_initial = X_initial[:, 1].reshape(-1, 1)
        t_boundary = X_boundary[:, 0].reshape(-1, 1)
        x_boundary = X_boundary[:, 1].reshape(-1, 1)
        
        # forward pass
        Y_interior = self.model(torch.hstack((t_interior, x_interior)))
        Y_initial = self.model(torch.hstack((t_initial, x_initial)))
        Y_boundary = self.model(torch.hstack((t_boundary, x_boundary)))
        
        n_interior = Y_interior[:, 0].reshape(-1, 1)
        u_interior = Y_interior[:, 1].reshape(-1, 1)
        ϵ_interior = Y_interior[:, 2].reshape(-1, 1)
        
        n_initial = Y_initial[:, 0].reshape(-1, 1)
        u_initial = Y_initial[:, 1].reshape(-1, 1)
        ϵ_initial = Y_initial[:, 2].reshape(-1, 1)
        
        n_boundary = Y_boundary[:, 0].reshape(-1, 1)
        u_boundary = Y_boundary[:, 1].reshape(-1, 1)
        ϵ_boundary = Y_boundary[:, 2].reshape(-1, 1)
        
        n_x_interior = grad(n_interior, x_interior, grad_outputs=torch.ones_like(n_interior), retain_graph=True, create_graph=True)[0]
        n_t_interior = grad(n_interior, t_interior, grad_outputs=torch.ones_like(n_interior), retain_graph=True, create_graph=True)[0]
        
        n_xx_interior = grad(n_x_interior, x_interior, grad_outputs=torch.ones_like(n_x_interior), retain_graph=True, create_graph=True)[0]
        
        u_x_interior = grad(u_interior, x_interior, grad_outputs=torch.ones_like(u_interior), retain_graph=True, create_graph=True)[0]
        u_t_interior = grad(u_interior, t_interior, grad_outputs=torch.ones_like(u_interior), retain_graph=True, create_graph=True)[0]
        
        u_xx_interior = grad(u_x_interior, x_interior, grad_outputs=torch.ones_like(u_x_interior), retain_graph=True, create_graph=True)[0]
        
        ϵ_x_interior = grad(ϵ_interior, x_interior, grad_outputs=torch.ones_like(ϵ_interior), retain_graph=True, create_graph=True)[0]
        ϵ_t_interior = grad(ϵ_interior, t_interior, grad_outputs=torch.ones_like(ϵ_interior), retain_graph=True, create_graph=True)[0]
        
        ϵ_x_boundary = grad(ϵ_boundary, x_boundary, grad_outputs=torch.ones_like(ϵ_boundary), retain_graph=True, create_graph=True)[0]
        
        w = compute_w(C_χ, l, ϵ_interior, α, a_u, u_interior)
        
        density_loss = pde_mean_density(x_interior, n_t_interior, n_x_interior, n_xx_interior, ϵ_interior, l, α, D_c).mean()
        vorticity_loss = pde_mean_vorticity(x_interior, ϵ_interior, n_x_interior, u_t_interior, u_xx_interior, l, α, μ_c, w).mean()
        tpe_loss = pde_tpe(x_interior, ϵ_interior, n_x_interior, u_x_interior, ϵ_t_interior, ϵ_x_interior, l, β, Λ, ϵ_c, w).mean()
        interior_loss = (density_loss + vorticity_loss + tpe_loss)/3
        
        mse = nn.MSELoss()
        
        initial_n_loss = mse(n_initial_cond(t_initial, x_initial), n_initial)
        initial_u_loss = mse(u_initial_cond(t_initial, x_initial), u_initial)
        initial_ϵ_loss = mse(ϵ_initial_cond(t_initial, x_initial), ϵ_initial)
        initial_loss = (initial_n_loss + initial_u_loss + initial_ϵ_loss)/3
        
        boundary_n_loss = mse(n_boundary_cond(t_boundary, x_boundary), n_boundary)
        boundary_u_loss = mse(u_boundary_cond(t_boundary, x_boundary), u_boundary)
        boundary_ϵ_loss = mse(ϵ_x_boundary_cond(t_boundary, x_boundary), ϵ_x_boundary)
        boundary_loss = (boundary_n_loss + boundary_u_loss + boundary_ϵ_loss)/3
        total_loss = interior_loss + initial_loss + boundary_loss
        
        return total_loss, density_loss, vorticity_loss, tpe_loss, initial_n_loss, initial_u_loss, initial_ϵ_loss, boundary_n_loss, boundary_u_loss, boundary_ϵ_loss
    
    def train(self):
        self.initialise_starting_surface()
        
        cwd = os.getcwd()
        plot_dirs = os.path.join(cwd, f'plots/{self.pars["experiment_name"]}')

        if not os.path.isdir(plot_dirs):
            os.makedirs(plot_dirs)
        
        mlflow.set_experiment(self.pars['experiment_name'])
        mlflow.start_run()
        
        mlflow.log_param("physical_params", physical_params)
        mlflow.log_param("model_params", self.pars)
        
        for epoch in tqdm(range(self.pars['epochs']), position=0, leave=True, desc='Training...'): 
            self.epoch = epoch
            
            # eval
            if epoch % self.pars['eval_interval'] == 0 or epoch == self.pars['epochs'] - 1:
                
                loss, density_loss, vorticity_loss, tpe_loss, initial_n_loss, initial_u_loss, initial_ϵ_loss, boundary_n_loss, boundary_u_loss, boundary_ϵ_loss \
                    = self.forward(self.eval_X_interior, self.eval_X_initial, self.eval_X_boundary)
                
                print()
                print(f'Epoch: {self.epoch}, Loss: {loss.item():,.4e}')
                print(f"density_loss: {density_loss.item():.4e}, vorticity_loss: {vorticity_loss.item():.4e}, tpe_loss: {tpe_loss.item():.4e}")
                print(f"initial_n_loss: {initial_n_loss.item():.4e}, initial_u_loss: {initial_u_loss.item():.4e}, initial_ϵ_loss: {initial_ϵ_loss.item():.4e}")
                print(f"boundary_n_loss: {boundary_n_loss.item():.4e}, boundary_u_loss: {boundary_u_loss.item():.4e}, boundary_ϵ_loss: {boundary_ϵ_loss.item():.4e}")
                
                mlflow.log_metric("total_loss", loss.item(), step=self.epoch)
                mlflow.log_metric("density_loss", density_loss.item(), step=self.epoch)
                mlflow.log_metric("vorticity_loss", vorticity_loss.item(), step=self.epoch)
                mlflow.log_metric("tpe_loss", tpe_loss.item(), step=self.epoch)
                mlflow.log_metric("initial_n_loss", initial_n_loss.item(), step=self.epoch)
                mlflow.log_metric("initial_u_loss", initial_u_loss.item(), step=self.epoch)
                mlflow.log_metric("initial_ϵ_loss", initial_ϵ_loss.item(), step=self.epoch)
                mlflow.log_metric("boundary_n_loss", boundary_n_loss.item(), step=self.epoch)
                mlflow.log_metric("boundary_u_loss", boundary_u_loss.item(), step=self.epoch)
                mlflow.log_metric("boundary_ϵ_loss", boundary_ϵ_loss.item(), step=self.epoch)
                
                mlflow.pytorch.log_model(self.model, f"{self.pars['experiment_name']}_model_epoch_{self.epoch}")
                
                if self.pars['plot_training_outputs']:
                    self.plot_outputs()
            
            # training step
            self.optimizer.zero_grad()
            X_interior = self.sample_interior_points()
            X_initial = self.sample_initial_points()
            X_boundary = self.sample_boundary_points()
            loss, density_loss, vorticity_loss, tpe_loss, initial_n_loss, initial_u_loss, initial_ϵ_loss, boundary_n_loss, boundary_u_loss, boundary_ϵ_loss = self.forward(X_interior, X_initial, X_boundary)
            loss.backward()
            self.optimizer.step()
            
                
            if loss.isnan():
                print(f'Epoch: {self.epoch}, Loss: {loss.item():,.4e}')
                print(f"density_loss: {density_loss.item():.4e}, vorticity_loss: {vorticity_loss.item():.4e}, tpe_loss: {tpe_loss.item():.4e}")
                print(f"initial_n_loss: {initial_n_loss.item():.4e}, initial_u_loss: {initial_u_loss.item():.4e}, initial_ϵ_loss: {initial_ϵ_loss.item():.4e}")
                print(f"boundary_n_loss: {boundary_n_loss.item():.4e}, boundary_u_loss: {boundary_u_loss.item():.4e}, boundary_ϵ_loss: {boundary_ϵ_loss.item():.4e}")
                print("loss is NaN, stopping training...")
                break
            
        mlflow.pytorch.log_model(self.model, f"{self.pars['experiment_name']}_model_final")
        
        self.save_gif()
        
        mlflow.end_run()
    
    def plot_outputs(self):
        Y = self(torch.hstack((self.eval_t, self.eval_x)))

        n = Y[:, 0].view(-1, 1)
        u = Y[:, 1].view(-1, 1)
        ϵ = Y[:, 2].view(-1, 1)
        
        n_x = grad(n, self.eval_x, grad_outputs=torch.ones_like(n), retain_graph=True, create_graph=True)[0]
        n_t = grad(n, self.eval_t, grad_outputs=torch.ones_like(n), retain_graph=True)[0]
        
        n_xx = grad(n_x, self.eval_x, grad_outputs=torch.ones_like(n_x), retain_graph=True)[0]
        
        u_x = grad(u, self.eval_x, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
        u_t = grad(u, self.eval_t, grad_outputs=torch.ones_like(u), retain_graph=True)[0]
        
        u_xx = grad(u_x, self.eval_x, grad_outputs=torch.ones_like(u_x), retain_graph=True)[0]
        
        ϵ_x = grad(ϵ, self.eval_x, grad_outputs=torch.ones_like(ϵ), retain_graph=True)[0]
        ϵ_t = grad(ϵ, self.eval_t, grad_outputs=torch.ones_like(ϵ), retain_graph=True)[0]
        
        n = n.view(100, 100).detach().cpu().numpy()
        n_x = n_x.view(100, 100).detach().cpu().numpy()
        n_t = n_t.view(100, 100).detach().cpu().numpy()
        n_xx = n_xx.view(100, 100).detach().cpu().numpy()
        
        u = u.view(100, 100).detach().cpu().numpy()
        u_x = u_x.view(100, 100).detach().cpu().numpy()
        u_t = u_t.view(100, 100).detach().cpu().numpy()
        u_xx = u_xx.view(100, 100).detach().cpu().numpy()
        
        ϵ = ϵ.view(100, 100).detach().cpu().numpy()
        ϵ_x = ϵ_x.view(100, 100).detach().cpu().numpy()
        ϵ_t = ϵ_t.view(100, 100).detach().cpu().numpy()
        
        if 'n' in self.pars['plot_vars_list']:
            self.plot_var(n, 'n')
        
        if 'n_x' in self.pars['plot_vars_list']:
            self.plot_var(n_x, 'n_x')
        
        if 'n_t' in self.pars['plot_vars_list']:
            self.plot_var(n_t, 'n_t')
        
        if 'n_xx' in self.pars['plot_vars_list']:
            self.plot_var(n_xx, 'n_xx')
        
        if 'u' in self.pars['plot_vars_list']:
            self.plot_var(u, 'u')
        
        if 'u_x' in self.pars['plot_vars_list']:
            self.plot_var(u_x, 'u_x')
        
        if 'u_t' in self.pars['plot_vars_list']:
            self.plot_var(u_t, 'u_t')
        
        if 'u_xx' in self.pars['plot_vars_list']:
            self.plot_var(u_xx, 'u_xx')
        
        if 'ϵ' in self.pars['plot_vars_list']:
            self.plot_var(ϵ, 'ϵ')
        
        if 'ϵ_x' in self.pars['plot_vars_list']:
            self.plot_var(ϵ_x, 'ϵ_x')
        
        if 'ϵ_t' in self.pars['plot_vars_list']:
            self.plot_var(ϵ_t, 'ϵ_t')
        
    def plot_var(self, var, varname):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(self.plot_t, self.plot_x, var, cmap='viridis')
        ax.set_xlabel('t')
        ax.set_ylabel('x')
        ax.set_zlabel(varname)
        plt.title(f'Model output {varname}, epoch {self.epoch}')
        filename = f"plots/{self.pars['experiment_name']}/plot_{varname}_epoch_{self.epoch}.png"
        self.plot_files[varname].append(filename)
        fig.savefig(filename)
        mlflow.log_artifact(filename)
        plt.close()
        
    def save_gif(self):
        for varname in self.pars['plot_vars_list']:
            gif_filename = f"plots/{self.pars['experiment_name']}/plot_{varname}.gif"
            with imageio.get_writer(gif_filename, mode='I') as writer:
                for filename in self.plot_files[varname]:
                    image = imageio.imread(filename)
                    writer.append_data(image)
                
            mlflow.log_artifact(gif_filename)
        
    def initialise_starting_surface(self):
        T_init, X_init, n_init, u_init, ϵ_init = self.starting_surface()
        T_init.requires_grad_()
        X_init.requires_grad_()
        n_init.requires_grad_()
        u_init.requires_grad_()
        ϵ_init.requires_grad_()
        
        print(f"T_init shape: {T_init.shape}, X_init shape: {X_init.shape}, n_init shape: {n_init.shape}, u_init shape: {u_init.shape}, ϵ_init shape: {ϵ_init.shape}")
        
        for i in tqdm(range(self.pars['initialisation_epochs'])):
            self.optimizer.zero_grad()
            Y = self.model(torch.hstack((T_init, X_init)))
            
            n_pred = Y[:, 0].reshape(-1, 1)
            u_pred = Y[:, 1].reshape(-1, 1)
            ϵ_pred = Y[:, 2].reshape(-1, 1)
            
            mse = nn.MSELoss()
            
            loss_n = mse(n_pred, n_init)
            loss_u = mse(u_pred, u_init)
            loss_ϵ = mse(ϵ_pred, ϵ_init)
            
            loss = (loss_n + loss_u + loss_ϵ)/3
            
            loss.backward()
            self.optimizer.step()
            
            if i % self.pars['initialisation_eval_interval'] == 0:
                print(f"Initialization epoch: {i}, loss: {loss.item():,.4e}")
            
        print(f"Initialization complete, initial surface loss: {loss.item():,.4e}")
    
    def starting_surface(self):
        t = torch.linspace(self.pars['t_min'], self.pars['t_max'], 100, device=self.pars['device'])
        x = torch.linspace(self.pars['x_min'], self.pars['x_max'], 100, device=self.pars['device'])
        
        x_range = self.pars['x_max'] - self.pars['x_min']
        t_range = self.pars['t_max'] - self.pars['t_min']
        
        T, X = torch.meshgrid(t, x)
        n = torch.sin((7.5+torch.rand(1, device=self.pars['device']))*torch.pi*X/x_range + torch.rand(1, device=self.pars['device'])) * torch.sin((7.5+torch.rand(1, device=self.pars['device']))*torch.pi*T/t_range + torch.rand(1, device=self.pars['device']))
        u = torch.sin((7.5+torch.rand(1, device=self.pars['device']))*torch.pi*X/x_range + torch.rand(1, device=self.pars['device'])) * torch.sin((7.5+torch.rand(1, device=self.pars['device']))*torch.pi*T/t_range + torch.rand(1, device=self.pars['device']))
        ϵ = 1 + 1e-2 + torch.sin((7.5+torch.rand(1, device=self.pars['device']))*torch.pi*X/x_range + torch.rand(1, device=self.pars['device'])) * torch.sin((7.5+torch.rand(1, device=self.pars['device']))*torch.pi*T/t_range + torch.rand(1, device=self.pars['device']))
        
        return T.reshape(-1, 1), X.reshape(-1, 1), n.reshape(-1, 1), u.reshape(-1, 1), ϵ.reshape(-1, 1)
    
        

## Experiments

In [13]:
pars = {
    'experiment_name': 'double_sine_initialization_pinn_v1',
    'layers': 4,
    'width': 64,
    'lr': 1e-4,
    'epochs': 100000,
    'eval_interval': 2500,
    'interior_batch_size': 2048,
    'initial_batch_size': 2048,
    'boundary_batch_size': 2048,
    'x_min': 0.0,
    'x_max': 1.0,
    't_min': 0.0,
    't_max': 10000,
    'device': 'cuda',
    'plot_training_outputs': True,
    'plot_vars_list': ['n', 'n_t', 'n_x', 'n_xx', 'u', 'u_t', 'u_x', 'u_xx', 'ϵ', 'ϵ_t', 'ϵ_x'],
    'initialisation_epochs': 100000,
    'initialisation_eval_interval': 5000
}
pinn = PINN(pars)


In [14]:
pinn.train()

T_init shape: torch.Size([10000, 1]), X_init shape: torch.Size([10000, 1]), n_init shape: torch.Size([10000, 1]), u_init shape: torch.Size([10000, 1]), ϵ_init shape: torch.Size([10000, 1])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


  0%|          | 0/100000 [00:00<?, ?it/s]

Initialization epoch: 0, loss: 5.2059e-01
Initialization epoch: 5000, loss: 2.0518e-01
Initialization epoch: 10000, loss: 1.7913e-02
Initialization epoch: 15000, loss: 5.9602e-03
Initialization epoch: 20000, loss: 6.4587e-03
Initialization epoch: 25000, loss: 3.1275e-03
Initialization epoch: 30000, loss: 2.4259e-03
Initialization epoch: 35000, loss: 1.9004e-03
Initialization epoch: 40000, loss: 1.5358e-03
Initialization epoch: 45000, loss: 1.2958e-03
Initialization epoch: 50000, loss: 1.3525e-03
Initialization epoch: 55000, loss: 3.0262e-03
Initialization epoch: 60000, loss: 8.9628e-04
Initialization epoch: 65000, loss: 8.1552e-04
Initialization epoch: 70000, loss: 7.5152e-04
Initialization epoch: 75000, loss: 6.9313e-04
Initialization epoch: 80000, loss: 1.1475e-03
Initialization epoch: 85000, loss: 5.9910e-04
Initialization epoch: 90000, loss: 3.6674e-03
Initialization epoch: 95000, loss: 5.3595e-04
Initialization complete, initial surface loss: 5.4950e-04


Training...:   0%|          | 0/100000 [00:00<?, ?it/s]


Epoch: 0, Loss: 1.3002e+10
density_loss: 2.2763e+03, vorticity_loss: 7.5739e+00, tpe_loss: 3.9007e+10
initial_n_loss: 8.7258e+00, initial_u_loss: 2.7545e-01, initial_ϵ_loss: 1.2090e+00
boundary_n_loss: 1.3325e+01, boundary_u_loss: 1.5841e-01, boundary_ϵ_loss: 3.1398e+01





Epoch: 2500, Loss: 2.3537e+07
density_loss: 3.1705e+04, vorticity_loss: 2.1963e+04, tpe_loss: 7.0558e+07
initial_n_loss: 3.4611e+01, initial_u_loss: 1.6938e+02, initial_ϵ_loss: 4.0109e+01
boundary_n_loss: 4.0810e+01, boundary_u_loss: 3.3361e+02, boundary_ϵ_loss: 5.8926e+01





Epoch: 5000, Loss: 6.4727e+06
density_loss: 1.9355e+04, vorticity_loss: 1.2088e+04, tpe_loss: 1.9386e+07
initial_n_loss: 2.6927e+01, initial_u_loss: 1.5091e+02, initial_ϵ_loss: 2.3738e+01
boundary_n_loss: 3.2203e+01, boundary_u_loss: 3.0356e+02, boundary_ϵ_loss: 5.6966e+01





Epoch: 7500, Loss: 3.2581e+06
density_loss: 1.5023e+04, vorticity_loss: 9.8865e+03, tpe_loss: 9.7488e+06
initial_n_loss: 2.5797e+01, initial_u_loss: 1.6436e+02, initial_ϵ_loss: 2.1755e+01
boundary_n_loss: 4.0559e+01, boundary_u_loss: 3.5744e+02, boundary_ϵ_loss: 5.0185e+01





Epoch: 10000, Loss: 1.0281e+07
density_loss: 4.2858e+04, vorticity_loss: 3.3105e+04, tpe_loss: 3.0766e+07
initial_n_loss: 3.4494e+01, initial_u_loss: 9.7818e+01, initial_ϵ_loss: 2.6464e+01
boundary_n_loss: 6.7055e+01, boundary_u_loss: 2.8816e+02, boundary_ϵ_loss: 1.1215e+02





Epoch: 12500, Loss: 9.1536e+06
density_loss: 3.3600e+04, vorticity_loss: 2.3764e+04, tpe_loss: 2.7403e+07
initial_n_loss: 2.6233e+01, initial_u_loss: 4.4903e+01, initial_ϵ_loss: 3.5021e+01
boundary_n_loss: 5.2207e+01, boundary_u_loss: 1.8195e+02, boundary_ϵ_loss: 1.6821e+02





Epoch: 15000, Loss: 8.8121e+06
density_loss: 3.7990e+04, vorticity_loss: 1.9333e+04, tpe_loss: 2.6378e+07
initial_n_loss: 2.9354e+01, initial_u_loss: 2.3360e+01, initial_ϵ_loss: 5.8816e+01
boundary_n_loss: 3.9482e+01, boundary_u_loss: 1.2236e+02, boundary_ϵ_loss: 2.1690e+02





Epoch: 17500, Loss: 8.4904e+06
density_loss: 1.6075e+05, vorticity_loss: 1.1315e+05, tpe_loss: 2.5196e+07
initial_n_loss: 3.0635e+01, initial_u_loss: 6.0890e+01, initial_ϵ_loss: 3.9372e+01
boundary_n_loss: 1.1237e+02, boundary_u_loss: 2.6990e+02, boundary_ϵ_loss: 5.5552e+02





Epoch: 20000, Loss: 7.6260e+07
density_loss: 7.8035e+04, vorticity_loss: 2.6782e+04, tpe_loss: 2.2867e+08
initial_n_loss: 3.1902e+01, initial_u_loss: 1.8664e+01, initial_ϵ_loss: 8.9626e+01
boundary_n_loss: 2.8663e+01, boundary_u_loss: 8.0911e+01, boundary_ϵ_loss: 3.6251e+02





Epoch: 22500, Loss: 1.4608e+07
density_loss: 4.7185e+04, vorticity_loss: 1.9277e+04, tpe_loss: 4.3756e+07
initial_n_loss: 2.8573e+01, initial_u_loss: 1.1089e+01, initial_ϵ_loss: 7.3033e+01
boundary_n_loss: 2.9278e+01, boundary_u_loss: 8.2221e+01, boundary_ϵ_loss: 2.7923e+02





Epoch: 25000, Loss: 3.5155e+06
density_loss: 4.0318e+04, vorticity_loss: 1.3010e+04, tpe_loss: 1.0493e+07
initial_n_loss: 2.2614e+01, initial_u_loss: 1.8562e+01, initial_ϵ_loss: 5.0519e+01
boundary_n_loss: 3.8615e+01, boundary_u_loss: 1.0351e+02, boundary_ϵ_loss: 2.7266e+02





Epoch: 27500, Loss: 1.3789e+08
density_loss: 2.2401e+05, vorticity_loss: 8.9116e+04, tpe_loss: 4.1336e+08
initial_n_loss: 7.2593e+01, initial_u_loss: 1.7997e+01, initial_ϵ_loss: 1.5142e+02
boundary_n_loss: 1.2060e+02, boundary_u_loss: 1.4739e+02, boundary_ϵ_loss: 5.3775e+02





Epoch: 30000, Loss: 2.6112e+07
density_loss: 3.1365e+05, vorticity_loss: 1.8437e+05, tpe_loss: 7.7836e+07
initial_n_loss: 3.7971e+01, initial_u_loss: 1.7693e+01, initial_ϵ_loss: 1.0502e+02
boundary_n_loss: 1.4453e+02, boundary_u_loss: 2.2879e+02, boundary_ϵ_loss: 6.6366e+02





Epoch: 32500, Loss: 1.8779e+07
density_loss: 3.0292e+05, vorticity_loss: 1.7468e+05, tpe_loss: 5.5857e+07
initial_n_loss: 3.6635e+01, initial_u_loss: 1.6270e+01, initial_ϵ_loss: 1.1187e+02
boundary_n_loss: 1.4004e+02, boundary_u_loss: 2.2352e+02, boundary_ϵ_loss: 6.3907e+02





Epoch: 35000, Loss: 1.4213e+07
density_loss: 3.2127e+05, vorticity_loss: 1.8894e+05, tpe_loss: 4.2127e+07
initial_n_loss: 3.5767e+01, initial_u_loss: 2.0753e+01, initial_ϵ_loss: 1.0061e+02
boundary_n_loss: 1.4619e+02, boundary_u_loss: 2.3883e+02, boundary_ϵ_loss: 6.8309e+02





Epoch: 37500, Loss: 2.8214e+07
density_loss: 3.5219e+05, vorticity_loss: 2.0631e+05, tpe_loss: 8.4083e+07
initial_n_loss: 3.1076e+01, initial_u_loss: 2.8544e+01, initial_ϵ_loss: 8.0960e+01
boundary_n_loss: 1.4559e+02, boundary_u_loss: 2.4398e+02, boundary_ϵ_loss: 6.4769e+02





Epoch: 40000, Loss: 1.1730e+07
density_loss: 4.3921e+05, vorticity_loss: 2.3103e+05, tpe_loss: 3.4519e+07
initial_n_loss: 3.8204e+01, initial_u_loss: 3.5532e+01, initial_ϵ_loss: 8.7608e+01
boundary_n_loss: 1.3881e+02, boundary_u_loss: 2.3796e+02, boundary_ϵ_loss: 8.0043e+02





Epoch: 42500, Loss: 5.0841e+07
density_loss: 4.4532e+05, vorticity_loss: 2.4176e+05, tpe_loss: 1.5183e+08
initial_n_loss: 2.1389e+01, initial_u_loss: 2.9578e+01, initial_ϵ_loss: 9.4275e+01
boundary_n_loss: 1.1975e+02, boundary_u_loss: 2.5697e+02, boundary_ϵ_loss: 6.3401e+02





Epoch: 45000, Loss: 6.9541e+07
density_loss: 5.4686e+05, vorticity_loss: 2.4052e+05, tpe_loss: 2.0784e+08
initial_n_loss: 3.5276e+01, initial_u_loss: 1.9515e+01, initial_ϵ_loss: 1.6921e+02
boundary_n_loss: 1.3018e+02, boundary_u_loss: 2.5919e+02, boundary_ϵ_loss: 7.1532e+02





Epoch: 47500, Loss: 4.7250e+07
density_loss: 7.6038e+05, vorticity_loss: 5.2117e+05, tpe_loss: 1.4047e+08
initial_n_loss: 4.1825e+01, initial_u_loss: 8.1100e+01, initial_ϵ_loss: 1.0827e+02
boundary_n_loss: 2.2243e+02, boundary_u_loss: 4.3957e+02, boundary_ϵ_loss: 9.1554e+02





Epoch: 50000, Loss: 4.3011e+07
density_loss: 7.3289e+05, vorticity_loss: 4.2327e+05, tpe_loss: 1.2788e+08
initial_n_loss: 4.1421e+01, initial_u_loss: 5.0571e+01, initial_ϵ_loss: 1.4785e+02
boundary_n_loss: 1.6024e+02, boundary_u_loss: 3.4613e+02, boundary_ϵ_loss: 9.4434e+02





Epoch: 52500, Loss: 2.9991e+07
density_loss: 5.7078e+05, vorticity_loss: 2.4548e+05, tpe_loss: 8.9156e+07
initial_n_loss: 4.1605e+01, initial_u_loss: 1.8491e+01, initial_ϵ_loss: 1.9818e+02
boundary_n_loss: 9.9600e+01, boundary_u_loss: 2.1815e+02, boundary_ϵ_loss: 6.1650e+02





Epoch: 55000, Loss: 1.5092e+07
density_loss: 6.6339e+05, vorticity_loss: 3.5828e+05, tpe_loss: 4.4254e+07
initial_n_loss: 2.3890e+01, initial_u_loss: 3.3044e+01, initial_ϵ_loss: 1.0589e+02
boundary_n_loss: 1.6371e+02, boundary_u_loss: 3.4201e+02, boundary_ϵ_loss: 5.8009e+02





Epoch: 57500, Loss: 2.4564e+07
density_loss: 7.0786e+05, vorticity_loss: 4.6633e+05, tpe_loss: 7.2515e+07
initial_n_loss: 3.0221e+01, initial_u_loss: 8.1375e+01, initial_ϵ_loss: 3.9794e+01
boundary_n_loss: 2.4800e+02, boundary_u_loss: 4.7857e+02, boundary_ϵ_loss: 8.1511e+02





Epoch: 60000, Loss: 6.0598e+07
density_loss: 5.3230e+05, vorticity_loss: 2.0615e+05, tpe_loss: 1.8105e+08
initial_n_loss: 4.8121e+01, initial_u_loss: 1.1728e+01, initial_ϵ_loss: 2.3794e+02
boundary_n_loss: 1.1922e+02, boundary_u_loss: 2.6409e+02, boundary_ϵ_loss: 5.1707e+02





Epoch: 62500, Loss: 9.6478e+07
density_loss: 6.3842e+05, vorticity_loss: 1.7621e+05, tpe_loss: 2.8862e+08
initial_n_loss: 7.8121e+01, initial_u_loss: 9.7771e+00, initial_ϵ_loss: 3.0894e+02
boundary_n_loss: 9.9886e+01, boundary_u_loss: 2.0126e+02, boundary_ϵ_loss: 5.7871e+02





Epoch: 65000, Loss: 1.2387e+08
density_loss: 4.0690e+05, vorticity_loss: 1.0480e+05, tpe_loss: 3.7109e+08
initial_n_loss: 9.5010e+01, initial_u_loss: 1.0183e+01, initial_ϵ_loss: 4.0336e+02
boundary_n_loss: 7.3150e+01, boundary_u_loss: 1.0964e+02, boundary_ϵ_loss: 7.4205e+02





Epoch: 67500, Loss: 7.2940e+07
density_loss: 3.5433e+05, vorticity_loss: 7.1426e+04, tpe_loss: 2.1839e+08
initial_n_loss: 6.8979e+01, initial_u_loss: 4.7731e+00, initial_ϵ_loss: 3.3560e+02
boundary_n_loss: 6.7850e+01, boundary_u_loss: 1.0944e+02, boundary_ϵ_loss: 6.9780e+02





Epoch: 70000, Loss: 7.2841e+07
density_loss: 2.1909e+05, vorticity_loss: 9.9723e+04, tpe_loss: 2.1820e+08
initial_n_loss: 6.7011e+01, initial_u_loss: 4.0930e+00, initial_ϵ_loss: 3.5085e+02
boundary_n_loss: 4.9133e+01, boundary_u_loss: 6.7153e+01, boundary_ϵ_loss: 5.7596e+02





Epoch: 72500, Loss: 1.7269e+07
density_loss: 2.0012e+05, vorticity_loss: 7.5058e+04, tpe_loss: 5.1530e+07
initial_n_loss: 4.3830e+01, initial_u_loss: 1.5202e+00, initial_ϵ_loss: 2.3155e+02
boundary_n_loss: 4.2159e+01, boundary_u_loss: 7.5302e+01, boundary_ϵ_loss: 5.0818e+02





Epoch: 75000, Loss: 5.2342e+06
density_loss: 1.8923e+05, vorticity_loss: 5.7180e+04, tpe_loss: 1.5455e+07
initial_n_loss: 2.3594e+01, initial_u_loss: 4.0104e+00, initial_ϵ_loss: 1.4773e+02
boundary_n_loss: 3.5682e+01, boundary_u_loss: 8.2718e+01, boundary_ϵ_loss: 4.6286e+02





Epoch: 77500, Loss: 3.7475e+06
density_loss: 1.7472e+05, vorticity_loss: 5.4704e+04, tpe_loss: 1.1012e+07
initial_n_loss: 1.4958e+01, initial_u_loss: 7.3222e+00, initial_ϵ_loss: 1.1704e+02
boundary_n_loss: 2.7598e+01, boundary_u_loss: 6.1228e+01, boundary_ϵ_loss: 4.6982e+02





Epoch: 80000, Loss: 1.9240e+07
density_loss: 4.3323e+05, vorticity_loss: 1.7798e+05, tpe_loss: 5.7107e+07
initial_n_loss: 3.3816e+01, initial_u_loss: 7.4434e+01, initial_ϵ_loss: 4.8894e+01
boundary_n_loss: 1.2261e+02, boundary_u_loss: 2.3634e+02, boundary_ϵ_loss: 4.3793e+02





Epoch: 82500, Loss: 2.8778e+07
density_loss: 5.3752e+05, vorticity_loss: 3.6831e+05, tpe_loss: 8.5426e+07
initial_n_loss: 6.0838e+01, initial_u_loss: 1.3425e+02, initial_ϵ_loss: 6.0922e+01
boundary_n_loss: 2.6051e+02, boundary_u_loss: 4.8423e+02, boundary_ϵ_loss: 5.1345e+02





Epoch: 85000, Loss: 3.2190e+07
density_loss: 4.9271e+05, vorticity_loss: 3.8147e+05, tpe_loss: 9.5694e+07
initial_n_loss: 9.1005e+01, initial_u_loss: 2.0228e+02, initial_ϵ_loss: 9.5939e+01
boundary_n_loss: 3.2058e+02, boundary_u_loss: 5.9967e+02, boundary_ϵ_loss: 5.1009e+02





Epoch: 87500, Loss: 1.1575e+07
density_loss: 4.7058e+05, vorticity_loss: 2.9680e+05, tpe_loss: 3.3957e+07
initial_n_loss: 6.4266e+01, initial_u_loss: 1.5540e+02, initial_ϵ_loss: 6.4905e+01
boundary_n_loss: 2.4931e+02, boundary_u_loss: 4.9094e+02, boundary_ϵ_loss: 5.0377e+02





Epoch: 90000, Loss: 9.7954e+06
density_loss: 4.8314e+05, vorticity_loss: 3.0257e+05, tpe_loss: 2.8599e+07
initial_n_loss: 5.5107e+01, initial_u_loss: 1.2754e+02, initial_ϵ_loss: 6.0749e+01
boundary_n_loss: 2.3808e+02, boundary_u_loss: 4.7678e+02, boundary_ϵ_loss: 5.5184e+02





Epoch: 92500, Loss: 8.1693e+06
density_loss: 5.0271e+05, vorticity_loss: 3.2650e+05, tpe_loss: 2.3677e+07
initial_n_loss: 5.2316e+01, initial_u_loss: 1.3007e+02, initial_ϵ_loss: 5.9158e+01
boundary_n_loss: 2.1795e+02, boundary_u_loss: 4.5793e+02, boundary_ϵ_loss: 6.4820e+02





Epoch: 95000, Loss: 6.6932e+06
density_loss: 3.9336e+05, vorticity_loss: 1.4799e+05, tpe_loss: 1.9537e+07
initial_n_loss: 3.8263e+01, initial_u_loss: 7.5349e+01, initial_ϵ_loss: 6.1167e+01
boundary_n_loss: 1.5092e+02, boundary_u_loss: 3.7703e+02, boundary_ϵ_loss: 8.7008e+02





Epoch: 97500, Loss: 3.9698e+07
density_loss: 5.4555e+05, vorticity_loss: 3.4499e+05, tpe_loss: 1.1820e+08
initial_n_loss: 5.4710e+01, initial_u_loss: 1.5003e+02, initial_ϵ_loss: 5.2556e+01
boundary_n_loss: 2.4850e+02, boundary_u_loss: 5.6967e+02, boundary_ϵ_loss: 6.2832e+02





Epoch: 99999, Loss: 1.2631e+07
density_loss: 6.0010e+05, vorticity_loss: 2.5930e+05, tpe_loss: 3.7031e+07
initial_n_loss: 4.7248e+01, initial_u_loss: 8.4115e+01, initial_ϵ_loss: 6.6181e+01
boundary_n_loss: 1.7316e+02, boundary_u_loss: 4.0018e+02, boundary_ϵ_loss: 8.9518e+02




## Debug

In [None]:
list(pinn.model.parameters())

In [None]:
t = torch.linspace(pars['t_min'], pars['t_max'], 100, device=pars['device'], requires_grad=True)
x = torch.linspace(pars['x_min'], pars['x_max'], 100, device=pars['device'], requires_grad=True)

x_range = pars['x_max'] - pars['x_min']
t_range = pars['t_max'] - pars['t_min']

T, X = torch.meshgrid(t, x)
n = torch.sin((5+torch.rand(1, device=pars['device']))*torch.pi*X/x_range + torch.rand(1, device=pars['device'])) * torch.sin((5+torch.rand(1, device=pars['device']))*torch.pi*T/t_range + torch.rand(1, device=pars['device']))
u = torch.sin((5+torch.rand(1, device=pars['device']))*torch.pi*X/x_range + torch.rand(1, device=pars['device'])) * torch.sin((5+torch.rand(1, device=pars['device']))*torch.pi*T/t_range + torch.rand(1, device=pars['device']))
ϵ = torch.sin((5+torch.rand(1, device=pars['device']))*torch.pi*X/x_range + torch.rand(1, device=pars['device'])) * torch.sin((5+torch.rand(1, device=pars['device']))*torch.pi*T/t_range + torch.rand(1, device=pars['device']))

In [None]:
T = T.reshape(-1, 1)
X = X.reshape(-1, 1)

torch.hstack((T, X))

In [None]:
# Create a 3D plot
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Plot the surface
surf = ax.plot_surface(T.detach().cpu(), X.detach().cpu(), u.detach().cpu(), cmap='viridis', edgecolor='none')

# Add labels and title
ax.set_xlabel('t')
ax.set_ylabel('x')
ax.set_zlabel('u')
ax.set_title('Double Sine Surface Plot')

# Add a color bar
fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)

# Show the plot
plt.show()