In [5]:
import numpy as np
import torch

from pina.problem import SpatialProblem, TimeDependentProblem
from pina.operators import nabla, grad, div, curl, advection
from pina import Condition, Span, LabelTensor


class QGE(SpatialProblem, TimeDependentProblem):

    output_variables = ['q', 'si']
    spatial_domain = Span({'x': [0, 1], 'y': [-1, 1]})
    temporal_domain = Span({'t': [0, 100]})
    
    def rand_choice_integer_Data(self):
        pass 
    
    def eq1(input_, output_):
        nu = 0.0022
        Re = 100 / nu
        
        #convective term
        si_curl = curl(output_.extract(['si']), input_, d = ['x', 'y'])
        convective_ = advection(output_.extract(['q']), input_, velocity_field = si_curl, d = ['x', 'y'])
        
        #diffusive term
        diffusive_ = nabla(output_.extract(['q']), input_, d = ['x', 'y'])   
        
        #transient term 
        du = grad(output_.extract(['q']), input_)
        transient_ = du.extract('dqdt')
        
        #forcing term
        force_ = torch.sin(torch.pi * input_.extract(['y']))
        
        return transient_ + convective_ - (1/Re) * diffusive_ - force_
    
    def eq2(input_, output_):
        r0 = 0.0036
        #second equation
        output = output_.extract(['q']) + r0 * nabla(output_.extract(['si']), input_, d = ['x', 'y']) - input_.extract(['y'])
        
        return output
        
    
    def continuity(input_, output_):
        si_curl = curl(output_.extract(['si']), input_, d = ['x', 'y'])
        return div(si_curl, input_, d = ['x', 'y'])

    def initial(input_, output_):
        value = 0.0
        return output_.extract(['si']) - value
    
    def zeta(input_, output_):
        value = input_.extract(['y'])
        return output_.extract(['q']) - value
    

    conditions = {
        't0': Condition(Span({'x': [0, 1], 'y': [-1, 1], 't' : 0}), initial),
        'zeta_condition': Condition(Span({'x': [0,1], 'y': [-1, 1], 't': 0}), zeta),      
        'D': Condition(Span({'x': [0, 1], 'y': [-1, 1], 't': [0, 100]}), [eq1, eq2, continuity]),
    }


In [4]:
import sys
import numpy as np
import torch
from torch.nn import ReLU, Tanh, Softplus

from pina import PINN, LabelTensor, Plotter
from pina.model import FeedForward

#args
id_run = 0
save = True

qge_problem = QGE()
model = FeedForward(
    layers=[10, 10, 10, 10],
    output_variables=qge_problem.output_variables,
    input_variables=qge_problem.input_variables,
    func=Softplus,
)
pinn = PINN(
    qge_problem,
    model,
    lr=0.006,
    error_norm='mse',
    regularizer=1e-8)

if save:
    pinn.span_pts(
            {'n': 100, 'mode': 'grid', 'variables': 't'},
            {'n': 20, 'mode': 'grid', 'variables': 'x'},
            {'n': 20, 'mode': 'grid', 'variables': 'y'},
            locations=['t0', 'zeta_condition', 'D'])
    
    pinn.train(100, 10)
    with open('qge_history_{}.txt'.format(id_run), 'w') as file_:
        for i, losses in pinn.history_loss.items():
            file_.write('{} {}\n'.format(i, sum(losses)))
    pinn.save_state('pina.qge')
else:
    pinn.load_state('pina.qge')
    plotter = Plotter()
    plotter.plot(pinn, components='si')
    plotter.plot_loss(pinn)


              sum          t0initial    zeta_conditi Deq1         Deq2         Dcontinuity  
[epoch 00000] 3.167565e+00 2.628742e-01 1.187286e+00 4.750442e-01 1.242361e+00 5.966882e-21 
              sum          t0initial    zeta_conditi Deq1         Deq2         Dcontinuity  
[epoch 00001] 2.714045e+00 1.894580e-01 1.014704e+00 4.750258e-01 1.034857e+00 5.397433e-21 
              sum          t0initial    zeta_conditi Deq1         Deq2         Dcontinuity  
[epoch 00010] 1.231908e+00 1.761013e-02 3.703571e-01 4.749822e-01 3.689583e-01 3.265997e-21 
              sum          t0initial    zeta_conditi Deq1         Deq2         Dcontinuity  
[epoch 00020] 1.260655e+00 9.735573e-05 3.997174e-01 4.749905e-01 3.858493e-01 1.039424e-21 
              sum          t0initial    zeta_conditi Deq1         Deq2         Dcontinuity  
[epoch 00030] 1.221339e+00 5.372635e-03 3.674332e-01 4.750073e-01 3.735260e-01 5.570385e-22 
              sum          t0initial    zeta_conditi Deq1         Deq2