In [1]:
using Lux, Reactant, Random, Optimisers, Enzyme, Statistics

const cdev = cpu_device()
const xdev = reactant_device(; force=true);

In [2]:
const L = 10.0f0
const t_end = 5.0f0
const ν = 0.1f0

const u_mean = 1.0f0
const u_amplitude = 0.5f0
const k = 2.0f0*pi/L
u0(x) = Float32(u_mean + u_amplitude * cos(k*x))

const λ_pde = 3.0f0
const λ_ic = 1.0f0
const λ_bc = 1.0f0;

In [3]:
struct PINN{M} <: AbstractLuxWrapperLayer{:model}
    model::M
end

function PINN(; hidden_dims::Int=32)
    return PINN(
        Chain(
            Dense(2 => hidden_dims, tanh),
            Dense(hidden_dims => hidden_dims, tanh),
            Dense(hidden_dims => hidden_dims, tanh),
            Dense(hidden_dims => 1),
        ),
    )
end;

In [4]:
function u(model::StatefulLuxLayer, xt::AbstractArray)
    return model(xt)
end

function ∂u_∂t(model::StatefulLuxLayer, xt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xt)[1][2, :]
end

function ∂u_∂x(model::StatefulLuxLayer, xt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xt)[1][1, :]
end
function ∂²u_∂x²(model::StatefulLuxLayer, xt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum ∘ ∂u_∂x, Enzyme.Const(model), xt)[2][1, :]
end;

In [5]:
function physics_informed_loss_function(model::StatefulLuxLayer, xt::AbstractArray)
    return mean(abs2, ∂u_∂t(model, xt) .+ u(model, xt) .* ∂u_∂x(model, xt) .- (ν .* ∂²u_∂x²(model, xt)))
end
function bc_loss_function(model::StatefulLuxLayer, xt_i::AbstractArray, xt_f::AbstractArray)
    u_i = model(xt_i)
    u_f = model(xt_f)
    mean(abs2, u_i .- u_f)
end
function ic_loss_function(model::StatefulLuxLayer, xt::AbstractArray, target::AbstractArray)
    u = model(xt)
    mean(abs2, u .- target)
end
function loss_function(model, ps, st, (xt, xt_bc_i, xt_bc_f, xt_ic, target_ic))
    smodel = StatefulLuxLayer(model, ps, st)
    physics_loss = physics_informed_loss_function(smodel, xt)
    bc_loss = bc_loss_function(smodel, xt_bc_i, xt_bc_f)
    ic_loss = ic_loss_function(smodel, xt_ic, target_ic)
    loss = λ_pde * physics_loss + λ_bc * bc_loss + λ_ic * ic_loss
    return loss, smodel.st, (; physics_loss, bc_loss, ic_loss, total_loss=loss)
end;

In [6]:
function get_data(rng, batch_size_pde, batch_size_bc, batch_size_ic)
    xt = rand(rng, Float32, (2, batch_size_pde))
    xt[1, :] .*= L
    xt[2, :] .*= t_end

    rand_t_bc = rand(rng, Float32, batch_size_bc) .* t_end
    xt_bc_i = zeros(Float32, (2, batch_size_bc))
    xt_bc_i[2, :] = rand_t_bc

    xt_bc_f = fill(L, (2, batch_size_bc))
    xt_bc_f[2, :] = rand_t_bc
    
    rand_x_ic = rand(rng, Float32, batch_size_ic) .* L
    xt_ic = zeros(Float32, (2, batch_size_ic))
    xt_ic[1, :] = rand_x_ic

    target_ic = reshape(u0.(rand_x_ic), 1, :)
    return (xt, xt_bc_i, xt_bc_f, xt_ic, target_ic)
end;

In [7]:
function train_model(; seed::Int=0, maxiters::Int=10000, hidden_dims::Int=64, 
        batch_size_pde::Int=256, batch_size_bc::Int=64, batch_size_ic::Int=64)
    
    rng = Random.default_rng()
    Random.seed!(rng, seed)

    pinn = PINN(; hidden_dims)
    ps, st = Lux.setup(rng, pinn) |> xdev

    train_state = Training.TrainState(pinn, ps, st, Adam(0.005f0))

    lr = i -> i < 5000 ? 0.005f0 : (i < 10000 ? 0.0005f0 : 0.00005f0)

    for iter in 1:maxiters
        batch = get_data(rng, batch_size_pde, batch_size_bc, batch_size_ic) |> xdev

        Optimisers.adjust!(train_state, lr(iter))

        _, loss, stats, train_state = Training.single_train_step!(AutoEnzyme(),
                                                                  loss_function,
                                                                  batch,
                                                                  train_state;
                                                                  return_gradients=Val(false),
        )

        if iter % 1000 == 0
            println("Iter: $(iter), loss: $(Float32(loss))")
        end
    end

    return StatefulLuxLayer(pinn, cdev(train_state.parameters), cdev(train_state.states))
end;

In [8]:
train_model()

Iter: 1000, loss: 0.009160802


I0000 00:00:1757272670.672039   14750 service.cc:163] XLA service 0x2c158190 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757272670.672057   14750 service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 3060 Ti, Compute Capability 8.6
I0000 00:00:1757272670.672332   14750 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757272670.672359   14750 gpu_helpers.cc:136] XLA backend allocating 6159925248 bytes on device 0 for BFCAllocator.
I0000 00:00:1757272670.672385   14750 gpu_helpers.cc:177] XLA backend will use up to 2053308416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757272670.682119   14750 cuda_dnn.cc:463] Loaded cuDNN version 90800
I0000 00:00:1757272670.689569   14750 cuda_executor.cc:517] failed to allocate 5.74GiB (6159925248 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
I0000 00:00:1757272670.689661   14750 cuda_executor.cc:517] failed to allocate 5.

Iter: 2000, loss: 0.0064483285
Iter: 3000, loss: 0.00772568
Iter: 4000, loss: 0.0074739833
Iter: 5000, loss: 0.0064942813
Iter: 6000, loss: 0.0064157755
Iter: 7000, loss: 0.0061363177
Iter: 8000, loss: 0.006158486
Iter: 9000, loss: 0.005519745
Iter: 10000, loss: 0.0063517764
Iter: 11000, loss: 0.005922146
Iter: 12000, loss: 0.006265676
Iter: 13000, loss: 0.006338785
Iter: 14000, loss: 0.005706429
Iter: 15000, loss: 0.0058926856
Iter: 16000, loss: 0.006232427
Iter: 17000, loss: 0.006073523
Iter: 18000, loss: 0.0055466415
Iter: 19000, loss: 0.0061833877
Iter: 20000, loss: 0.0058193626
Iter: 21000, loss: 0.005880187
Iter: 22000, loss: 0.0059887934
Iter: 23000, loss: 0.0063040177
Iter: 24000, loss: 0.006266666
Iter: 25000, loss: 0.0063183843
Iter: 26000, loss: 0.006076712
Iter: 27000, loss: 0.0062433663
Iter: 28000, loss: 0.0057777427
Iter: 29000, loss: 0.0061264215
Iter: 30000, loss: 0.0055093113
Iter: 31000, loss: 0.0061090044
Iter: 32000, loss: 0.006177628
Iter: 33000, loss: 0.006438334

StatefulLuxLayer{Val{true}()}(
    PINN(
        model = Chain(
            layer_1 = Dense(2 => 64, tanh),       [90m# 192 parameters[39m
            layer_(2-3) = Dense(64 => 64, tanh),  [90m# 8_320 (4_160 x 2) parameters[39m
            layer_4 = Dense(64 => 1),             [90m# 65 parameters[39m
        ),
    ),
) [90m        # Total: [39m8_577 parameters,
[90m          #        plus [39m0 states.