In [1]:
import Pkg; Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`


In [2]:
using Revise
using Clapeyron
includet("./saftvrmienn.jl")
import Clapeyron: a_res

using Flux
using Plots
using ForwardDiff, DiffResults

using Zygote, ChainRulesCore



In [3]:
model = SAFTVRMie(["methane"])
for x in fieldnames(typeof(model.params))
    println(x, ": ", eval(Meta.parse("model.params.$x.values")))
end

Mw: 

[16.04]
segment: [1.0]
sigma: 

[3.737e-10;;]
lambda_a: [6.0;;]
lambda_r: [12.504;;]
epsilon: [152.58;;]
epsilon_assoc: Clapeyron.Compressed4DMatrix{Float64, Vector{Float64}}Float64[]
bondvol: Clapeyron.Compressed4DMatrix{Float64, Vector{Float64}}Float64[]


In [4]:
x = SAFTVRMieNN(
    params = SAFTVRMieNNParams(
        Mw=[16.04],
        segment=[1.0],
        sigma=[3.737e-10],
        lambda_a=[6.0],
        lambda_r=[12.504],
        epsilon=[152.58],
    )
)

SAFTVRMieNN with 1 component:
 "methane"
Contains parameters: Mw, segment, sigma, lambda_a, lambda_r, epsilon, epsilon_assoc, bondvol

In [5]:
@show a_res(x, 1e-4, 300.0, [1.0])
@show a_res(model, 1e-4, 300.0, [1.0])


a_res(x, 0.0001, 300.0, [1.0]) = -0.2883849961140881
a_res(model, 0.0001, 300.0, [1.0]) = 

-0.2883849961140881


-0.2883849961140881

In [6]:
function differentiable_saft(X)
    model = SAFTVRMieNN(
        params = SAFTVRMieNNParams(
            Mw=[16.04],
            segment=[1.0],
            sigma=[X[1]],
            lambda_a=[X[2]],
            lambda_r=[X[3]],
            epsilon=[X[4]],
            epsilon_assoc=Float64[],
            bondvol=Float64[],
        )
    )
    return a_res(model, 1e-4, 300.0, [1.0])
    # return pressure(model, 1e-4, 300.0, [1.0])
    return saturation_pressure(model)
end

differentiable_saft([3.737e-10, 6.0, 12.504, 152.58])

-0.2883849961140881

In [21]:
g1 = ForwardDiff.gradient(differentiable_saft, [3.737e-10, 6.0, 12.504, 152.58])

4-element Vector{Float64}:
 -1.2753186700741215e9
  0.23877186863500233
  0.053846521804507934
 -0.006372733864435557

In [22]:
g2 = Zygote.gradient(differentiable_saft, [3.737e-10, 6.0, 12.504, 152.58])

([-1.2753186700741205e9, 0.23877186863500235, 0.0538465218045079, -0.006372733864435554],)

In [25]:
g2[1] ≈ g1

true

In [9]:
function f(X)
    return differentiable_saft(X)
end

function ChainRulesCore.rrule(::typeof(f), x)
    y = f(x)
    
    function f_pullback(Δy)
        # Use ForwardDiff to compute the gradient
        ∂x = ForwardDiff.gradient(f, x) .* Δy # Note: element-wise multiplication
        return (NoTangent(), ∂x)
    end
    
    return y, f_pullback
end

# Test the gradient computation
Zygote.gradient(f, [3.737e-10, 6.0, 12.504, 152.58])

([-1.2769666505152988e9, 0.23822451028384095, 0.053698344981408117, -0.006493795105749014],)