In [24]:
import Pkg; Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`


In [25]:
using CUDA
using Flux
using Flux: onehotbatch, onecold
using Flux.Losses: logitcrossentropy
using Flux.Data: DataLoader
using GeometricFlux
using GeometricFlux.Datasets
using GraphSignals
using Graphs
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using Statistics
using Random

using RDKitMinimalLib
using MolecularGraph
using Plots, GraphPlot
using OneHotArrays

using Clapeyron
import Clapeyron: a_res

In [26]:
# atom_symbol(mol): atom letters as a symbol e.g. :C, :O and :N
# charge(mol): electric charge of the atom. only integer charge is allowed in the model
# multiplicity(mol): 1: no unpaired electron(default), 2: radical, 3: biradical
# lone_pair(mol): number of lone pair on the atom
# implicit_hydrogens(mol): number of implicit hydrogens that are not appear as graph vertices but automatically calculated, drawn in image and used for calculation of other descriptors.
# valence(mol): number of atom valence, specific to each atom species and considering electric charge. Implicit number of hydrogens is obtained by subtracting the degree of the vertex from the valence.
# is_aromatic(mol): whether the atom is aromatic or not. only binary aromaticity is allowed in the model.
# pi_electron(mol): number of pi electrons
# hybridization(mol): orbital hybridization e.g. sp, sp2 and sp3

In [27]:
function concatenate_row(row)
    concatenated = Float32[]
    for vec in row
        append!(concatenated, Float32.(vec))
    end
    return concatenated
end

function make_graph_from_smiles(smiles::String)::FeaturedGraph
    # FeaturedGraph provided by GraphSignals for GeometricFlux, designed 
    # for use with GNNs
    molgraph = smilestomol(smiles)

    g = Graph(nv(molgraph))
    for e in edges(molgraph)
        add_edge!(g, e.src, e.dst)
    end

    # Should number of hydrogens be one-hot encoded?
    num_h = map(x -> onehot(x, [1, 2, 3, 4]), implicit_hydrogens(molgraph))
    hybrid = map(x -> onehot(x, [:sp, :sp2, :sp3]), hybridization(molgraph)) # One-hot encoded (sp, sp2, sp3)
    atoms = map(x -> onehot(x, [:C, :O, :N]), atom_symbol(molgraph)) # One-hot encoded (C, O, N, etc.)

    # Clunky way to do this, almost certainly better ways
    nf = hcat(num_h, hybrid, atoms)
    transformed_rows = [concatenate_row(row) for row in eachrow(nf)]
    nf = hcat(transformed_rows...)
    # display(nf) #! not correct, should be all one matrix

    order = [molgraph.eprops[e][:order] for e in edges(molgraph)]
    ef = order'

    # nf = node features, ef = edge features
    # Needs to be R^(n_features, n_nodes)
    fg = FeaturedGraph(g, nf = nf, ef = ef)
    # X, y = data.features, onehotbatch(data.targets, 1:7)

    # Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
    # The last dimension in each tensor is the observation dimension, i.e. the one
    # divided into mini-batches.
    # return fg
    return fg
end

g = make_graph_from_smiles("C=CCC")

FeaturedGraph:
	Undirected graph with (#V=4, #E=3) in adjacency matrix
	Node feature:	ℝ^10 <GraphSignals.NodeSignal{Matrix{Float32}}>
	Edge feature:	ℝ^1 <GraphSignals.EdgeSignal{LinearAlgebra.Adjoint{Int64, Vector{Int64}}}>

In [23]:
# Iterate over molecules in dataset and build graph for each one
# Initially sample data for hydrocarbons
#! isobutane, isopentane not defined for SAFTVRMie
species = [
    "methane",
    "ethane",
    "propane",
    "butane",
    "pentane",
    "hexane",
    "heptane",
    "octane",
    "nonane",
    "decane",
]

# Define smiles map
smiles_map = Dict(
    "methane" => "C",
    "ethane" => "CC",
    "propane" => "CCC",
    "butane" => "CCCC",
    "isobutane" => "CC(C)C",
    "pentane" => "CCCCC",
    "isopentane" => "CC(C)CC",
    "hexane" => "CCCCCC",
    "heptane" => "CCCCCCC",
    "octane" => "CCCCCCCC",
    "nonane" => "CCCCCCCCC",
    "decane" => "CCCCCCCCCC",
)

# X data contains graph, V, T
# Y data contains a_res
#* Sampling data along saturation curve
T = Float32
X_data = Vector{Tuple{typeof(g),T,T, String}}([])
Y_data = Vector{Vector{T}}()

n = 30
for s in species
    # model = GERG2008([s])
    model = SAFTVRMie([s])
    Tc, pc, Vc = crit_pure(model)
    smiles = smiles_map[s]

    # fingerprint = make_fingerprint(smiles)
    fg = make_graph_from_smiles(smiles)

    T_range = range(0.5 * Tc, 0.99 * Tc, n)
    # V_range = range(0.5 * Vc, 1.5 * Vc, n) # V could be sampled from a logspace
    for T in T_range
        (p₀, V_vec...) = saturation_pressure(model, T)
        for V in V_vec
            push!(X_data, (fg, V, T, s))
            a = a_res(model, V, T, [1.0])
            push!(Y_data, Float32[a])
        end
    end
end

In [21]:
X_data[1]

(FeaturedGraph:
	Undirected graph with (#V=1, #E=0) in adjacency matrix
	Node feature:	ℝ^10 <GraphSignals.NodeSignal{Matrix{Float32}}>
	Edge feature:	ℝ^1 <GraphSignals.EdgeSignal{LinearAlgebra.Adjoint{Int64, Vector{Int64}}}>, 3.6714497f-5, 97.726006f0, "methane")

In [15]:
@with_kw mutable struct Args
    η = 0.001               # learning rate
    batch_size = 8          # batch size
    epochs = 20             # number of epochs
    seed = 0                # random seed
    cuda = false            # use GPU
    heads = 2               # attention heads
    input_dim = 10          # input dimension #! What is this? -> ChatGPT suggests the number of features per node
    hidden_dim = 2          # hidden dimension
    target_dim = 7          # target dimension
    dataset = Cora          # dataset to train on
end

Args

In [16]:
function train(data)
    args = Args()
    Random.seed!(args.seed)

    if args.cuda && CUDA.has_cuda()
        device = gpu
        CUDA.allowscalar(false)
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    model = Chain(
        WithGraph(fg, GATConv(args.input_dim => args.hidden_dim, heads=args.heads, concat=true)),
        Dropout(0.2),
        WithGraph(fg, GATConv(args.hidden_dim * args.heads => args.target_dim, heads=args.heads, concat=false)),
    ) |> device

end

train (generic function with 1 method)