In [3]:
import Pkg; Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`


In [4]:
using Clapeyron
import Clapeyron: a_res

using MolecularGraph, Graphs
using Plots

using Flux
using Flux: onecold, onehotbatch, logitcrossentropy
using Flux: DataLoader
using GraphNeuralNetworks
using MLDatasets
using MLUtils
using OneHotArrays
using LinearAlgebra, Random, Statistics



In [5]:
# atom_symbol(mol): atom letters as a symbol e.g. :C, :O and :N
# charge(mol): electric charge of the atom. only integer charge is allowed in the model
# multiplicity(mol): 1: no unpaired electron(default), 2: radical, 3: biradical
# lone_pair(mol): number of lone pair on the atom
# implicit_hydrogens(mol): number of implicit hydrogens that are not appear as graph vertices but automatically calculated, drawn in image and used for calculation of other descriptors.
# valence(mol): number of atom valence, specific to each atom species and considering electric charge. Implicit number of hydrogens is obtained by subtracting the degree of the vertex from the valence.
# is_aromatic(mol): whether the atom is aromatic or not. only binary aromaticity is allowed in the model.
# pi_electron(mol): number of pi electrons
# hybridization(mol): orbital hybridization e.g. sp, sp2 and sp3

In [6]:
function make_graph_from_smiles(smiles::String)
    molgraph = smilestomol(smiles)

    g = Graph(nv(molgraph))
    for e in edges(molgraph)
        add_edge!(g, e.src, e.dst)
    end

    # Should number of hydrogens be one-hot encoded?
    f(vec, enc) = hcat(map(x -> onehot(x, enc), vec)...)
    num_h = f(implicit_hydrogens(molgraph), [0, 1, 2, 3, 4])
    hybrid = f(hybridization(molgraph), [:sp, :sp2, :sp3])
    atoms = f(atom_symbol(molgraph), [:C, :O, :N])

    # Node data should be matrix (num_features, num_nodes)
    # Matrix has num_nodes columns, num_features rows
    ndata = Float32.(vcat(num_h, hybrid, atoms))

    g = GNNGraph(g, ndata = ndata, edata = nothing)
    return g
end

g = make_graph_from_smiles("CC=CC(CC=O)")

GNNGraph:
  num_nodes: 7
  num_edges: 12
  ndata:
	x = 11×7 Matrix{Float32}

In [7]:
g.ndata.x

11×7 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  1.0
 0.0  1.0  1.0  0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0  1.0  0.0  0.0
 1.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  1.0  1.0  0.0  0.0  1.0  1.0
 1.0  0.0  0.0  1.0  1.0  0.0  0.0
 1.0  1.0  1.0  1.0  1.0  1.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0

In [67]:
# Iterate over molecules in dataset and build graph for each one
# Initially sample data for hydrocarbons
#! isobutane, isopentane not defined for SAFTVRMie
species = [
    "methane",
    "ethane",
    "propane",
    "butane",
    "pentane",
    "hexane",
    "heptane",
    "octane",
    "nonane",
    # "decane",
]

# Define smiles map
smiles_map = Dict(
    "methane" => "C",
    "ethane" => "CC",
    "propane" => "CCC",
    "butane" => "CCCC",
    "isobutane" => "CC(C)C",
    "pentane" => "CCCCC",
    "isopentane" => "CC(C)CC",
    "hexane" => "CCCCCC",
    "heptane" => "CCCCCCC",
    "octane" => "CCCCCCCC",
    "nonane" => "CCCCCCCCC",
    "decane" => "CCCCCCCCCC",
)

# X data contains graph, V, T
# Y data contains a_res
#* Sampling data along saturation curve
T = Float32
# X_data = Vector{Tuple{typeof(g),T,T, String}}([])
# Y_data = Vector{Vector{T}}()

T = GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}
graphs = T[]
states = Tuple{Float32, Float32, String}[]
Y_data = Float32[]

n = 30
for s in species
    # model = GERG2008([s])
    model = SAFTVRMie([s])
    Tc, pc, Vc = crit_pure(model)
    smiles = smiles_map[s]

    # fingerprint = make_fingerprint(smiles)
    g = make_graph_from_smiles(smiles)

    T_range = range(0.5 * Tc, 0.99 * Tc, n)
    # V_range = range(0.5 * Vc, 1.5 * Vc, n) # V could be sampled from a logspace
    for T in T_range
        (p₀, V_vec...) = saturation_pressure(model, T)
        for V in V_vec
            push!(graphs, g)
            push!(states, (V, T, s))
            
            a = a_res(model, V, T, [1.0])
            @assert !isnan(a) "a is NaN at $(V), $(T), $s"
            push!(Y_data, a)
        end
    end
end

In [68]:
n = 500
display(graphs[n])
display(states[n])
display(Y_data[n])
nothing

GNNGraph:
  num_nodes: 9
  num_edges: 16
  ndata:
	x = 11×9 Matrix{Float32}

(0.07785356f0, 392.35257f0, "nonane")

-0.020128421f0

In [69]:
train_data, test_data = splitobs((graphs, states, Y_data), at = 0.8, shuffle = true) |> getobs

Random.seed!(0)
train_loader = DataLoader(train_data, batchsize = 32, shuffle = true)
test_loader = DataLoader(test_data, batchsize = 32, shuffle = false)

# Testing if batching works. This will be used when training
# This should produce a single GNNGraph object with a matrix of ndata
vec_gs, _ = first(train_loader)
MLUtils.batch(vec_gs)

GNNGraph:
  num_nodes: 163
  num_edges: 262
  num_graphs: 32
  ndata:
	x = 11×163 Matrix{Float32}

In [82]:
function create_model(nin, nh, nout)
    GNNChain(
        GraphConv(nin => nh, relu),
        # GraphConv(nh => nh, relu),
        # GraphConv(nh => nh),
        GlobalPool(mean), # Average the node features
        Dropout(0.2),
        Dense(nh, nout),
    )
end

function eval_loss(model, data_loader, device)
    loss = 0.0
    for (g, states, y) in data_loader
        g, y = MLUtils.batch(g) |> device, y |> device
        ŷ = model(g, g.ndata.x)'
        loss += Flux.mse(ŷ, y)
    end
    return loss / length(data_loader)
end

function train!(model; epochs=12, η=1e-3, infotime=3)
    # device = Flux.gpu # uncomment this for GPU training
    device = Flux.cpu
    model = model |> device
    opt = Flux.setup(Adam(1e-3), model)

    function report(epoch)
        train = eval_loss(model, train_loader, device)
        test = eval_loss(model, test_loader, device)
        @info (; epoch, train, test)
    end

    report(0)
    for epoch in 1:epochs
        for (g, states, y) in train_loader
            g, y = MLUtils.batch(g) |> device, y |> device
            loss = 0.0
            grads = Flux.gradient(model) do model
                ŷ = model(g, g.ndata.x)'
                loss = Flux.mse(ŷ, y)
            end
            Flux.update!(opt, model, grads[1])
        end
        epoch % infotime == 0 && report(epoch)
    end
end

train! (generic function with 1 method)

In [83]:
nin = 11
nh = 16
nout = 1
model = create_model(nin, nh, nout)
train!(model)

┌ Info: (epoch = 0, train = 10.289446728570121, test = 8.44406795501709)
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:31
┌ Info: (epoch = 3, train = 8.590846572603498, test = 7.037330150604248)
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:31


┌ Info: (epoch = 6, train = 7.50623973778316, test = 6.122561097145081)
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:31
┌ Info: (epoch = 9, train = 6.5303173405783514, test = 5.111987233161926)
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:31


┌ Info: (epoch = 12, train = 5.40480945791517, test = 4.426450252532959)
└ @ Main /home/luc/SAFT_ML/4_gnn_saft.ipynb:31
