In [1]:
import numpy as np
import torch

from pina.problem import SpatialProblem, TimeDependentProblem
from pina.operators import nabla, grad, div, curl, advection
from pina import Condition, Span, LabelTensor


class QGE(SpatialProblem, TimeDependentProblem):

    output_variables = ['q', 'si']
    spatial_domain = Span({'x': [0, 1], 'y': [-1, 1]})
    temporal_domain = Span({'t': [0, 100]})
    
    def rand_choice_integer_Data(self):
        pass 
    
    def eq1(input_, output_):
        nu = 0.0022
        Re = 100 / nu
        
        #convective term
        si_curl = curl(output_.extract(['si']), input_, d = ['x', 'y'])
        convective_ = advection(output_.extract(['q']), input_, velocity_field = si_curl, d = ['x', 'y'])
        
        #diffusive term
        diffusive_ = nabla(output_.extract(['q']), input_, d = ['x', 'y'])   
        
        #transient term 
        du = grad(output_.extract(['q']), input_)
        transient_ = du.extract('dqdt')
        
        #forcing term
        force_ = torch.sin(torch.pi * input_.extract(['y']))
        
        return transient_ + convective_ - (1/Re) * diffusive_ - force_
    
    def eq2(input_, output_):
        r0 = 0.0036
        
        #second equation
        output = output_.extract(['q']) + r0 * nabla(output_.extract(['si']), input_, d = ['x', 'y']) - input_.extract(['y'])
        return output
        
    def continuity(input_, output_):
        si_curl = curl(output_.extract(['si']), input_, d = ['x', 'y'])
        return div(si_curl, input_, d = ['x', 'y'])

    def initial(input_, output_):
        value = 0.0
        return output_.extract(['si']) - value
    
    def zeta(input_, output_):
        value = input_.extract(['y'])
        return output_.extract(['q']) - value
    
    def si(input_, output_):
        si_expected = 0.0
        return output_.extract(['si']) - si_expected
        
    
    conditions = {
        't0': Condition(Span({'x': [0, 1], 'y': [-1, 1], 't' : 0}), initial),
        
        'upper': Condition(Span({'x':  [0,1], 'y': 1, 't': [0,100]}), [si, zeta]),
        'fixedWall1': Condition(Span({'x':  0, 'y': [-1,1], 't': [0,100]}), [si, zeta]),
        'fixedWall2': Condition(Span({'x':  1, 'y': [-1,1], 't': [0,100]}), [si, zeta]),
        'fixedWall3': Condition(Span({'x':  [0,1], 'y': -1, 't': [0,100]}), [si, zeta]),
        
        'D': Condition(Span({'x': [0, 1], 'y': [-1, 1], 't': [0, 100]}), [eq1, eq2, continuity]),
    }


In [2]:
import sys
import numpy as np
import torch
from torch.nn import ReLU, Tanh, Softplus

from pina import PINN, LabelTensor, Plotter
from pina.model import FeedForward

#args
id_run = 0
save = True

qge_problem = QGE()
model = FeedForward(
    layers=[10, 10, 10, 10],
    output_variables=qge_problem.output_variables,
    input_variables=qge_problem.input_variables,
    func=Softplus,
)
pinn = PINN(
    qge_problem,
    model,
    lr=0.006,
    error_norm='mse',
    regularizer=1e-8)

if save:
    pinn.span_pts(
            {'n': 100, 'mode': 'grid', 'variables': 't'},
            {'n': 4, 'mode': 'grid', 'variables': 'x'},
            {'n': 8, 'mode': 'grid', 'variables': 'y'},
            locations=['t0', 'upper','fixedWall1','fixedWall2','fixedWall3', 'D'])
    
    pinn.train(100, 10)
    with open('qge_history_{}.txt'.format(id_run), 'w') as file_:
        for i, losses in pinn.history_loss.items():
            file_.write('{} {}\n'.format(i, sum(losses)))
    pinn.save_state('pina.qge')
else:
    pinn.load_state('pina.qge')
    plotter = Plotter()
    plotter.plot(pinn, components='si')
    plotter.plot_loss(pinn)


              sum          t0initial    uppersi      upperzeta    fixedWall1si fixedWall1ze fixedWall2si fixedWall2ze fixedWall3si fixedWall3ze Deq1         Deq2         Dcontinuity  
[epoch 00000] 1.069388e+01 1.047776e+00 1.463900e+00 1.208108e+00 1.451941e+00 4.462402e-01 1.460613e+00 4.470644e-01 1.448370e+00 8.356909e-01 4.375242e-01 4.466504e-01 1.891934e-22 
              sum          t0initial    uppersi      upperzeta    fixedWall1si fixedWall1ze fixedWall2si fixedWall2ze fixedWall3si fixedWall3ze Deq1         Deq2         Dcontinuity  
[epoch 00001] 7.857171e+00 8.288423e-01 8.192036e-01 8.542905e-01 8.075427e-01 4.388306e-01 8.157874e-01 4.381738e-01 8.038858e-01 1.174595e+00 4.375193e-01 4.385004e-01 1.960187e-22 
              sum          t0initial    uppersi      upperzeta    fixedWall1si fixedWall1ze fixedWall2si fixedWall2ze fixedWall3si fixedWall3ze Deq1         Deq2         Dcontinuity  
[epoch 00010] 4.183349e+00 9.235127e-02 7.958755e-02 1.161971e+00 8.143102e-02 4