In [40]:
import Pkg; Pkg.activate(".")
using Clapeyron
includet("./saftvrmienn.jl")
# These are functions we're going to overload for SAFTVRMieNN
import Clapeyron: a_res, saturation_pressure, pressure

using Flux
using Plots, Statistics
using ForwardDiff, DiffResults

using Zygote, ChainRulesCore
using ImplicitDifferentiation

using CSV, DataFrames
using MLUtils
using RDKitMinimalLib
using JLD2

# Multithreaded loss
using Zygote: bufferfrom
using Base.Threads: @spawn
using Plots

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`


In [42]:
X = [16.04, 1.0, 3.737, 6.0, 12.504, 152.58]
V = volume_NN(X, 1e7, 100.0)
∂V∂X = Zygote.gradient(X -> volume_NN(X, 1e7, 100.0), X)

([-0.0, 2.260005096581295e-5, 2.1863843008154135e-5, 1.46022174535494e-6, 1.6493904176981177e-7, -6.011351535625539e-8],)

In [48]:
#* I want to define a rrule for the backwards pass of Vₗ & pressure simultaneously
#* It's wasteful to run the volume solver so many times unnecessarily
X = [16.04, 1.0, 3.737, 6.0, 12.504, 152.58]
T = 150.0
p = saturation_pressure_NN(X, T)
# ∂p∂X,  = 
Zygote.gradient(X -> saturation_pressure_NN(X, T), X)

([-0.0, -4.32664026029552e6, -838766.9527486291, 1.2246980149044194e6, 256678.70349442933, -39357.47330351631],)

In [18]:
X = [16.04, 1.0, 3.737, 6.0, 12.504, 152.58]
f_V(X) = volume_NN(X, 1e7, 100.0)[1]
dX = [0.0, 0.0, 0.0, 0.0, 0.0, 1e-6]
f_∂V∂X(X) = (f_V(X .+ dX) - f_V(X .- dX))/(2dX)
f_∂V∂X(X)

1×6 transpose(::Vector{Float64}) with eltype Float64:
 -0.0  -0.0  -0.0  -0.0  -0.0  -5.78614e-8

In [49]:
# Generate training set for liquid density and saturation pressure
function create_data(; batch_size=16, n_points=25)
    # Create training & validation data
    df = CSV.read("./pcpsaft_params/SI_pcp-saft_parameters.csv", DataFrame, header=1)
    filter!(row -> occursin("Alkane", row.family), df)
    df = first(df, 1) #* Take only first molecule in dataframe
    @show df.common_name
    mol_data = zip(df.common_name, df.isomeric_smiles, df.molarweight)
    println("Generating data for $(length(mol_data)) molecules...")

    function make_fingerprint(s::String)::Vector{Float64}
        mol = get_mol(s)
        @assert !isnothing(mol)

        fp = []
        # for (nbits, rad) in [(256, 256), (1, 3)]
        #* Approximately ECFP4 fingerprint
        nbits = 256
        rad = 4

        fp_details = Dict{String,Any}("nBits" => nbits, "radius" => rad)
        fp_str = get_morgan_fp(mol, fp_details)
        append!(fp, [parse(Float64, string(c)) for c in fp_str])
        # end

        desc = get_descriptors(mol)
        relevant_keys = [
            "CrippenClogP",
            "NumHeavyAtoms",
            "amw",
            "FractionCSP3",
        ]
        relevant_desc = [desc[k] for k in relevant_keys]
        append!(fp, last.(relevant_desc))

        return fp
    end

    T = Float64
    # X_data = Vector{Tuple{Vector{T},T,T,T}}([])
    X_data = Vector{Tuple{Vector{T},T,T}}([])
    Y_data = Vector{Vector{T}}()

    # n = 0
    for (name, smiles, Mw) in mol_data
        # if n < 20
        try
            saft_model = PPCSAFT([name])
            # saft_model = SAFTVRMie([name])
            Tc, pc, Vc = crit_pure(saft_model)

            # fp = make_fingerprint(smiles)
            fp = [1.0]
            # append!(fp, Mw)

            T_range = range(0.5 * Tc, 0.975 * Tc, n_points)
            for T in T_range
                (p_sat, Vₗ_sat, Vᵥ_sat) = saturation_pressure(saft_model, T)

                # p = p_sat * 5.0

                # Vₗ = volume(saft_model, p, T; phase=:liquid)
                push!(X_data, (fp, T, Mw))
                push!(Y_data, [Vₗ_sat, p_sat])
            end
            # n += 1
        catch e
            println("Fingerprint generation failed for $name, $e")
        end
        # else
        # break
        # end
    end

    #* Remove columns from fingerprints
    # Identify zero & one columns
    # num_cols = length(X_data[1][1])
    # zero_cols = trues(num_cols)
    # for (vec, _, _) in X_data
    #     zero_cols .&= (vec .== 0)
    # end
    # keep_cols = .!zero_cols # Create a Mask
    # X_data = [(vec[keep_cols], vals...) for (vec, vals...) in X_data] # Apply Mask

    # num_cols = length(X_data[1][1])
    # one_cols = trues(num_cols)
    # for (vec, _, _) in X_data
    #     one_cols .&= (vec .== 1)
    # end
    # keep_cols = .!one_cols # Create a Mask
    # X_data = [(vec[keep_cols], vals...) for (vec, vals...) in X_data] # Apply Mask

    train_data, test_data = splitobs((X_data, Y_data), at=1.0, shuffle=false)

    train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=false)
    test_loader = DataLoader(test_data, batchsize=batch_size, shuffle=false)
    println("n_batches = $(length(train_loader)), batch_size = $batch_size")
    flush(stdout)
    return train_loader, test_loader
end


function create_ff_model(nfeatures)
    # Base NN architecture from "Fitting Error vs Parameter Performance"
    nout = 4
    model = Chain(
        Dense(nfeatures, nout, x -> x; bias=false, init=zeros32),
    )
    #* glorot_uniform default initialisation
    # model = Chain(
    #     Dense(nfeatures, nout * 8, tanh, init=zeros32),
    #     Dense(nout * 8, nout * 4, tanh, init=zeros32),
    #     Dense(nout * 4, nout * 2, tanh, init=zeros32),
    #     Dense(nout * 2, nout, x -> x, init=zeros32), # Allow unbounded negative outputs; parameter values physically bounded in SAFT layer
    # )
    # model(x) = m, σ, λ_a, λ_r, ϵ

    # return nn_model, unbounded_model
    return model
end

function get_idx_from_iterator(iterator, idx)
    data_iterator = iterate(iterator)
    for _ in 1:idx-1
        data_iterator = iterate(iterator, data_iterator[2])
    end
    return data_iterator[1]
end


# function SAFT_head(model, X; b=[3.0, 3.5, 7.0, 12.5, 250.0], c=10.0)
# function SAFT_head(model, X; b=[2.0, 4.0], c=[1.0, 1.0])
function SAFT_head(model, X; b=[2.5, 3.5, 12.0, 250.0], c=Float64[1, 1, 10, 100])
    fp, T, Mw = X

    # m = 1.8514
    # σ = 4.0887
    λ_a = 6.0
    # λ_r = 13.65
    # ϵ = 273.64
    # fp, p, T, Mw = X
    pred_params = model(fp)

    # Add bias and scale
    biased_params = @. pred_params * c + b

    # Can also fix lambda_a
    #! How to do this in AD compatible way, can't do in-place modification
    # if biased_params[1] < 1.0
    #     biased_params[1] = ones(length(biased_params[1]))
    # end

    saft_input = vcat(Mw, biased_params[1:2], [λ_a], biased_params[3:4])

    # saft_input = vcat(Mw, biased_params[1:2], [λ_a, λ_r], ϵ)
    # Vₗ = volume_NN(saft_input, p, T)

    # ŷ_1 = !isnan(Vₗ) ? Vₗ : 1e3

    Tc = ignore_derivatives() do
        critical_temperature_NN(saft_input)
    end
    # todo include saturation volumes in loss
    if T < Tc
        p_sat = saturation_pressure_NN(saft_input, T)
        # @show saturation_NN(saft_input, T)
        # p_sat, Vₗ_sat, Vᵥ_sat = saturation_NN(saft_input, T)
        if !isnan(p_sat)
            Vₗ_sat = volume_NN(saft_input, p_sat, T)
            # ŷ_2 = sat_p
            ŷ = [Vₗ_sat, p_sat]
        else
            # println("sat_p is NaN at T = $T, saft_input = $saft_input")
            ŷ = [nothing, nothing]
        end
    else
        ŷ = [nothing, nothing]
    end

    return ŷ
end

function eval_loss(X_batch, y_batch, metric, model)
    batch_loss = 0.0
    n = 0
    for (X, y_vec) in zip(X_batch, y_batch)
        # y = y_vec[1]
        ŷ_vec = SAFT_head(model, X)

        for (ŷ, y) in zip(ŷ_vec, y_vec)
            if !isnothing(ŷ)
                batch_loss += metric(y, ŷ)
                n += 1
            end
        end

    end
    if n > 0
        batch_loss /= n
    end
    # penalize batch_loss depending on how many failed
    # batch_loss += length(y_batch) - n

    return batch_loss
end

function eval_loss_par(X_batch, y_batch, metric, model, n_chunks)
    n = length(X_batch)
    chunk_size = n ÷ n_chunks

    p = bufferfrom(zeros(n_chunks))

    # Creating views for each chunk
    X_chunks = vcat([view(X_batch, (i-1)*chunk_size+1:i*chunk_size) for i in 1:n_chunks-1], [view(X_batch, (n_chunks-1)*chunk_size+1:n)])
    y_chunks = vcat([view(y_batch, (i-1)*chunk_size+1:i*chunk_size) for i in 1:n_chunks-1], [view(y_batch, (n_chunks-1)*chunk_size+1:n)])

    @sync begin
        for i = 1:n_chunks
            @spawn begin
                p[i] = eval_loss(X_chunks[i], y_chunks[i], metric, model)
            end
        end
    end
    return sum(p) / n_chunks # average partial losses
end

function percent_error(y, ŷ)
    return 100 * abs(y - ŷ) / y
end

function mse(y, ŷ)
    return ((y - ŷ) / y)^2
end

function finite_diff_grads(model, x, y; ϵ=1e-8)
    grads = []
    for p in Flux.params(model)
        push!(grads, zeros(size(p)))
    end

    for (i, p) in enumerate(Flux.params(model))
        for j in eachindex(p)
            tmp = p[j]
            p[j] = tmp + ϵ
            J1 = eval_loss(x, y, mse, model)
            p[j] = tmp - ϵ
            J2 = eval_loss(x, y, mse, model)
            p[j] = tmp
            grads[i][j] = (J1 - J2) / (2 * ϵ)
        end
    end
    return grads
end

function train_model!(model, train_loader, test_loader; epochs=10)
    optim = Flux.setup(Flux.Adam(0.01), model) # 1e-3 usually safe starting LR
    # optim = Flux.setup(Descent(0.001), model)

    println("training on $(Threads.nthreads()) threads")
    flush(stdout)

    for epoch in 1:epochs
        batch_loss = 0.0
        for (X_batch, y_batch) in train_loader

            loss, grads = Flux.withgradient(model) do m
                # loss = eval_loss_par(X_batch, y_batch, percent_error, m, Threads.nthreads())
                loss = eval_loss(X_batch, y_batch, mse, m)
                loss
            end
            batch_loss += loss
            @assert !isnan(loss)

            # grads_fd = finite_diff_grads(model, X_batch, y_batch)
            # @show grads[1]
            # @show grads_fd      # Show FD gradients

            Flux.update!(optim, model, grads[1])
        end
        batch_loss /= length(train_loader)
        epoch % 1 == 0 && println("epoch $epoch: batch_loss = $batch_loss")
        flush(stdout)
    end
end

function main(; epochs=15)
    train_loader, test_loader = create_data(n_points=20, batch_size=20) # Should make 5 batches / epoch. 256 / 8 gives 32 evaluations per thread
    @show n_features = length(first(train_loader)[1][1][1])

    model = create_ff_model(n_features)
    # @show model.layers[1].weight, model([1.0])
    train_model!(model, train_loader, test_loader; epochs=epochs)
    return model
end

main (generic function with 1 method)

In [50]:
m = main(;epochs=100)

df.common_name = ["n-butane"]
Generating data for 1 molecules...
n_batches = 1, batch_size = 20




n_features = length((((first(train_loader))[1])[1])[1]) = 1
training on 1 threads




epoch 1: batch_loss = 0.24360634636006565


epoch 2: batch_loss = 0.21856524198932484


epoch 3: batch_loss = 0.19263854094403393


epoch 4: batch_loss = 0.16619345097130594


epoch 5: batch_loss = 0.13975286024385647


epoch 6: batch_loss = 0.11404064450074192


epoch 7: batch_loss = 0.09003176568119306


epoch 8: batch_loss = 0.06899334499720985


epoch 9: batch_loss = 0.05247330314489913


epoch 10: batch_loss = 0.042122478994858555


epoch 11: batch_loss = 0.0391454677641264


epoch 12: batch_loss = 0.043319216356720966


epoch 13: batch_loss = 0.05202216759451302


epoch 14: batch_loss = 0.06022449516294287


epoch 15: batch_loss = 0.06327247884927467


epoch 16: batch_loss = 0.060784066247526546


epoch 17: batch_loss = 0.055225912400765645


epoch 18: batch_loss = 0.04895451794010682


epoch 19: batch_loss = 0.043618980411431404


epoch 20: batch_loss = 0.04010837696824411


epoch 21: batch_loss = 0.03856948063716571


epoch 22: batch_loss = 0.038633299998410905


epoch 23: batch_loss = 0.03969377243365802


epoch 24: batch_loss = 0.04112327294322722


epoch 25: batch_loss = 0.04239919555803306


epoch 26: batch_loss = 0.04315670889680262


epoch 27: batch_loss = 0.04319422484917013


epoch 28: batch_loss = 0.042454953748414515


epoch 29: batch_loss = 0.04099959072430672


epoch 30: batch_loss = 0.038977455239669535


epoch 31: batch_loss = 0.03659793492772055


epoch 32: batch_loss = 0.03410143623679026


epoch 33: batch_loss = 0.03172833268222639


epoch 34: batch_loss = 0.02968562352066449


epoch 35: batch_loss = 0.02811362015532907


epoch 36: batch_loss = 0.027058483149606793


epoch 37: batch_loss = 0.02645943041875386


epoch 38: batch_loss = 0.0261596395675034


epoch 39: batch_loss = 0.025945237169700215


epoch 40: batch_loss = 0.025606605471235964


epoch 41: batch_loss = 0.025003358830701417


epoch 42: batch_loss = 0.024106244268971917


epoch 43: batch_loss = 0.022995793192818165


epoch 44: batch_loss = 0.021819498230467894


epoch 45: batch_loss = 0.020731274271898685


epoch 46: batch_loss = 0.01984154977902969


epoch 47: batch_loss = 0.019193999730944638


epoch 48: batch_loss = 0.01876899217295455


epoch 49: batch_loss = 0.018504268912652562


epoch 50: batch_loss = 0.018321426634539912


epoch 51: batch_loss = 0.018149429519253635


epoch 52: batch_loss = 0.01794038242589676


epoch 53: batch_loss = 0.017676194738658398


epoch 54: batch_loss = 0.017366967765553104


epoch 55: batch_loss = 0.017043072526420868


epoch 56: batch_loss = 0.016743419032307787


epoch 57: batch_loss = 0.016502717597357698


epoch 58: batch_loss = 0.016340681520199702


epoch 59: batch_loss = 0.016255948265741706


epoch 60: batch_loss = 0.016226615573163874


epoch 61: batch_loss = 0.016217505650797152


epoch 62: batch_loss = 0.01619187215250882


epoch 63: batch_loss = 0.016123276806015822


epoch 64: batch_loss = 0.016003045856894885


epoch 65: batch_loss = 0.015840657593047626


epoch 66: batch_loss = 0.015657648162162398


epoch 67: batch_loss = 0.015478366263005899


epoch 68: batch_loss = 0.015321619514544383


epoch 69: batch_loss = 0.015196069566646746


epoch 70: batch_loss = 0.015100090878134933


epoch 71: batch_loss = 0.015025033222085126


epoch 72: batch_loss = 0.014959933099479655


epoch 73: batch_loss = 0.014895779010354309


epoch 74: batch_loss = 0.014828072457162656


epoch 75: batch_loss = 0.014757210106147601


epoch 76: batch_loss = 0.01468700101934973


epoch 77: batch_loss = 0.01462210501296175


epoch 78: batch_loss = 0.014565424182038805


epoch 79: batch_loss = 0.014516384314042813


epoch 80: batch_loss = 0.014470682040736


epoch 81: batch_loss = 0.0144215217061829


epoch 82: batch_loss = 0.014361806541016205


epoch 83: batch_loss = 0.014286400802941857


epoch 84: batch_loss = 0.01419358939329855


epoch 85: batch_loss = 0.014085258122301025


epoch 86: batch_loss = 0.01396587049628567


epoch 87: batch_loss = 0.013840790724252297


epoch 88: batch_loss = 0.013714664201992846


epoch 89: batch_loss = 0.013590417224626302


epoch 90: batch_loss = 0.01346907934736311


epoch 91: batch_loss = 0.013350279601798349


epoch 92: batch_loss = 0.013233048855726598


epoch 93: batch_loss = 0.013116537281279001


epoch 94: batch_loss = 0.013000380395721833


epoch 95: batch_loss = 0.012884641750578952


epoch 96: batch_loss = 0.012769449281515733


epoch 97: batch_loss = 0.01265455819183437


epoch 98: batch_loss = 0.012539071724648099


epoch 99: batch_loss = 0.012421466753638448


epoch 100: batch_loss = 0.012299923169666882


Chain(
  Dense(1 => 4, #488; bias=false),      [90m# 4 parameters[39m
) 

In [8]:
# Todo: code to create saturation envelopes for given envelopes

In [37]:
fieldnames(typeof(m.layers[1]))
m.layers[1].weight

4×1 Matrix{Float32}:
 -0.093466766
 -0.089043505
  0.089826174
 -0.093140006

In [38]:
b = [2.5, 3.5, 12, 250.0]
# b = [2.0, 4.0, 250.0]
c = Float64[1, 1, 10, 100]
# b = [2.0, 4.0]
# c = Float64[1, 1]
params = m([1.0]) .* c .+ b

4-element Vector{Float64}:
   2.406533233821392
   3.410956494510174
  12.898261740803719
 240.68599939346313

In [1]:
Mw = 58.12
λ_a = 6.0

plot(;box=:on, dpi=400, xlabel="log10(V / m³)", ylabel="T / K")
function f(model, c, label)
    # Create saturation envelope
    Tc, pc, Vc = crit_pure(model)
    T_range = range(0.5 * Tc, 0.9999 * Tc, 500)

    p_sat =  Float64[]
    Vl_sat = Float64[]
    Vv_sat = Float64[]
    for T in T_range
        (p, Vl, Vv) = saturation_pressure(model, T)
        push!(p_sat, p)
        push!(Vl_sat, Vl)
        push!(Vv_sat, Vv)
    end

    # plot!(T_range, p_sat, label="saturation pressure")
    plot!(log10.(Vl_sat), T_range, label=label, lw=2, color=c)
    plot!(log10.(Vv_sat), T_range, label="", lw=2, color=c)
    scatter!([log10(Vc)], [Tc], label="", color=c)
end

pred_model = make_model(vcat(Mw, params[1:2], [λ_a], params[3:4])...)
base_model = SAFTVRMie(["n-butane"])
f(pred_model, 1, "4 params regressed")
f(base_model, 2, "Base SAFT-VR-Mie")
#* Plot nominal parameters

UndefVarError: UndefVarError: plot not defined

In [12]:
df = CSV.read("./pcpsaft_params/SI_pcp-saft_parameters.csv", DataFrame, header=1)
filter!(row -> occursin("Alkane", row.family), df)
df = first(df, 1) #* Take only first molecule in dataframe
mol_data = zip(df.common_name, df.isomeric_smiles, df.molarweight)
saft_model = SAFTVRMie([first(mol_data)[1]])
saft_model.params.segment.values[1], saft_model.params.sigma.values[1]*1e10, saft_model.params.epsilon.values[1], saft_model.params.lambda_r.values[1]

(1.8514, 4.0887, 273.64, 13.65)