In [1]:
# Set up a chain rule to calculate the loss function @ saturation w/out propagating derivatives through saturation solver
# This was done in the Language of Molecules paper.

# 1_differentiable_saft works when no solvers are involved, i.e. for properties specified with a given (V, T)
# if, instead, we want to solve for saturation conditions, then I have the old problem of 

In [2]:
import Pkg; Pkg.activate(".")

using Revise
using Clapeyron
includet("./saftvrmienn.jl")
# These are functions we're going to overload for SAFTVRMieNN
import Clapeyron: a_res, saturation_pressure, pressure

using Flux
using Plots
using ForwardDiff, DiffResults

using Zygote, ChainRulesCore
using ImplicitDifferentiation

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`




In [3]:
function make_NN_model(Mw, m, σ, λ_a, λ_r, ϵ)
    model = SAFTVRMieNN(
        params = SAFTVRMieNNParams(
            Mw=[Mw],
            segment=[m],
            sigma=[σ],
            lambda_a=[λ_a],
            lambda_r=[λ_r],
            epsilon=[ϵ],
            epsilon_assoc=Float32[],
            bondvol=Float32[],
        )
    )
    return model
end

# Hack to get around Clapeyron model construction API
function make_model(Mw, m, σ, λ_a, λ_r, ϵ)
    model = SAFTVRMie(["methane"])
    
    model.params.Mw[1] = Mw
    model.params.segment[1] = m
    model.params.sigma[1] = σ
    model.params.lambda_a[1] = λ_a
    model.params.lambda_r[1] = λ_r
    model.params.epsilon[1] = ϵ

    return model
end


make_model (generic function with 1 method)

In [49]:
function saturation_pressure_NN(X, T)
    model = make_model(X...)
    p, Vₗ, Vᵥ = saturation_pressure(model, T)

    return p
end

function ChainRulesCore.rrule(::typeof(saturation_pressure_NN), X, T)
    model = make_model(X...)
    p, Vₗ, Vᵥ = saturation_pressure(model, T)
    
    function f_pullback(Δy)
        #* Newton step from perfect initialisation
        function f_p(X, T)
            model = make_NN_model(X...)
            p2 = -(eos(model, Vᵥ, T) - eos(model, Vₗ, T))/(Vᵥ - Vₗ);
            return p2
        end

        ∂X = @thunk(ForwardDiff.gradient(X -> f_p(X, T), X) .* Δy)
        ∂T = @thunk(ForwardDiff.derivative(T -> f_p(X, T), T) .* Δy)
        return (NoTangent(), ∂X, ∂T)
    end
    
    return p, f_pullback
end

X = [16.04, 1.0, 3.737e-10, 6.0, 12.504, 152.58]

p = saturation_pressure_NN(X, T)
∂X = Zygote.gradient(saturation_pressure_NN, X, T)

([-0.0, -235775.88358427823, -2.738594599593772e14, 61346.898616495506, 12147.672399827, -2130.8750918885917], 3592.4268158264795)