### Importing all necessary packages

In [None]:
using Flux
using CPUTime
using NeuralPDE
using GalacticOptim
using Optim
using DiffEqFlux
using Quadrature
using CUDA
using Cuba
using QuasiMonteCarlo
using ModelingToolkit #it contains @variables and parameters

### Defining the Equation, Domains and Boundary conditions

In [None]:
@parameters x y t
@variables u(..)

# Derivatives
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)

# Partial Differential equation
eq = Dt(u(t,x,y)) ~ Dxx(u(t,x,y))+Dyy(u(t,x,y))

# Analytical Solution function
analytic_sol_func(t,x,y) = exp(x+y)*cos(x+y+4t)

# Domains for x,y,time 't'
domains = [t ∈ IntervalDomain(0.,2.),
          x ∈ IntervalDomain(0.,2.),
          y ∈ IntervalDomain(0.,2.)]

# Boundary Conditions
bcs = [u(0.,x,y)~analytic_sol_func(0.,x,y),
       u(t,0.,y)~analytic_sol_func(t,0.,y),
       u(t,2.,y)~analytic_sol_func(t,2.,y),
       u(t,x,0.)~analytic_sol_func(t,x,0.),
       u(t,x,2.)~analytic_sol_func(t,x,2.)]

### Creating the NN and Defining the PDE problem

In [None]:
# NN with 2 Hidden layers
# total weight parameters to be trained (40+110+110+11 = 271)
chain = FastChain(FastDense(3,10,Flux.σ),
                  FastDense(10,10,Flux.σ),
                  FastDense(10,10,Flux.σ),
                  FastDense(10,1))

# Initial parameters
initθ = DiffEqFlux.initial_params(chain) |> gpu

# Training Strategy
strategy = NeuralPDE.QuasiRandomTraining(3000; #points
                                         sampling_alg = UniformSample(),
                                         minibatch = 50)
# Discritization
discretization = NeuralPDE.PhysicsInformedNN(chain,
                                             strategy;
                                             init_params = initθ)
# Problem formulation
@named pde_system = PDESystem(eq,bcs,domains,[t,x,y],[u])
prob = NeuralPDE.discretize(pde_system,discretization)

In [None]:
# Callback function
cb = function (p,l)
    println("Current loss is: $l")
    return false
end

# Training
@time @CPUtime res = GalacticOptim.solve(prob,ADAM(0.1);cb=cb,maxiters=1500)
prob = remake(prob,u0=res.minimizer)

@time @CPUtime res = GalacticOptim.solve(prob,ADAM(0.01);cb=cb,maxiters=1500)
prob = remake(prob,u0=res.minimizer)

@time @CPUtime res = GalacticOptim.solve(prob,ADAM(0.001);cb=cb,maxiters=1500)

In [None]:
phi = discretization.phi
ts,xs,ys = [domain.domain.lower:0.1:domain.domain.upper for domain in domains]
u_real = [analytic_sol_func(t,x,y) for t in ts for x in xs for y in ys]
u_predict = [first(Array(phi([t, x, y], res.minimizer))) for t in ts for x in xs for y in ys]

### Ploting the animation

In [None]:
using Plots
using Printf

function plot_(res)
    # Animate
    anim = @animate for (i, t) in enumerate(0:0.05:2.0)
        @info "Animating frame $i..."
        u_real = reshape([analytic_sol_func(t,x,y) for x in xs for y in ys], (length(xs),length(ys)))
        u_predict = reshape([Array(phi([t, x, y], res.minimizer))[1] for x in xs for y in ys], length(xs), length(ys))
        u_error = abs.(u_predict .- u_real)
        title = @sprintf("predict t = %.3f", t)
        p1 = plot(xs, ys, u_predict,st=:surface, label="", title=title)
        title = @sprintf("real")
        p2 = plot(xs, ys, u_real,st=:surface, label="", title=title)
        title = @sprintf("error")
        p3 = plot(xs, ys, u_error, st=:contourf,label="", title=title)
        plot(p1,p2,p3)
    end
    gif(anim,"3pde.gif", fps=10)
end

plot_(res)

### Checking at 1.0 seconds

In [None]:
t = 1.0
u_real = reshape([analytic_sol_func(t,x,y) for x in xs for y in ys], (length(xs),length(ys)))
u_predict = reshape([Array(phi([t, x, y], res.minimizer))[1] for x in xs for y in ys], length(xs), length(ys))
u_error = abs.(u_predict .- u_real)

title = @sprintf("predict t = %.2f sec", t)
p1 = plot(xs, ys, u_predict,st=:surface, label="", title=title)
title = @sprintf("analytic")
savefig("2D-PDE-pred.pdf")
p2 = plot(xs, ys, u_real,st=:surface, label="", title=title)
title = @sprintf("error")
savefig("2D-PDE-anal.pdf")
p3 = plot(xs, ys, u_error, st=:contourf,label="", title=title)
savefig("2D-PDE-err.pdf")
plot(p1,p2,p3)