In [1]:
import Pkg; Pkg.activate(".")
using Clapeyron
includet("./saftvrmienn.jl")
# These are functions we're going to overload for SAFTVRMieNN
import Clapeyron: a_res, saturation_pressure, pressure

using Flux
using Plots, Statistics
using ForwardDiff, DiffResults

using Zygote, ChainRulesCore
using ImplicitDifferentiation

using CSV, DataFrames
using MLUtils
using RDKitMinimalLib
using JLD2

# Multithreaded loss
using Zygote: bufferfrom
using Base.Threads: @spawn

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`




In [2]:
X = [16.04, 1.0, 3.737, 6.0, 12.504, 152.58]
V = volume_NN(X, 1e7, 100.0)
∂V∂X = Zygote.gradient(X -> volume_NN(X, 1e7, 100.0), X)

([-0.0, 2.260005096581295e-5, 2.1863843008154135e-5, 1.46022174535494e-6, 1.6493904176981177e-7, -6.011351535625539e-8],)

In [3]:
X = [16.04, 1.0, 3.737, 6.0, 12.504, 152.58]
f_V(X) = volume_NN(X, 1e7, 100.0)[1]
dX = [0.0, 0.0, 0.0, 0.0, 0.0, 1e-6]
f_∂V∂X(X) = (f_V(X .+ dX) - f_V(X .- dX))/(2dX)
f_∂V∂X(X)

1×6 transpose(::Vector{Float64}) with eltype Float64:
 -0.0  -0.0  -0.0  -0.0  -0.0  -5.78614e-8

In [99]:
# Generate training set for liquid density and saturation pressure
function create_data(; batch_size=16, n_points=25)
    # Create training & validation data
    df = CSV.read("./pcpsaft_params/SI_pcp-saft_parameters.csv", DataFrame, header=1)
    filter!(row -> occursin("Alkane", row.family), df)
    df = first(df, 1) #* Take only first molecule in dataframe
    mol_data = zip(df.common_name, df.isomeric_smiles, df.molarweight)
    println("Generating data for $(length(mol_data)) molecules...")

    function make_fingerprint(s::String)::Vector{Float64}
        mol = get_mol(s)
        @assert !isnothing(mol)

        fp = []
        # for (nbits, rad) in [(256, 256), (1, 3)]
        #* Approximately ECFP4 fingerprint
        nbits = 256
        rad = 4

        fp_details = Dict{String,Any}("nBits" => nbits, "radius" => rad)
        fp_str = get_morgan_fp(mol, fp_details)
        append!(fp, [parse(Float64, string(c)) for c in fp_str])
        # end

        desc = get_descriptors(mol)
        relevant_keys = [
            "CrippenClogP",
            "NumHeavyAtoms",
            "amw",
            "FractionCSP3",
        ]
        relevant_desc = [desc[k] for k in relevant_keys]
        append!(fp, last.(relevant_desc))

        return fp
    end

    T = Float64
    X_data = Vector{Tuple{Vector{T},T,T,T}}([])
    Y_data = Vector{Vector{T}}()

    for (name, smiles, Mw) in mol_data
        try
            # saft_model = PPCSAFT([name])
            saft_model = SAFTVRMie([name])
            Tc, pc, Vc = crit_pure(saft_model)

            # fp = make_fingerprint(smiles)
            fp = [1.0]
            # append!(fp, Mw)

            T_range = range(0.5 * Tc, 0.975 * Tc, n_points)
            for T in T_range
                (p₀, V_vec...) = saturation_pressure(saft_model, T)
                p = p₀ * 1.5
                Vₗ = volume(saft_model, p, T; phase=:liquid)
                push!(X_data, (fp, p, T, Mw))
                push!(Y_data, [Vₗ])
            end
        catch e
            println("Fingerprint generation failed for $name, $e")
        end
    end

    #* Remove columns from fingerprints
    # Identify zero & one columns
    # num_cols = length(X_data[1][1])
    # zero_cols = trues(num_cols)
    # for (vec, _, _) in X_data
    #     zero_cols .&= (vec .== 0)
    # end
    # keep_cols = .!zero_cols # Create a Mask
    # X_data = [(vec[keep_cols], vals...) for (vec, vals...) in X_data] # Apply Mask

    # num_cols = length(X_data[1][1])
    # one_cols = trues(num_cols)
    # for (vec, _, _) in X_data
    #     one_cols .&= (vec .== 1)
    # end
    # keep_cols = .!one_cols # Create a Mask
    # X_data = [(vec[keep_cols], vals...) for (vec, vals...) in X_data] # Apply Mask

    train_data, test_data = splitobs((X_data, Y_data), at=0.8, shuffle=false)

    train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=false)
    test_loader = DataLoader(test_data, batchsize=batch_size, shuffle=false)
    println("n_batches = $(length(train_loader)), batch_size = $batch_size")
    flush(stdout)
    return train_loader, test_loader
end


function create_ff_model(nfeatures)
    # Base NN architecture from "Fitting Error vs Parameter Performance"
    nout = 5
    model = Chain(
        Dense(nfeatures, nout*8, relu),
        Dense(nout*8, nout*4, relu),
        Dense(nout*4, nout*2, relu),
        Dense(nout*2, nout, relu),
        # Dense(16, nout, relu),
        # Dense(16, nout, x -> x),
        # Dense(1024, 512, relu),
        # Dense(512, 256, relu),
        # Dense(32, 32, relu),
        # Dense(32, 32, relu),
        # Dense(32, nout, relu),
    )
    # model(x) = m, σ, λ_a, λ_r, ϵ

    # return nn_model, unbounded_model
    return model
end

function get_idx_from_iterator(iterator, idx)
    data_iterator = iterate(iterator)
    for _ in 1:idx-1
        data_iterator = iterate(iterator, data_iterator[2])
    end
    return data_iterator[1]
end


function SAFT_head(model, X; b=[3.0, 3.5, 7.0, 12.5, 250.0], c=10.0)
    fp, p, T, Mw = X
    pred_params = model(fp)
    # Add bias and scale
    biased_params = pred_params / c + b

    saft_input = vcat(Mw, biased_params)
    Vₗ = volume_NN(saft_input, p, T)
    ŷ = Vₗ
    # ŷ = sum(saft_input)
    # Tc = ignore_derivatives() do
    #     critical_temperature_NN(saft_input)
    # end
    # if T < Tc
    #     sat_p = saturation_pressure_NN(saft_input, T)
    #     if !isnan(sat_p)
    #         ŷ = sat_p
    #     else
    #         # println("sat_p is NaN at T = $T, saft_input = $saft_input")
    #         ŷ = nothing
    #     end
    # else
    #     ŷ = nothing
    # end

    return ŷ
end

function eval_loss(X_batch, y_batch, metric, model)
    batch_loss = 0.0
    n = 0
    for (X, y_vec) in zip(X_batch, y_batch)
        y = y_vec[1]
        ŷ = SAFT_head(model, X)
        if !isnothing(ŷ)
            batch_loss += metric(y, ŷ)
            n += 1
        end
    end
    if n > 0 
        batch_loss /= n
    end
    # penalize batch_loss depending on how many failed
    batch_loss += length(y_batch) - n

    return batch_loss
end

function eval_loss_par(X_batch, y_batch, metric, model, n_chunks)
    n = length(X_batch)
    chunk_size = n ÷ n_chunks

    p = bufferfrom(zeros(n_chunks))

    # Creating views for each chunk
    X_chunks = vcat([view(X_batch, (i-1)*chunk_size+1:i*chunk_size) for i in 1:n_chunks-1], [view(X_batch, (n_chunks-1)*chunk_size+1:n)])
    y_chunks = vcat([view(y_batch, (i-1)*chunk_size+1:i*chunk_size) for i in 1:n_chunks-1], [view(y_batch, (n_chunks-1)*chunk_size+1:n)])

    @sync begin
        for i = 1:n_chunks
            @spawn begin 
                p[i] = eval_loss(X_chunks[i], y_chunks[i], metric, model)
            end
        end
    end
    return sum(p) / n_chunks # average partial losses
end

function percent_error(y, ŷ)
    return 100 * abs(y - ŷ) / y
end

function mse(y, ŷ)
    return (y - ŷ)^2
end

function train_model!(model, train_loader, test_loader; epochs=10)
    optim = Flux.setup(Flux.Adam(0.001), model) # 1e-3 usually safe starting LR

    println("training on $(Threads.nthreads()) threads")
    flush(stdout)

    for epoch in 1:epochs
        batch_loss = 0.0
        for (X_batch, y_batch) in train_loader

            loss, grads = Flux.withgradient(model) do m
                # loss = eval_loss_par(X_batch, y_batch, percent_error, m, Threads.nthreads())
                loss = eval_loss(X_batch, y_batch, percent_error, m)
                loss
            end
            batch_loss += loss
            @assert !isnan(loss)

            Flux.update!(optim, model, grads[1])
        end
        batch_loss /= length(train_loader)
        epoch % 25 == 0 && println("epoch $epoch: batch_loss = $batch_loss")
        flush(stdout)
    end
end

function main(;epochs=15)
    train_loader, test_loader = create_data(n_points=40, batch_size=32) # Should make 5 batches / epoch. 256 / 8 gives 32 evaluations per thread
    n_features = length(first(train_loader)[1][1][1])

    model = create_ff_model(n_features)
    train_model!(model, train_loader, test_loader; epochs=epochs)
    return model
end

main (generic function with 1 method)

In [100]:
m = main(;epochs=1000)

Generating data for 1 molecules...
n_batches = 1, batch_size = 32


training on 1 threads


epoch 25: batch_loss = 9.242274525578019


epoch 50: batch_loss = 9.218041463446953


epoch 75: batch_loss = 9.1754327757208


epoch 100: batch_loss = 9.09752602689555


epoch 125: batch_loss = 8.9506153816173


epoch 150: batch_loss = 8.779983405275496


epoch 175: batch_loss = 8.522244616063816


epoch 200: batch_loss = 8.147972954979284


epoch 225: batch_loss = 7.688389637370618


epoch 250: batch_loss = 7.065613981395191


epoch 275: batch_loss = 6.265253465256205


epoch 300: batch_loss = 5.282424080702349


epoch 325: batch_loss = 4.208476258923702


epoch 350: batch_loss = 2.9770374781779974


epoch 375: batch_loss = 1.6247487279772053


epoch 400: batch_loss = 0.6792526953527332


epoch 425: batch_loss = 0.6199522504326581


epoch 450: batch_loss = 0.6195588667338581


epoch 475: batch_loss = 0.6195577069538356


epoch 500: batch_loss = 0.6196034368776722


epoch 525: batch_loss = 0.6195600547982286


epoch 550: batch_loss = 0.6195750107928558


epoch 575: batch_loss = 0.6195887851861731


epoch 600: batch_loss = 0.6195602653554133


epoch 625: batch_loss = 0.6195599676672674


epoch 650: batch_loss = 0.619559439366473


epoch 675: batch_loss = 0.6195596397317972


epoch 700: batch_loss = 0.6195599221086124


epoch 725: batch_loss = 0.619559918196734


epoch 750: batch_loss = 0.6195604061569904


epoch 775: batch_loss = 0.619573531612702


epoch 800: batch_loss = 0.6195603085423035


epoch 825: batch_loss = 0.6195600496306792


epoch 850: batch_loss = 0.6195602140529343


epoch 875: batch_loss = 0.6195601420358637


epoch 900: batch_loss = 0.6195800491765063


epoch 925: batch_loss = 0.6195737627352352


epoch 950: batch_loss = 0.619560358227653


epoch 975: batch_loss = 0.6195844865576995


epoch 1000: batch_loss = 0.6195983065842929


Chain(
  Dense(1 => 40, relu),                 [90m# 80 parameters[39m
  Dense(40 => 20, relu),                [90m# 820 parameters[39m
  Dense(20 => 10, relu),                [90m# 210 parameters[39m
  Dense(10 => 5, relu),                 [90m# 55 parameters[39m
) [90m                  # Total: 8 arrays, [39m1_165 parameters, 5.051 KiB.

In [101]:
train_loader, test_loader = create_data(n_points=40, batch_size=32) # Should make 5 batches / epoch. 256 / 8 gives 32 evaluations per thread
mol = first(train_loader)[1][1]
fp, p, T, Mw = mol

b = [3.0, 3.5, 7.0, 12.5, 250.0]
c = 10.0
pred_params = m(fp)
biased_params = pred_params / c + b

biased_params

Generating data for 1 molecules...


n_batches = 1, batch_size = 32


5-element Vector{Float64}:
   3.0
   3.5
   7.0
  12.5
 280.8042358398437

In [102]:
df = CSV.read("./pcpsaft_params/SI_pcp-saft_parameters.csv", DataFrame, header=1)
filter!(row -> occursin("Alkane", row.family), df)
df = first(df, 1) #* Take only first molecule in dataframe
mol_data = zip(df.common_name, df.isomeric_smiles, df.molarweight)
saft_model = SAFTVRMie([first(mol_data)[1]])
saft_model.params.segment.values, saft_model.params.sigma.values, saft_model.params.lambda_a.values, saft_model.params.lambda_r.values, saft_model.params.epsilon.values

([1.8514], [4.0887000000000006e-10;;], [6.0;;], [13.65;;], [273.64;;])