In [1]:
import Pkg; Pkg.activate(".")
using Clapeyron
includet("./saftvrmienn.jl")
# These are functions we're going to overload for SAFTVRMieNN
import Clapeyron: a_res, saturation_pressure, pressure

using Flux
using Plots, Statistics
using ForwardDiff, DiffResults

using Zygote, ChainRulesCore
using ImplicitDifferentiation

using CSV, DataFrames
using MLUtils
using RDKitMinimalLib
using JLD2

# Multithreaded loss
using Zygote: bufferfrom
using Base.Threads: @spawn

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`




In [2]:
X = [16.04, 1.0, 3.737, 6.0, 12.504, 152.58]
V = volume_NN(X, 1e7, 100.0)
∂V∂X = Zygote.gradient(X -> volume_NN(X, 1e7, 100.0), X)

([-0.0, 2.260005096581295e-5, 2.1863843008154135e-5, 1.46022174535494e-6, 1.6493904176981177e-7, -6.011351535625539e-8],)

In [3]:
X = [16.04, 1.0, 3.737, 6.0, 12.504, 152.58]
f_V(X) = volume_NN(X, 1e7, 100.0)[1]
dX = [0.0, 0.0, 0.0, 0.0, 0.0, 1e-6]
f_∂V∂X(X) = (f_V(X .+ dX) - f_V(X .- dX))/(2dX)
f_∂V∂X(X)

1×6 transpose(::Vector{Float64}) with eltype Float64:
 -0.0  -0.0  -0.0  -0.0  -0.0  -5.78614e-8

In [93]:
# Generate training set for liquid density and saturation pressure
function create_data(; batch_size=16, n_points=25)
    # Create training & validation data
    df = CSV.read("./pcpsaft_params/SI_pcp-saft_parameters.csv", DataFrame, header=1)
    filter!(row -> occursin("Alkane", row.family), df)
    df = first(df, 1) #* Take only first molecule in dataframe
    @show df.common_name
    mol_data = zip(df.common_name, df.isomeric_smiles, df.molarweight)
    println("Generating data for $(length(mol_data)) molecules...")

    function make_fingerprint(s::String)::Vector{Float64}
        mol = get_mol(s)
        @assert !isnothing(mol)

        fp = []
        # for (nbits, rad) in [(256, 256), (1, 3)]
        #* Approximately ECFP4 fingerprint
        nbits = 256
        rad = 4

        fp_details = Dict{String,Any}("nBits" => nbits, "radius" => rad)
        fp_str = get_morgan_fp(mol, fp_details)
        append!(fp, [parse(Float64, string(c)) for c in fp_str])
        # end

        desc = get_descriptors(mol)
        relevant_keys = [
            "CrippenClogP",
            "NumHeavyAtoms",
            "amw",
            "FractionCSP3",
        ]
        relevant_desc = [desc[k] for k in relevant_keys]
        append!(fp, last.(relevant_desc))

        return fp
    end

    T = Float64
    X_data = Vector{Tuple{Vector{T},T,T,T}}([])
    Y_data = Vector{Vector{T}}()

    # n = 0
    for (name, smiles, Mw) in mol_data
        # if n < 20
        try
            # saft_model = PPCSAFT([name])
            saft_model = SAFTVRMie([name])
            Tc, pc, Vc = crit_pure(saft_model)

            # fp = make_fingerprint(smiles)
            fp = [1.0]
            # append!(fp, Mw)

            T_range = range(0.5 * Tc, 0.975 * Tc, n_points)
            for T in T_range
                (p₀, V_vec...) = saturation_pressure(saft_model, T)

                p = p₀ * 10

                Vₗ = volume(saft_model, p, T; phase=:liquid)
                push!(X_data, (fp, p, T, Mw))
                push!(Y_data, [Vₗ])
            end
            # n += 1
        catch e
            println("Fingerprint generation failed for $name, $e")
        end
        # else
        # break
        # end
    end

    #* Remove columns from fingerprints
    # Identify zero & one columns
    # num_cols = length(X_data[1][1])
    # zero_cols = trues(num_cols)
    # for (vec, _, _) in X_data
    #     zero_cols .&= (vec .== 0)
    # end
    # keep_cols = .!zero_cols # Create a Mask
    # X_data = [(vec[keep_cols], vals...) for (vec, vals...) in X_data] # Apply Mask

    # num_cols = length(X_data[1][1])
    # one_cols = trues(num_cols)
    # for (vec, _, _) in X_data
    #     one_cols .&= (vec .== 1)
    # end
    # keep_cols = .!one_cols # Create a Mask
    # X_data = [(vec[keep_cols], vals...) for (vec, vals...) in X_data] # Apply Mask

    train_data, test_data = splitobs((X_data, Y_data), at=1.0, shuffle=false)

    train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=false)
    test_loader = DataLoader(test_data, batchsize=batch_size, shuffle=false)
    println("n_batches = $(length(train_loader)), batch_size = $batch_size")
    flush(stdout)
    return train_loader, test_loader
end


function create_ff_model(nfeatures)
    # Base NN architecture from "Fitting Error vs Parameter Performance"
    nout = 3
    model = Chain(
        Dense(nfeatures, nout, x -> x; bias=false, init=zeros32),
    )
    # model = Chain(
    #     Dense(nfeatures, nout * 8, relu),
    #     Dense(nout * 8, nout * 4, relu),
    #     Dense(nout * 4, nout * 2, relu),
    #     Dense(nout * 2, nout, relu),
    #     # Dense(16, nout, relu),
    #     # Dense(16, nout, x -> x),
    #     # Dense(1024, 512, relu),
    #     # Dense(512, 256, relu),
    #     # Dense(32, 32, relu),
    #     # Dense(32, 32, relu),
    #     # Dense(32, nout, relu),
    # )
    # model(x) = m, σ, λ_a, λ_r, ϵ

    # return nn_model, unbounded_model
    return model
end

function get_idx_from_iterator(iterator, idx)
    data_iterator = iterate(iterator)
    for _ in 1:idx-1
        data_iterator = iterate(iterator, data_iterator[2])
    end
    return data_iterator[1]
end


# function SAFT_head(model, X; b=[3.0, 3.5, 7.0, 12.5, 250.0], c=10.0)
# function SAFT_head(model, X; b=[3.0, 3.5], c=10.0)
function SAFT_head(model, X; b=[2.0, 4.0, 250.0], c=Float64[1, 1, 100])

    # m = 1.8514
    # σ = 4.0887
    λ_a = 6.0
    λ_r = 13.65
    # ϵ = 273.64
    fp, p, T, Mw = X
    pred_params = model(fp)

    # Add bias and scale
    biased_params = @. pred_params * c + b

    saft_input = vcat(Mw, biased_params[1:2], [λ_a, λ_r], biased_params[3])
    # @show saft_input, p, T
    Vₗ = volume_NN(saft_input, p, T)

    ŷ = !isnan(Vₗ) ? Vₗ : 1e3
    # if !isnan(Vₗ)
    # ŷ = Vₗ
    # else

    # ŷ = sum(saft_input)
    # Tc = ignore_derivatives() do
    #     critical_temperature_NN(saft_input)
    # end
    # if T < Tc
    #     sat_p = saturation_pressure_NN(saft_input, T)
    #     if !isnan(sat_p)
    #         ŷ = sat_p
    #     else
    #         # println("sat_p is NaN at T = $T, saft_input = $saft_input")
    #         ŷ = nothing
    #     end
    # else
    #     ŷ = nothing
    # end

    return ŷ
end

function eval_loss(X_batch, y_batch, metric, model)
    batch_loss = 0.0
    n = 0
    for (X, y_vec) in zip(X_batch, y_batch)
        y = y_vec[1]
        ŷ = SAFT_head(model, X)
        if !isnothing(ŷ)
            batch_loss += metric(y, ŷ)
            n += 1
        end
    end
    if n > 0
        batch_loss /= n
    end
    # penalize batch_loss depending on how many failed
    # batch_loss += length(y_batch) - n

    return batch_loss
end

function eval_loss_par(X_batch, y_batch, metric, model, n_chunks)
    n = length(X_batch)
    chunk_size = n ÷ n_chunks

    p = bufferfrom(zeros(n_chunks))

    # Creating views for each chunk
    X_chunks = vcat([view(X_batch, (i-1)*chunk_size+1:i*chunk_size) for i in 1:n_chunks-1], [view(X_batch, (n_chunks-1)*chunk_size+1:n)])
    y_chunks = vcat([view(y_batch, (i-1)*chunk_size+1:i*chunk_size) for i in 1:n_chunks-1], [view(y_batch, (n_chunks-1)*chunk_size+1:n)])

    @sync begin
        for i = 1:n_chunks
            @spawn begin
                p[i] = eval_loss(X_chunks[i], y_chunks[i], metric, model)
            end
        end
    end
    return sum(p) / n_chunks # average partial losses
end

function percent_error(y, ŷ)
    return 100 * abs(y - ŷ) / y
end

function mse(y, ŷ)
    return (y - ŷ)^2 / y
end

function finite_diff_grads(model, x, y; ϵ=1e-5)
    grads = []
    for p in Flux.params(model)
        push!(grads, zeros(size(p)))
    end

    for (i, p) in enumerate(Flux.params(model))
        for j in eachindex(p)
            tmp = p[j]
            p[j] = tmp + ϵ
            J1 = eval_loss(x, y, percent_error, model)
            p[j] = tmp - ϵ
            J2 = eval_loss(x, y, percent_error, model)
            p[j] = tmp
            grads[i][j] = (J1 - J2) / (2 * ϵ)
        end
    end
    return grads
end

function train_model!(model, train_loader, test_loader; epochs=10)
    optim = Flux.setup(Flux.Adam(0.001), model) # 1e-3 usually safe starting LR
    # optim = Flux.setup(Descent(0.001), model)

    println("training on $(Threads.nthreads()) threads")
    flush(stdout)

    for epoch in 1:epochs
        batch_loss = 0.0
        for (X_batch, y_batch) in train_loader

            loss, grads = Flux.withgradient(model) do m
                # loss = eval_loss_par(X_batch, y_batch, percent_error, m, Threads.nthreads())
                loss = eval_loss(X_batch, y_batch, percent_error, m)
                loss
            end
            batch_loss += loss
            @assert !isnan(loss)

            grads_fd = finite_diff_grads(model, X_batch, y_batch)
            # @show grads[1]
            # @show grads_fd      # Show FD gradients

            Flux.update!(optim, model, grads[1])
        end
        batch_loss /= length(train_loader)
        epoch % 1 == 0 && println("epoch $epoch: batch_loss = $batch_loss")
        flush(stdout)
    end
end

function main(; epochs=15)
    train_loader, test_loader = create_data(n_points=21, batch_size=7) # Should make 5 batches / epoch. 256 / 8 gives 32 evaluations per thread
    @show n_features = length(first(train_loader)[1][1][1])

    model = create_ff_model(n_features)
    @show model.layers[1].weight, model([1.0])
    train_model!(model, train_loader, test_loader; epochs=epochs)
    return model
end

main (generic function with 1 method)

In [94]:
m = main(;epochs=50)

df.common_name = ["n-butane"]
Generating data for 1 molecules...
n_batches = 3, batch_size = 7


n_features = length((((first(train_loader))[1])[1])[1]) = 1
training on 1 threads




epoch 1: batch_loss = 5.360681053341593


epoch 2: batch_loss = 5.095254476969392


epoch 3: batch_loss = 4.897990291885889


epoch 4: batch_loss = 4.716068695442626


epoch 5: batch_loss = 4.5406649955639145


epoch 6: batch_loss = 4.368539442838148


epoch 7: batch_loss = 4.198127423822174


epoch 8: batch_loss = 4.02855431249443


epoch 9: batch_loss = 3.8592855096452303


epoch 10: batch_loss = 3.689975341095186


epoch 11: batch_loss = 3.520392129478276


epoch 12: batch_loss = 3.3503771152134068


epoch 13: batch_loss = 3.1798206537278144


epoch 14: batch_loss = 3.0086468179930477


epoch 15: batch_loss = 2.8368034388344303


epoch 16: batch_loss = 2.6642561814038066


epoch 17: batch_loss = 2.4909829114810775


epoch 18: batch_loss = 2.316970810965202


epoch 19: batch_loss = 2.142213757826628


epoch 20: batch_loss = 1.9667107931264542


epoch 21: batch_loss = 1.790464466729297


epoch 22: batch_loss = 1.613480559258684


epoch 23: batch_loss = 1.4357663158448926


epoch 24: batch_loss = 1.2698875354190393


epoch 25: batch_loss = 1.132949117650889


epoch 26: batch_loss = 1.0274738021408802


epoch 27: batch_loss = 0.9516407550856668


epoch 28: batch_loss = 0.9003618821088573


epoch 29: batch_loss = 0.8686134160576152


epoch 30: batch_loss = 0.8478455096292242


epoch 31: batch_loss = 0.8350654164006324


epoch 32: batch_loss = 0.8280746287812502


epoch 33: batch_loss = 0.8252577543607602


epoch 34: batch_loss = 0.8224166940669874


epoch 35: batch_loss = 0.8173909195652144


epoch 36: batch_loss = 0.8091421425004336


epoch 37: batch_loss = 0.7983932960186033


epoch 38: batch_loss = 0.7812298723085802


epoch 39: batch_loss = 0.7553670525877793


epoch 40: batch_loss = 0.7273135875886059


epoch 41: batch_loss = 0.7017588832703124


epoch 42: batch_loss = 0.6776388395646258


epoch 43: batch_loss = 0.6544290014914161


epoch 44: batch_loss = 0.6317444110134941


epoch 45: batch_loss = 0.6093020702975974


epoch 46: batch_loss = 0.5868938421666781


epoch 47: batch_loss = 0.5643667161345426


epoch 48: batch_loss = 0.5416079787954167


epoch 49: batch_loss = 0.518535083296794


epoch 50: batch_loss = 0.4950869015724746


Chain(
  Dense(1 => 3, #447; bias=false),      [90m# 3 parameters[39m
) 

In [95]:
fieldnames(typeof(m.layers[1]))
m.layers[1].weight

3×1 Matrix{Float32}:
  0.0070747477
 -0.02944446
  0.12534449

In [97]:
b = [2.0, 4.0, 250.0]
c = Float64[1, 1, 100]
m([1.0]) .* c .+ b

3-element Vector{Float64}:
   2.007074747700244
   3.970555540174246
 262.53444850444794

In [92]:
df = CSV.read("./pcpsaft_params/SI_pcp-saft_parameters.csv", DataFrame, header=1)
filter!(row -> occursin("Alkane", row.family), df)
df = first(df, 1) #* Take only first molecule in dataframe
mol_data = zip(df.common_name, df.isomeric_smiles, df.molarweight)
saft_model = SAFTVRMie([first(mol_data)[1]])
saft_model.params.segment.values[1], saft_model.params.sigma.values[1]*1e10, saft_model.params.epsilon.values[1]#, saft_model.params.lambda_a.values, saft_model.params.lambda_r.values, saft_model.params.epsilon.values

(1.8514, 4.0887, 273.64)