In [1]:
using Lux,
    Optimisers,
    Random,
    Printf,
    Statistics,
    MLUtils,
    Reactant,
    Enzyme

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

(::CPUDevice) (generic function with 1 method)

In [2]:
struct PINN{M} <: AbstractLuxWrapperLayer{:model}
    model::M
end

function PINN(; hidden_dims::Int=32)
    return PINN(
        Chain(
            Dense(3 => hidden_dims, tanh),
            Dense(hidden_dims => hidden_dims, tanh),
            Dense(hidden_dims => hidden_dims, tanh),
            Dense(hidden_dims => 1),
        ),
    )
end

PINN

In [3]:
function u(model::StatefulLuxLayer, xyt::AbstractArray)
    return model(xyt)
end

function ∂u_∂t(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][3, :]
end

function ∂u_∂x(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][1, :]
end

function ∂u_∂y(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum ∘ model, xyt)[1][2, :]
end

function ∂²u_∂x²(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum ∘ ∂u_∂x, Enzyme.Const(model), xyt)[2][1, :]
end

function ∂²u_∂y²(model::StatefulLuxLayer, xyt::AbstractArray)
    return Enzyme.gradient(Enzyme.Reverse, sum ∘ ∂u_∂y, Enzyme.Const(model), xyt)[2][2, :]
end

∂²u_∂y² (generic function with 1 method)

In [4]:
function physics_informed_loss_function(model::StatefulLuxLayer, xyt::AbstractArray)
    return mean(abs2, ∂u_∂t(model, xyt) .- u(model, xyt) .* ∂²u_∂x²(model, xyt) .- ∂²u_∂y²(model, xyt))
end
function mse_loss_function(model::StatefulLuxLayer, target::AbstractArray, xyt::AbstractArray)
    return MSELoss()(model(xyt), target)
end
function loss_function(model, ps, st, (xyt, target_data, xyt_bc, target_bc))
    smodel = StatefulLuxLayer(model, ps, st)
    physics_loss = physics_informed_loss_function(smodel, xyt)
    data_loss = mse_loss_function(smodel, target_data, xyt)
    bc_loss = mse_loss_function(smodel, target_bc, xyt_bc)
    loss = physics_loss + data_loss + bc_loss
    return loss, smodel.st, (; physics_loss, data_loss, bc_loss)
end 

loss_function (generic function with 1 method)

In [5]:
analytical_solution(x, y, t) = @. exp(x + y) * cos(x + y + 4t)
analytical_solution(xyt) = analytical_solution(xyt[1, :], xyt[2, :], xyt[3, :])

analytical_solution (generic function with 2 methods)

In [6]:
begin
    grid_len = 16

    grid = range(0.0f0, 2.0f0; length=grid_len)
    xyt = stack([[elem...] for elem in vec(collect(Iterators.product(grid, grid, grid)))])

    target_data = reshape(analytical_solution(xyt), 1, :)

    bc_len = 512

    x = collect(range(0.0f0, 2.0f0; length=bc_len))
    y = collect(range(0.0f0, 2.0f0; length=bc_len))
    t = collect(range(0.0f0, 2.0f0; length=bc_len))

    xyt_bc = hcat(
        stack((x, y, zeros(Float32, bc_len)); dims=1),
        stack((zeros(Float32, bc_len), y, t); dims=1),
        stack((ones(Float32, bc_len) .* 2, y, t); dims=1),
        stack((x, zeros(Float32, bc_len), t); dims=1),
        stack((x, ones(Float32, bc_len) .* 2, t); dims=1),
    )
    target_bc = reshape(analytical_solution(xyt_bc), 1, :)

    min_target_bc, max_target_bc = extrema(target_bc)
    min_data, max_data = extrema(target_data)
    min_pde_val, max_pde_val = min(min_data, min_target_bc), max(max_data, max_target_bc)

    xyt = (xyt .- minimum(xyt)) ./ (maximum(xyt) .- minimum(xyt))
    xyt_bc = (xyt_bc .- minimum(xyt_bc)) ./ (maximum(xyt_bc) .- minimum(xyt_bc))
    target_bc = (target_bc .- min_pde_val) ./ (max_pde_val - min_pde_val)
    target_data = (target_data .- min_pde_val) ./ (max_pde_val - min_pde_val)
end

1×4096 Matrix{Float32}:
 0.511  0.512222  0.513394  0.514452  …  0.761361  0.839134  0.926969

In [7]:
rng = Random.default_rng()
Random.seed!(rng, 0)
pinn = PINN(; hidden_dims=10)
ps, st = Lux.setup(rng, pinn) |> xdev
train_state = Training.TrainState(pinn, ps, st, Adam(0.005f0))

xyt_batch = xyt[:, 1:40] |> xdev
target_data_batch = target_data[:, 1:40] |> xdev
xyt_bc_batch = xyt_bc[:, 1:40] |> xdev
target_bc_batch = target_bc[:, 1:40] |> xdev

_, loss, stats, train_state = Training.single_train_step!(
        AutoEnzyme(),
        loss_function,
        (xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
        train_state;
        return_gradients=Val(false),
    );

I0000 00:00:1757270412.409325   11786 service.cc:163] XLA service 0x38914110 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757270412.409346   11786 service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 3060 Ti, Compute Capability 8.6
I0000 00:00:1757270412.409669   11786 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757270412.409699   11786 gpu_helpers.cc:136] XLA backend allocating 6159925248 bytes on device 0 for BFCAllocator.
I0000 00:00:1757270412.409725   11786 gpu_helpers.cc:177] XLA backend will use up to 2053308416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757270412.421045   11786 cuda_dnn.cc:463] Loaded cuDNN version 90800


In [7]:
function train_model(xyt, target_data, xyt_bc, target_bc; seed::Int=0, maxiters::Int=50000, hidden_dims::Int=128)
    rng = Random.default_rng()
    Random.seed!(rng, seed)

    pinn = PINN(; hidden_dims)
    ps, st = Lux.setup(rng, pinn) |> xdev

    bc_dataloader = DataLoader((xyt_bc, target_bc); batchsize=128, shuffle=true, partial=false) |> xdev
    pde_dataloader = DataLoader((xyt, target_data); batchsize=128, shuffle=true, partial=false) |> xdev

    train_state = Training.TrainState(pinn, ps, st, Adam(0.005f0))

    lr = i -> i < 5000 ? 0.005f0 : (i < 10000 ? 0.0005f0 : 0.00005f0)

    iter = 1
    for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in
        zip(Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader))
        Optimisers.adjust!(train_state, lr(iter))

        _, loss, stats, train_state = Training.single_train_step!(
            AutoEnzyme(),
            loss_function,
            (xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch),
            train_state;
            return_gradients=Val(false),
        )

        if iter % 1000 == 0
            println("Iter: $(iter), loss: $(Float32(loss))")
        end
        
        iter += 1
        iter ≥ maxiters && break
    end

    return StatefulLuxLayer(pinn, cdev(train_state.parameters), cdev(train_state.states))
end

trained_model = train_model(xyt, target_data, xyt_bc, target_bc)         

Iter: 1000, loss: 0.016553883


I0000 00:00:1757266283.547096    7194 service.cc:163] XLA service 0x322e16c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757266283.547117    7194 service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 3060 Ti, Compute Capability 8.6
I0000 00:00:1757266283.547477    7194 se_gpu_pjrt_client.cc:1338] Using BFC allocator.
I0000 00:00:1757266283.547566    7194 gpu_helpers.cc:136] XLA backend allocating 6159925248 bytes on device 0 for BFCAllocator.
I0000 00:00:1757266283.547635    7194 gpu_helpers.cc:177] XLA backend will use up to 2053308416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1757266283.558691    7194 cuda_dnn.cc:463] Loaded cuDNN version 90800


Iter: 2000, loss: 0.014384383
Iter: 3000, loss: 0.01110726
Iter: 4000, loss: 0.00680809
Iter: 5000, loss: 0.009741962
Iter: 6000, loss: 0.00096959283
Iter: 7000, loss: 0.0009281422
Iter: 8000, loss: 0.0013213546
Iter: 9000, loss: 0.0005148426
Iter: 10000, loss: 0.0009806744
Iter: 11000, loss: 0.00036486753
Iter: 12000, loss: 0.000426544
Iter: 13000, loss: 0.0002843058
Iter: 14000, loss: 0.0004991124
Iter: 15000, loss: 0.00025018654
Iter: 16000, loss: 0.00031640788
Iter: 17000, loss: 0.00039353993
Iter: 18000, loss: 0.00031272764
Iter: 19000, loss: 0.0002892317
Iter: 20000, loss: 0.00024611966
Iter: 21000, loss: 0.00022522072
Iter: 22000, loss: 0.00023999666
Iter: 23000, loss: 0.0003094984
Iter: 24000, loss: 0.0002671322
Iter: 25000, loss: 0.00018351548
Iter: 26000, loss: 0.00035562256
Iter: 27000, loss: 0.00019729715
Iter: 28000, loss: 0.0002448506
Iter: 29000, loss: 0.0002386901
Iter: 30000, loss: 0.00020195781
Iter: 31000, loss: 0.00017870276
Iter: 32000, loss: 0.00021166826
Iter: 33

StatefulLuxLayer{Val{true}()}(
    PINN(
        model = Chain(
            layer_1 = Dense(3 => 128, tanh),      [90m# 512 parameters[39m
            layer_(2-3) = Dense(128 => 128, tanh),  [90m# 33_024 (16_512 x 2) parameters[39m
            layer_4 = Dense(128 => 1),            [90m# 129 parameters[39m
        ),
    ),
) [90m        # Total: [39m33_665 parameters,
[90m          #        plus [39m0 states.

In [16]:
dat = [1.0f0 2.0f0 3.0f0 1.0f0;
       1.1f0 2.1f0 3.1f0 4.1f0;
       1.2f0 2.2f0 3.2f0 4.2f0]
trained_model(dat)

1×4 Matrix{Float32}:
 1.57821  1.97627  1.85081  1.96543