In [1]:
import Pkg; Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`


In [2]:
using Clapeyron
includet("./saftvrmienn.jl")
import Clapeyron: a_res

using MolecularGraph, Graphs
using Plots

using Flux
# using Flux: onecold, onehotbatch, logitcrossentropy
using Flux: DataLoader
using GraphNeuralNetworks
using ForwardDiff, Zygote, ChainRulesCore

using MLUtils
using OneHotArrays
# using LinearAlgebra, Random, Statistics
using Statistics, Random





In [3]:
# atom_symbol(mol): atom letters as a symbol e.g. :C, :O and :N
# charge(mol): electric charge of the atom. only integer charge is allowed in the model
# multiplicity(mol): 1: no unpaired electron(default), 2: radical, 3: biradical
# lone_pair(mol): number of lone pair on the atom
# implicit_hydrogens(mol): number of implicit hydrogens that are not appear as graph vertices but automatically calculated, drawn in image and used for calculation of other descriptors.
# valence(mol): number of atom valence, specific to each atom species and considering electric charge. Implicit number of hydrogens is obtained by subtracting the degree of the vertex from the valence.
# is_aromatic(mol): whether the atom is aromatic or not. only binary aromaticity is allowed in the model.
# pi_electron(mol): number of pi electrons
# hybridization(mol): orbital hybridization e.g. sp, sp2 and sp3

In [111]:
function make_graph_from_smiles(smiles::String)
    molgraph = smilestomol(smiles)

    g = SimpleGraph(nv(molgraph))
    for e in edges(molgraph)
        add_edge!(g, e.src, e.dst)
    end

    # Should number of hydrogens be one-hot encoded?
    f(vec, enc) = hcat(map(x -> onehot(x, enc), vec)...)
    num_h = f(implicit_hydrogens(molgraph), [0, 1, 2, 3, 4])
    hybrid = f(hybridization(molgraph), [:sp, :sp2, :sp3])
    atoms = f(atom_symbol(molgraph), [:C, :O, :N])

    # Node data should be matrix (num_features, num_nodes)
    # Matrix has num_nodes columns, num_features rows
    ndata = Float32.(vcat(num_h, hybrid, atoms))

    # h(vec, enc) = hcat(map(x -> onehot(x, enc), vec)...)
    # @show bond_order(molgraph), is_rotatable(molgraph), is_aromatic(molgraph), collect(edges(molgraph))
    b_order = Float32.(f(bond_order(molgraph), [1, 2, 3]))
    # @show b_order
    # rotatable = f(is_rotatable(molgraph), [false, true])
    # edata = Matrix{Float32}(vcat(b_order, rotatable))
    # edata = Matrix{Float32}(b_order)
    edata = nothing
    
    g = GNNGraph(g, ndata = ndata, edata = edata)
    return g
end

g = make_graph_from_smiles("C")

GNNGraph:
  num_nodes: 1
  num_edges: 0
  ndata:
	x = 11×1 Matrix{Float32}

In [109]:
model = SAFTVRMie(["methane"])
fieldnames(typeof(model.params.Mw))
model.params.Mw.values[1]

16.04

In [119]:
# Iterate over molecules in dataset and build graph for each one
# Initially sample data for hydrocarbons
#! isobutane, isopentane not defined for SAFTVRMie
all_species = [
    "methane",
    "ethane",
    "propane",
    "butane",
    "pentane",
    "hexane",
    "heptane",
    "octane",
    "nonane",
    "decane",
]

# Define smiles map
smiles_map = Dict(
    "methane" => "C",
    "ethane" => "CC",
    "propane" => "CCC",
    "butane" => "CCCC",
    "isobutane" => "CC(C)C",
    "pentane" => "CCCCC",
    "isopentane" => "CC(C)CC",
    "hexane" => "CCCCCC",
    "heptane" => "CCCCCCC",
    "octane" => "CCCCCCCC",
    "nonane" => "CCCCCCCCC",
    "decane" => "CCCCCCCCCC",
)

# Create training data, currently sampled along saturation curve

T = GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}
graphs = T[]
states = Vector{Float32}[]
species = String[] # For checking parameter similarity
Y_data = Float32[]

n = 15
for s in all_species
    # model = GERG2008([s])
    model = SAFTVRMie([s])
    Tc, pc, Vc = crit_pure(model)
    smiles = smiles_map[s]

    # fingerprint = make_fingerprint(smiles)
    g = make_graph_from_smiles(smiles)

    T_range = range(0.5 * Tc, 0.99 * Tc, n)
    # V_range = range(0.5 * Vc, 1.5 * Vc, n) # V could be sampled from a logspace
    for T in T_range
        (p₀, V_vec...) = saturation_pressure(model, T)
        if !any(isnan.(V_vec))
            for V in V_vec
                push!(graphs, g)
                push!(species, s)

                Mw = model.params.Mw.values[1]
                m = model.params.segment.values[1]
                push!(states, [V, T, Mw, m])

                a = a_res(model, V, T, [1.0])
                @assert !isnan(a) "a is NaN at (V,T) = ($(V),$(T)) for $s"
                push!(Y_data, a)
            end
        else
            @warn "NaN found in V_vec at T = $T for $s"
        end
    end
end

└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:69


In [120]:
m = SAFTgammaMie(["decane"])
crit_pure(m)

(630.7620907026924, 2.2684498427389706e6, 0.0006551521374012365)

In [121]:
n = 10
display(graphs[n])
display(states[n])
display(Y_data[n])
nothing

GNNGraph:
  num_nodes: 1
  num_edges: 0
  ndata:
	x = 11×1 Matrix{Float32}

4-element Vector{Float32}:
   0.0036027166
 125.08929
  16.04
   1.0

-0.058258507f0

In [122]:
train_data, test_data = splitobs((graphs, states, species, Y_data), at = 0.8, shuffle = true) |> getobs

Random.seed!(0)
train_loader = DataLoader(train_data, batchsize = 32, shuffle = true)
test_loader = DataLoader(test_data, batchsize = 32, shuffle = false)

# Testing if batching works. This will be used when training
# This should produce a single GNNGraph object with a matrix of ndata
vec_gs, _ = first(train_loader)
MLUtils.batch(vec_gs)

GNNGraph:
  num_nodes: 162
  num_edges: 260
  num_graphs: 32
  ndata:
	x = 11×162 Matrix{Float32}

In [123]:
function differentiable_saft(X::AbstractVector{T}, Vol, Temp, Mw, m) where {T<:Real}
    model = SAFTVRMieNN(
        params=SAFTVRMieNNParams(
            Mw=[Mw],
            segment=[m], # (C - 4)/(3) + 1
            sigma=[X[1]] * 1f-10,
            lambda_a=[6.0], # Fixing at 6. Simple molecules interacting through London dispersion -> Should have λₐ = 6.
            lambda_r=[X[2]],
            epsilon=[X[3]],
            # Required for association
            epsilon_assoc=Float32[],
            bondvol=Float32[],
        )
    )
    return a_res(model, Vol, Temp, [1.0])
end

function ChainRulesCore.rrule(::typeof(differentiable_saft), x, V, T, Mw, m)
    y = differentiable_saft(x, V, T, Mw, m)

    function f_pullback(Δy)
        # Use ForwardDiff to compute the gradient
        #? ForwardDiff through nonlinear SAFT solvers not ideal.
        ∂x = @thunk(ForwardDiff.gradient(x -> differentiable_saft(x, V, T, Mw, m), x) .* Δy)
        ∂V = @thunk(ForwardDiff.derivative(V -> differentiable_saft(x, V, T, Mw, m), V) * Δy)
        ∂T = @thunk(ForwardDiff.derivative(T -> differentiable_saft(x, V, T, Mw, m), T) * Δy)
        ∂Mw = @thunk(ForwardDiff.derivative(Mw -> differentiable_saft(x, V, T, Mw, m), Mw) * Δy)
        ∂m = @thunk(ForwardDiff.derivative(m -> differentiable_saft(x, V, T, Mw, m), m) * Δy)
        return (NoTangent(), ∂x, ∂V, ∂T, ∂Mw, ∂m)
    end

    return y, f_pullback
end

In [124]:
#? Works for evaluation, fails for gradients
X = [3.737, 12.504, 152.58]
@show differentiable_saft(X, 1e-4, 300, 14.0, 1.0)
@show ForwardDiff.gradient(x -> differentiable_saft(x, 1e-4, 300, 14.0, 1.0), X)
@show Zygote.gradient(x -> differentiable_saft(x, 1e-4, 300, 14.0, 1.0), X)

differentiable_saft(X, 0.0001, 300, 14.0, 1.0) = -0.2883850024854241
ForwardDiff.gradient((x->begin
            #= /home/luc/SAFT_ML/4_gnn_saft.ipynb:4 =#
            differentiable_saft(x, 0.0001, 300, 14.0, 1.0)
        end), X) = 

[-0.12769665985165313, 0.053698346998990246, -0.0064937953559046635]


Zygote.gradient((x->begin
            #= /home/luc/SAFT_ML/4_gnn_saft.ipynb:5 =#
            differentiable_saft(x, 0.0001, 300, 14.0, 1.0)
        end), X) = ([-0.12769665985165313, 0.053698346998990246, -0.0064937953559046635],)


([-0.12769665985165313, 0.053698346998990246, -0.0064937953559046635],)

In [130]:
function create_graphconv_model(nin, nh; nout=3, nhlayers=1, afunc=relu)
    GNNChain(
        GraphConv(nin => nh, afunc),
        [GraphConv(nh => nh, afunc) for _ in 1:nhlayers]...,
        GlobalPool(mean), # Average the node features
        Dropout(0.2),
        Dense(nh, nh),
        Dense(nh, nout),
    )
end

function create_graphattention_model(nin, ein, nh; nout=3, nhlayers=1, afunc=relu)
    GNNChain(
        GATv2Conv((nin, ein) => nh, afunc),
        [GATv2Conv(nh => nh, afunc) for _ in 1:nhlayers]...,
        GlobalPool(mean),
        Dropout(0.2),
        Dense(nh, nh),
        Dense(nh, nout),
    )
end

function bound_output(X, lb, ub, b=10.0)
    return @. lb + (ub - lb) * 0.5 * (tanh(1 / b * (X - lb) / (ub - lb)) + 1)
end

function predict_a_res(X, V, T, Mw, m)
    # Bound output
    bounds = Tuple{Float32,Float32}[
        (2.5, 5), # σ
        (10, 20), # λ_r
        (100, 1000), # ϵ #? What should this be bounded to for SAFTVRMie?
    ]
    # ŷ = mean(X)
    X_bounded = bound_output(X, first.(bounds), last.(bounds))
    ŷ = differentiable_saft(X_bounded, V, T, Mw, m)
    return ŷ + mean(X)
end

function eval_loss(model, data_loader, device)
    loss = 0.0
    acc = 0.0
    for (g, state, species, y) in data_loader
        g, state, y = MLUtils.batch(g) |> device, state |> device, y |> device
        X = model(g, g.ndata.x)
        for (Xᵢ, stateᵢ, yᵢ) in zip(eachcol(X), state, y)
            V, T, Mw, m = stateᵢ
            ŷ = predict_a_res(Xᵢ, V, T, Mw, m)
            loss += ((ŷ - yᵢ) / yᵢ)^2
            acc += abs((ŷ - yᵢ) / yᵢ)
            @assert loss isa Real "Loss is not a real number, got $(typeof(loss)), X_pred = $X_pred"
            @assert !isnan(loss) "Loss is NaN, X_pred = $X_pred"
        end
        loss /= length(state)
        acc /= length(state)
    end
    loss /= length(data_loader)
    acc /= length(data_loader)
    # return loss, 100 * sqrt(loss)
    return (loss = round(loss, digits = 4),
            acc = round(100 * acc, digits = 2))
end

function train!(model; epochs=50, η=1e-2, infotime=10, log_loss=false)
    # device = Flux.gpu # uncomment this for GPU training
    device = Flux.cpu
    model = model |> device
    opt = ADAM()

    function report(epoch)
        train = eval_loss(model, train_loader, device)
        test = eval_loss(model, test_loader, device)
        @info (; epoch, train, test)
    end

    epoch_loss_vec = Float32[]
    report(0)
    for epoch in 1:epochs
        epoch_loss = 0.0
        for (g, state, species, y) in train_loader
            g, state, y = MLUtils.batch(g) |> device, state |> device, y |> device

            batch_loss = 0.0
            loss_fn() = begin
                X = model(g, g.ndata.x)
                for (Xᵢ, stateᵢ, yᵢ) in zip(eachcol(X), state, y)
                    V, T, Mw, m = stateᵢ
                    ŷ = predict_a_res(Xᵢ, V, T, Mw, m)
                    batch_loss += ((ŷ - yᵢ) / yᵢ)^2
                    @assert batch_loss isa Real "Loss is not a real number, got $(typeof(loss)), X_pred = $X_pred"
                    @assert !isnan(batch_loss) "Loss is NaN, X_pred = $X_pred"
                end
                batch_loss /= length(state)
            end

            grads = Zygote.gradient(Flux.params(model)) do
                loss_fn()
            end
            epoch_loss += batch_loss
            Flux.update!(opt, Flux.params(model), grads)
        end
        epoch_loss /= length(train_loader)
        push!(epoch_loss_vec, epoch_loss)
        
        epoch % infotime == 0 && report(epoch)
    end
    return epoch_loss_vec
end

train! (generic function with 1 method)

In [135]:
nin = 11
for nh ∈ [16, 64, 128, 256]
    @info "nh = $nh"
    model = create_graphattention_model(nin, nh, nhlayers=1)
    epoch_loss_vec = train!(model, epochs=50, infotime=50)
end
# nh = 16
# model = create_graphattention_model(nin, nh, nhlayers=1)
# model = create_graphconv_model(nin, nh; nhlayers=1)
# epoch_loss_vec = train!(model, epochs=50, infotime=50)
nothing

┌ Info: nh = 16
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:3
┌ Info: (epoch = 0, train = (loss = 67.5765, acc = 108.01), test = (loss = 169.113, acc = 335.32))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72


┌ Info: (epoch = 50, train = (loss = 2.2443, acc = 43.53), test = (loss = 6.4945, acc = 131.55))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72
┌ Info: nh = 64
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:3
┌ Info: (epoch = 0, train = (loss = 26.691, acc = 87.3), test = (loss = 89.645, acc = 310.24))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72
┌ Info: (epoch = 50, train = (loss = 1.17, acc = 29.72), test = (loss = 3.9623, acc = 110.99))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72


┌ Info: nh = 128
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:3
┌ Info: (epoch = 0, train = (loss = 210.3582, acc = 210.27), test = (loss = 579.042, acc = 528.27))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72


┌ Info: (epoch = 50, train = (loss = 1.4218, acc = 32.77), test = (loss = 11.0201, acc = 149.28))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72
┌ Info: nh = 256
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:3
┌ Info: (epoch = 0, train = (loss = 10.1809, acc = 46.13), test = (loss = 20.1314, acc = 160.87))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72
┌ Info: (epoch = 50, train = (loss = 1.3892, acc = 26.78), test = (loss = 16.9619, acc = 143.71))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72


In [136]:
nin = 11
for nh ∈ [16, 64, 128, 256]
    @info "nh = $nh"
    model = create_graphconv_model(nin, nh; nhlayers=1)
    epoch_loss_vec = train!(model, epochs=50, infotime=50)
end
nothing

┌ Info: nh = 16
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:3
┌ Info: (epoch = 0, train = (loss = 17661.8113, acc = 1517.87), test = (loss = 9717.8597, acc = 2010.68))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72


┌ Info: (epoch = 50, train = (loss = 9.6704, acc = 56.25), test = (loss = 53.5042, acc = 200.69))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72
┌ Info: nh = 64
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:3
┌ Info: (epoch = 0, train = (loss = 14.7587, acc = 58.99), test = (loss = 50.9815, acc = 226.55))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72
┌ Info: (epoch = 50, train = (loss = 11.2679, acc = 38.86), test = (loss = 64.9709, acc = 244.51))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72


┌ Info: nh = 128
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:3
┌ Info: (epoch = 0, train = (loss = 2.8101, acc = 36.28), test = (loss = 15.331, acc = 152.46))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72


┌ Info: (epoch = 50, train = (loss = 1.6628, acc = 22.24), test = (loss = 13.1399, acc = 136.3))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72
┌ Info: nh = 256
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:3
┌ Info: (epoch = 0, train = (loss = 28.203, acc = 67.42), test = (loss = 1460.2671, acc = 809.78))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72
┌ Info: (epoch = 50, train = (loss = 2.8036, acc = 33.82), test = (loss = 4.8567, acc = 114.86))
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:72


In [13]:
# Abstract this into a function that takes a model and returns the X_bounded output
g, state, species, y = first(train_loader)
g = MLUtils.batch(g)
X = model(g, g.ndata.x)
X = first(eachcol(X))
@show first(state), first(species), first(y)
bounds = Tuple{Float32,Float32}[
    (2.5, 5),
    (10, 18),
    (150, 300),
]
# ŷ = mean(X)
X_bounded = bound_output(X, first.(bounds), last.(bounds))
m = SAFTVRMie([first(species)])
X_nominal = [m.params.sigma.values[1]*1e10, m.params.lambda_r.values[1], m.params.epsilon.values[1]]
display(X_bounded)
display(X_nominal)

3-element Vector{Float64}:
   3.620209658064277
  13.453103363749996
 217.57003215248915

3-element Vector{Float64}:
   4.0887
  13.65
 273.64

(first(state), first(species), first(y)) = (Float32[9.589684f-5, 267.38336, 58.12, 1.8514], "butane", -4.662166f0)


In [14]:
m = SAFTVRMie(["methane"])
X_nominal = [m.params.sigma.values[1], m.params.lambda_r.values[1], m.params.epsilon.values[1]]

3-element Vector{Float64}:
   3.737e-10
  12.504
 152.58

In [15]:
plot(epoch_loss_vec, label = "Training loss")

UndefVarError: UndefVarError: epoch_loss_vec not defined

In [16]:
plot(epoch_loss_vec, label = "Training loss", ylims=(0.0, 0.02))

UndefVarError: UndefVarError: epoch_loss_vec not defined

In [17]:
# Next:
# - Take notebooks, extract to .jl single-functions and run hyperparameter sweep
# - Evaluate AAD on saturation pressure
# - Stratify test/train data by molecule
