In [1]:
using LinearAlgebra
using Dates

using Plots
using BenchmarkTools

BLAS.set_num_threads(1)

include("src/io.jl")
include("src/distances.jl")
include("src/optimizer.jl")
include("src/network.jl")
include("src/symmfunctions.jl")
include("src/base.jl")
include("src/pretraining.jl")

# Initialize the parameters
globalParms, MCParms, NNParms, preTrainParms, systemParmsList = parametersInit()

# Initialize the input data
inputs = inputInit(globalParms, NNParms, preTrainParms, systemParmsList)
if globalParms.mode == "training"
    model, opt, refRDFs = inputs
else
    model = inputs
end

Running ML-IMC in the training mode.
Building a model...
Chain(Dense(14 => 20, relu; bias=false), Dense(20 => 20, relu; bias=false), Dense(20 => 1; bias=false))
   Number of layers: 4 
   Number of neurons in each layer: [14, 20, 20, 1]


(Chain(Dense(14 => 20, relu; bias=false), Dense(20 => 20, relu; bias=false), Dense(20 => 1; bias=false)), AMSGrad(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}()), Any[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  1.004, 1.005, 1.006, 1.005, 1.006, 1.006, 1.007, 1.007, 1.007, 1.007], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  1.003, 1.003, 1.003, 1.004, 1.004, 1.004, 1.005, 1.005, 1.004, 1.005], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  1.004, 1.004, 1.004, 1.005, 1.005, 1.005, 1.005, 1.005, 1.005, 1.005], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  1.004, 1.004, 1.005, 1.004, 1.006, 1.003, 1.004, 1.004, 1.005, 1.004]])

In [2]:
    println("Using the following symmetry functions as the neural input for each atom:")
    if NNParms.G2Functions != []
        println("    G2 symmetry functions:")
        println("    eta, Å^-2; rcutoff, Å; rshift, Å")
        for G2Function in NNParms.G2Functions
            println("       ", G2Function)
        end
    end
    if NNParms.G9Functions != []
        println("    G9 symmetry functions:")
        println("    eta, Å^-2; lambda; zeta; rcutoff, Å; rshift, Å")
        for G9Function in NNParms.G9Functions
            println("       ", G9Function)
        end
    end

Using the following symmetry functions as the neural input for each atom:
    G2 symmetry functions:
    eta, Å^-2; rcutoff, Å; rshift, Å
       G2(0.003, 6.0, 0.0)
       G2(0.03, 6.0, 0.0)
       G2(0.1, 6.0, 0.0)
       G2(0.2, 6.0, 0.0)
       G2(0.5, 6.0, 3.35)
       G2(1.0, 6.0, 3.35)
       G2(2.0, 6.0, 3.35)
       G2(5.0, 6.0, 3.35)
       G2(0.5, 6.0, 4.5)
       G2(1.0, 6.0, 4.5)
       G2(2.0, 6.0, 4.5)
       G2(5.0, 6.0, 4.5)
    G9 symmetry functions:
    eta, Å^-2; lambda; zeta; rcutoff, Å; rshift, Å
       G9(0.125, 1, 1.0, 6.0, 0.0)
       G9(0.125, -1, 1.0, 6.0, 0.0)


In [3]:
systemParms = systemParmsList[1];

In [28]:
traj = readXTC(systemParms)
frame = read_step(traj, 1)
coordinates = positions(frame)
box = lengths(UnitCell(frame));

In [29]:
distanceMatrix = buildDistanceMatrix(frame)

512×512 Matrix{Float64}:
  0.0       7.09189  15.855    17.4607   …   9.68317  15.981    20.8271
  7.09189   0.0      19.4916   14.7913      14.8578   14.2792   17.9268
 15.855    19.4916    0.0      11.6814      19.1468   16.8733   23.2905
 17.4607   14.7913   11.6814    0.0         14.4029   18.7578   16.7659
 12.0278    5.86375  16.1165   13.7371      18.4128   16.9396   16.4657
  6.82342   7.67516  16.2445   13.3721   …  12.6422   17.2571   20.5049
 13.0989   18.7139   17.984    24.9762      11.0264   17.6202   18.7331
 11.3263   10.5723   11.0429   10.3943      14.228    21.6119   20.0425
 19.6562   18.8416   18.4182   13.9998      13.8353   11.5905    6.22577
 15.2346   17.5955   23.1113   24.493       16.9502   11.3124   10.3767
  9.03241  12.3969   11.9791   16.8481   …  15.5818   15.1848   18.55
 15.0417   17.157    15.4381   15.8653      13.3199   15.7629   12.5644
 15.2687   12.0559    9.79153  15.5865      16.0364   18.6892   24.1867
  ⋮                                     

In [30]:
G2Matrix = buildG2Matrix(distanceMatrix, NNParms.G2Functions)

512×12 Matrix{Float64}:
 2.35274  1.88407  1.31836  1.08508  …  0.821824  0.615639  0.475041
 2.57931  2.11925  1.48816  1.16131     0.554418  0.296164  0.174102
 2.88216  2.22165  1.4174   1.0981      1.31263   0.986246  0.628015
 2.95723  2.28464  1.46019  1.12002     1.2689    0.963069  0.621101
 2.66877  2.05593  1.34333  1.0794      1.27734   1.0632    0.780847
 2.42679  1.94255  1.35277  1.10012  …  0.776183  0.543296  0.367214
 2.55469  1.93678  1.263    1.04644     1.35606   1.15287   0.808196
 2.77007  2.19655  1.48137  1.15654     0.917392  0.67184   0.464642
 2.497    2.0022   1.38295  1.10922     0.780579  0.50403   0.290404
 2.7829   2.21422  1.46328  1.12337     0.98522   0.602758  0.244447
 3.19812  2.48002  1.5701   1.16437  …  1.18      0.784615  0.437206
 2.8682   2.26575  1.5035   1.15549     0.950804  0.66364   0.407923
 3.05587  2.41323  1.5789   1.18409     1.00831   0.727419  0.51327
 ⋮                                   ⋱            ⋮         
 2.62699  2.05608  

In [8]:
@btime buildG2Matrix($distanceMatrix, $NNParms.G2Functions);

  22.548 ms (514 allocations: 2.11 MiB)


In [31]:
pointIndex = 1;

In [32]:
distanceVector1 = distanceMatrix[:, pointIndex]

512-element Vector{Float64}:
  0.0
  7.091892808138942
 15.854950124730799
 17.4607272015029
 12.027773352854163
  6.82341543296718
 13.098859274949335
 11.326277833108653
 19.656198560684448
 15.23458176314022
  9.032409199465086
 15.041669488657645
 15.268684775800134
  ⋮
 18.560362730062472
 14.736974614013517
 20.77740909565999
 16.196676724150503
 15.990429100693698
 11.823270083877945
 13.597919877031464
 13.102531054745459
 14.870004024270196
  9.683171701468305
 15.980994682509369
 20.827134860816255

In [33]:
positions(frame)[:, pointIndex]

3-element Vector{Float64}:
 18.08000087738037
  4.760000109672546
 27.280001640319824

In [34]:
positions(frame)[:, pointIndex] .+= [0.5, 1.0, 0.25]

3-element view(::Chemfiles.ChemfilesArray, :, 1) with eltype Float64:
 18.58000087738037
  5.760000109672546
 27.530001640319824

In [35]:
point = positions(frame)[:, pointIndex];

In [36]:
distanceVector2 = computeDistanceVector(point, positions(frame), box);

In [37]:
distanceMatrix[pointIndex, :] = distanceVector2
distanceMatrix[:, pointIndex] = distanceVector2

512-element Vector{Float64}:
  0.0
  7.851738727625623
 15.78803169183056
 16.631638615967013
 12.54249185439215
  6.837141116555358
 12.68552778681243
 11.09071589699985
 19.345507005851747
 15.436234584607027
  9.221546281023974
 15.05338241196147
 15.068249564958611
  ⋮
 17.57360159287757
 14.979349826898726
 20.85332870075784
 17.159463758743524
 16.639601056092584
 11.524418177173946
 12.706137202826636
 12.251152215277234
 15.737360666709632
  8.59396958441395
 17.080458098351816
 21.14079528696429

In [38]:
G2MatrixUpdated = buildG2Matrix(distanceMatrix, NNParms.G2Functions)

512×12 Matrix{Float64}:
 2.53595  2.07024  1.44638  1.14218  …  0.726923  0.500109  0.339733
 2.57931  2.11925  1.48816  1.16131     0.554418  0.296164  0.174102
 2.88216  2.22165  1.4174   1.0981      1.31263   0.986246  0.628015
 2.95723  2.28464  1.46019  1.12002     1.2689    0.963069  0.621101
 2.66877  2.05593  1.34333  1.0794      1.27734   1.0632    0.780847
 2.42679  1.94255  1.35277  1.10012  …  0.776183  0.543296  0.367214
 2.55469  1.93678  1.263    1.04644     1.35606   1.15287   0.808196
 2.77007  2.19655  1.48137  1.15654     0.917392  0.67184   0.464642
 2.497    2.0022   1.38295  1.10922     0.780579  0.50403   0.290404
 2.7829   2.21422  1.46328  1.12337     0.98522   0.602758  0.244447
 3.19812  2.48002  1.5701   1.16437  …  1.18      0.784615  0.437206
 2.8682   2.26575  1.5035   1.15549     0.950804  0.66364   0.407923
 3.05587  2.41323  1.5789   1.18409     1.00831   0.727419  0.51327
 ⋮                                   ⋱            ⋮         
 2.62699  2.05608  

In [39]:
sum(G2MatrixUpdated .- G2Matrix)

2.066260676296225

In [40]:
sum(G2MatrixUpdated[pointIndex, :] .- G2Matrix[pointIndex, :])

1.0331303381481123

In [41]:
G2MatrixOriginal = copy(G2Matrix);

In [42]:
G2Matrix == G2MatrixOriginal

true

In [43]:
updateG2Matrix!(G2Matrix, distanceVector1, distanceVector2, systemParms, NNParms.G2Functions, pointIndex);

In [44]:
G2Matrix == G2MatrixOriginal

false

In [45]:
sum(G2MatrixUpdated .- G2MatrixOriginal)

2.066260676296225

In [46]:
sum(G2MatrixUpdated .- G2Matrix)

-1.27675647831893e-15

In [25]:
@btime updateG2Matrix!($G2Matrix, $distanceVector1, $distanceVector2, $systemParms, $NNParms.G2Functions, $pointIndex);

  107.760 μs (0 allocations: 0 bytes)


In [48]:
updateG2Matrix!(G2Matrix, distanceVector2, distanceVector1, systemParms, NNParms.G2Functions, pointIndex);

In [50]:
sum(G2MatrixOriginal .- G2Matrix)

-1.6653345369377348e-16

In [6]:
model

Chain(
  Dense(20 => 20, relu; bias=false),    [90m# 400 parameters[39m
  Dense(20 => 20, relu; bias=false),    [90m# 400 parameters[39m
  Dense(20 => 1; bias=false),           [90m# 20 parameters[39m
) [90m                  # Total: 3 arrays, [39m820 parameters, 6.664 KiB.

In [65]:
inputdata = rand(Float64, (512, 14));
inputlayer = inputdata[1, :];

In [67]:
@btime atomicEnergy($inputlayer, $model)

  275.132 ns (6 allocations: 1.00 KiB)


0.296728100444809

In [92]:
@btime totalEnergyScalar(inputdata, model)

  154.099 μs (3585 allocations: 600.02 KiB)


387.64894878504174

In [87]:
atomicEnergy(inputdata[1, :], model)

0.296728100444809

In [7]:
energyGradients = computeEnergyGradients(G2Matrix, model)

3-element Vector{Matrix{Float64}}:
 [2415.920543679994 2547.4565519338016 … 2191.36532301734 1912.3621881613617; 0.0 0.0 … 0.0 0.0; … ; 3034.7183653662732 3199.945132032516 … 2752.6470638218057 2402.181921432444; 0.0 0.0 … 0.0 0.0]
 [-1627.2656754914422 -0.0 … -1116.289223062292 -0.0; 2578.3771711348572 0.0 … 1768.7429240824147 0.0; … ; 1373.4965021267865 0.0 … 942.2059140864284 0.0; 0.0 0.0 … 0.0 0.0]
 [5444.763771149569 7445.66566094606 … 9993.053626551657 0.0]

In [8]:
@btime computeEnergyGradients($G2Matrix, $model)

  6.449 ms (20057 allocations: 88.38 MiB)


3-element Vector{Matrix{Float64}}:
 [2415.920543679994 2547.4565519338016 … 2191.36532301734 1912.3621881613617; 0.0 0.0 … 0.0 0.0; … ; 3034.7183653662732 3199.945132032516 … 2752.6470638218057 2402.181921432444; 0.0 0.0 … 0.0 0.0]
 [-1627.2656754914422 -0.0 … -1116.289223062292 -0.0; 2578.3771711348572 0.0 … 1768.7429240824147 0.0; … ; 1373.4965021267865 0.0 … 942.2059140864284 0.0; 0.0 0.0 … 0.0 0.0]
 [5444.763771149569 7445.66566094606 … 9993.053626551657 0.0]

In [5]:
cutoff = 6.0 # Å
rs = 0.0 # Å 
eta = 0.1 # Å^-2
lambda = 1.0
zeta = 1.0;

In [6]:
distanceVector = distanceMatrix[:, 1];

In [7]:
function computeCosAngle(coordinates, i, j, k, distance_ij, distance_ik)::Float64
    @assert i != j && i != k && k != j 
    vector_0i = coordinates[:, i]
    vector_ij = coordinates[:, j] .- vector_0i
    vector_ik = coordinates[:, k] .- vector_0i
    cosAngle = dot(vector_ij, vector_ik) / (distance_ij * distance_ik)
    return (cosAngle)
end

computeCosAngle (generic function with 1 method)

In [8]:
@btime computeCosAngle($coordinates, 1, 461, 483, distanceVector[461], distanceVector[483])

  142.165 ns (9 allocations: 480 bytes)


0.7169325422695667

In [9]:
function G9(cosAngle, distance_ij, distance_ik, cutoff, eta, zeta, lambda=1.0, rshift=0.0)::Float64
    return (
        (1.0 + lambda * cosAngle)^zeta * 
        exp(-eta * (
                (distance_ij - rshift)^2 + 
                (distance_ik - rshift)^2)
                ) * 
        distanceCutoff(distance_ij, cutoff) * 
        distanceCutoff(distance_ik, cutoff))
end        

G9 (generic function with 3 methods)

In [10]:
function G3(cosAngle, distance_ij, distance_ik, distance_kj, cutoff, eta, zeta, lambda=1.0, rshift=0.0)::Float64
    return (
        (1.0 + lambda * cosAngle)^zeta * 
        exp(-eta * (
                (distance_ij - rshift)^2 + 
                (distance_ik - rshift)^2 +
                (distance_kj - rshift)^2)) * 
        distanceCutoff(distance_ij, cutoff) * 
        distanceCutoff(distance_ik, cutoff) *
        distanceCutoff(distance_kj, cutoff))
end        

G3 (generic function with 3 methods)

In [11]:
function G9total(i, coordinates, distanceMatrix, cutoff, eta, zeta, lambda=1.0, rshift=0.0)::Float64
    sum = 0.0
    distanceVector = distanceMatrix[:, i];
    N = length(distanceVector)
    @inbounds for k in eachindex(distanceVector)
        distance_ik = distanceVector[k]
        @inbounds for j in 1:k-1
            distance_ij = distanceVector[j]
            if 0 < distance_ij < cutoff && 0 < distance_ik < cutoff
                cosAngle = computeCosAngle(coordinates, i, j, k, distance_ij, distance_ik)
                sum += G9(cosAngle, distance_ij, distance_ik, cutoff, eta, zeta, lambda, rshift)
            end
        end
    end
    return (2.0^(1.0 - zeta) * sum)
end

G9total (generic function with 3 methods)

In [12]:
@btime G9total(5, $coordinates, $distanceMatrix, cutoff, eta, 1.0, 1.0)

  11.758 μs (333 allocations: 29.95 KiB)


0.6491223630240601

In [13]:
G9total(5, coordinates, distanceMatrix, cutoff, eta, 1.0, 1.0)

0.6491223630240601

I will implement the G3total function later (because an additional distance has to be computed)

In [15]:
@btime buildDistanceMatrix_old($frame);

  6.737 ms (2 allocations: 2.00 MiB)


In [16]:
@btime buildDistanceMatrix($frame);

  5.273 ms (3080 allocations: 12.27 MiB)


In [17]:
distanceMatrix1 = buildDistanceMatrix_old(frame)
distanceMatrix2 = buildDistanceMatrix(frame)
distanceMatrix1 == distanceMatrix2

true

In [18]:
distanceVector = distanceMatrix[:, 1];

In [19]:
@btime $distanceMatrix[:, $1];

  435.050 ns (1 allocation: 4.12 KiB)


In [31]:
coordinates = positions(frame);
r1 = coordinates[:, 1]
r2 = coordinates[:, 2]

3-element Vector{Float64}:
 15.830000638961792
 32.47000217437744
 31.760001182556152

In [35]:
distanceVector1 = distanceMatrix[:, 2];

In [37]:
distanceVector2 = computeDistanceVector(r2, coordinates, box);

In [33]:
distanceVector1 == distanceVector2

true

In [40]:
@btime updateDistance!($frame, $distanceVector, $1);

  10.216 μs (0 allocations: 0 bytes)


In [41]:
@btime computeDistanceVector($r1, $coordinates, $box);

  6.400 μs (5 allocations: 20.45 KiB)
