In [1]:
using LinearAlgebra
using Dates

using Plots
using BenchmarkTools

BLAS.set_num_threads(1)

include("src/io.jl")
include("src/distances.jl")
include("src/optimizer.jl")
include("src/network.jl")
include("src/symmfunctions.jl")
include("src/base.jl")
include("src/pretraining.jl")

# Initialize the parameters
globalParms, MCParms, NNParms, preTrainParms, systemParmsList = parametersInit()

# Initialize the input data
inputs = inputInit(globalParms, NNParms, preTrainParms, systemParmsList)
if globalParms.mode == "training"
    model, opt, refRDFs = inputs
else
    model = inputs
end

Running ML-IMC in the training mode.
Building a model...
Chain(Dense(16 => 20, relu; bias=false), Dense(20 => 20, relu; bias=false), Dense(20 => 1; bias=false))
   Number of layers: 4 
   Number of neurons in each layer: [16, 20, 20, 1]


(Chain(Dense(16 => 20, relu; bias=false), Dense(20 => 20, relu; bias=false), Dense(20 => 1; bias=false)), AMSGrad(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}()), Any[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  1.004, 1.005, 1.006, 1.005, 1.006, 1.006, 1.007, 1.007, 1.007, 1.007], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  1.003, 1.003, 1.003, 1.004, 1.004, 1.004, 1.005, 1.005, 1.004, 1.005], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  1.004, 1.004, 1.004, 1.005, 1.005, 1.005, 1.005, 1.005, 1.005, 1.005], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  1.004, 1.004, 1.005, 1.004, 1.006, 1.003, 1.004, 1.004, 1.005, 1.004]])

In [2]:
println("Using the following symmetry functions as the neural input for each atom:")
    if NNParms.G2Functions != []
        println("    G2 symmetry functions:")
        println("    eta, Å^-2; rcutoff, Å; rshift, Å")
        for G2Function in NNParms.G2Functions
            println("       ", G2Function)
        end
    end
    if NNParms.G3Functions != []
        println("    G3 symmetry functions:")
        println("    eta, Å^-2; lambda; zeta; rcutoff, Å; rshift, Å")
        for G3Function in NNParms.G3Functions
            println("       ", G3Function)
        end
    end
    if NNParms.G9Functions != []
        println("    G9 symmetry functions:")
        println("    eta, Å^-2; lambda; zeta; rcutoff, Å; rshift, Å")
        for G9Function in NNParms.G9Functions
            println("       ", G9Function)
        end
    end
    println("Maximum cutoff distance: $(NNParms.maxDistanceCutoff) Å")

Using the following symmetry functions as the neural input for each atom:
    G2 symmetry functions:
    eta, Å^-2; rcutoff, Å; rshift, Å
       G2(0.003, 6.0, 0.0)
       G2(0.03, 6.0, 0.0)
       G2(0.1, 6.0, 0.0)
       G2(0.2, 6.0, 0.0)
       G2(0.5, 6.0, 3.35)
       G2(1.0, 6.0, 3.35)
       G2(2.0, 6.0, 3.35)
       G2(5.0, 6.0, 3.35)
       G2(0.5, 6.0, 4.5)
       G2(1.0, 6.0, 4.5)
       G2(2.0, 6.0, 4.5)
       G2(5.0, 6.0, 4.5)
    G3 symmetry functions:
    eta, Å^-2; lambda; zeta; rcutoff, Å; rshift, Å
       G3(0.01, 1.0, 1.0, 6.0, 0.0)
       G3(0.01, -1.0, 1.0, 6.0, 0.0)
    G9 symmetry functions:
    eta, Å^-2; lambda; zeta; rcutoff, Å; rshift, Å
       G9(0.01, 1.0, 1.0, 6.0, 0.0)
       G9(0.01, -1.0, 1.0, 6.0, 0.0)
Maximum cutoff distance: 6.0 Å


In [3]:
systemParms = systemParmsList[1]
traj = readXTC(systemParms)
frame = read_step(traj, 1)
coordinates1 = deepcopy(positions(frame))
box = lengths(UnitCell(frame));

In [4]:
distanceMatrix = buildDistanceMatrix(frame);

In [5]:
G2Matrix = buildG2Matrix(distanceMatrix, NNParms.G2Functions);

In [6]:
#@btime buildG2Matrix($distanceMatrix, $NNParms.G2Functions);

In [7]:
G3Matrix = buildG3Matrix(distanceMatrix, coordinates1, box, NNParms.G3Functions)

512×2 Matrix{Float64}:
 0.0420455  0.0135732
 0.0565059  0.0433692
 0.15188    0.0444613
 0.161745   0.0489703
 0.075882   0.0259094
 0.0284074  0.0105423
 0.0778856  0.0189655
 0.0671209  0.0237215
 0.0750116  0.035236
 0.156217   0.0461813
 0.239378   0.0845712
 0.121736   0.046634
 0.143385   0.0631665
 ⋮          
 0.0770564  0.0289734
 0.110758   0.0398616
 0.0421053  0.0167195
 0.0471012  0.011058
 0.101296   0.036071
 0.145133   0.0528852
 0.0635515  0.0178394
 0.1905     0.0653554
 0.0974846  0.0318726
 0.187462   0.05528
 0.121085   0.032481
 0.108074   0.03024

In [8]:
#@btime buildG3Matrix($distanceMatrix, $coordinates1, $box, $NNParms.G3Functions);

In [9]:
pointIndex = 1;
distanceVector1 = distanceMatrix[:, pointIndex]
positions(frame)[:, pointIndex]

3-element Vector{Float64}:
 18.08000087738037
  4.760000109672546
 27.280001640319824

In [10]:
positions(frame)[:, pointIndex] .+= [1.0, 3.0, 1.0]

3-element view(::Chemfiles.ChemfilesArray, :, 1) with eltype Float64:
 19.08000087738037
  7.760000109672546
 28.280001640319824

In [11]:
point = positions(frame)[:, pointIndex];
distanceVector2 = computeDistanceVector(point, positions(frame), box);
distanceMatrix[pointIndex, :] = distanceVector2;
distanceMatrix[:, pointIndex] = distanceVector2;

In [12]:
coordinates2 = positions(frame);

In [13]:
coordinates2 == coordinates1

false

In [14]:
coordinates2[:, 2:end] == coordinates1[:, 2:end]

true

In [15]:
coordinates2[:, pointIndex] .- [1.0, 3.0, 1.0] == coordinates1[:, pointIndex]

true

In [16]:
distanceVector1 != distanceVector2

true

In [17]:
G3MatrixUpdated = buildG3Matrix(distanceMatrix, coordinates2, box, NNParms.G3Functions);

In [18]:
sum(G3MatrixUpdated .- G3Matrix)

1.2972781202556045

In [19]:
sum(G3MatrixUpdated[pointIndex, :] .- G3Matrix[pointIndex, :])

0.4324260400852014

In [20]:
G3MatrixOriginal = copy(G3Matrix);

In [22]:
G3Matrix == G3MatrixOriginal

true

In [97]:
include("src/symmfunctions.jl");

In [95]:
G3Matrix = deepcopy(G3MatrixOriginal);

In [23]:
updateG3Matrix!(G3Matrix, coordinates1, coordinates2, box, distanceVector1, distanceVector2, systemParms, NNParms.G3Functions, pointIndex);

In [24]:
sum(G3Matrix .- G3MatrixOriginal)

1.2972781202556045

In [25]:
sum(G3Matrix .- G3MatrixUpdated)

-1.6219664500383146e-16

In [27]:
updateG3Matrix!(G3Matrix, coordinates2, coordinates1, box, distanceVector2, distanceVector1, systemParms, NNParms.G3Functions, pointIndex);

In [28]:
sum(G3Matrix .- G3MatrixOriginal)

-3.469446951953614e-17

In [29]:
sum(G3Matrix .- G3MatrixUpdated)

-1.2972781202556045

In [30]:
@btime buildG3Matrix($distanceMatrix, $coordinates2, $box, $NNParms.G3Functions);

  23.213 ms (473433 allocations: 38.15 MiB)


In [32]:
@btime updateG3Matrix!($G3Matrix, $coordinates2, $coordinates1, $box, $distanceVector2, $distanceVector1, $systemParms, $NNParms.G3Functions, $pointIndex);

  1.944 ms (69692 allocations: 5.32 MiB)


It is a success! The updateG3Matrix! function works correctly and currently is roughly 20x faster than building the matrix from scratch. I know I can further optimize it a little bit. 2ms is too much of course.

Possible solutions: 
- optimize the update function by removing unnecessary computation/memory access
- reduce cutoff for the angular function
- use G9 function that is slightly faster than G3