In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# os.environ['JAX_PLATFORM_NAME'] = "cpu"
# os.environ['JAX_PLATFORMS'] = "cpu"

In [2]:
import jax.numpy as jnp
import jax

import numpy as onp

import matplotlib.pyplot as plt
import random

from simulation.simulate_full import run_entire_simulation

from analyzers import defaultvalues as dv, database, loss as loss_anaylzer, gradients as grad_analyzer

database.set_filename("../data/grad_analyzer/nmolecules_scan.npz")

  from IPython.utils import traitlets as _traitlets


In [3]:
N_GRADIENTS = 4
nmolecules_values = onp.array([3, 4, 5, 6, 7, 8])
print(nmolecules_values)

[3 4 5 6 7 8]


In [4]:
def print_info(grads):
    amean, astd = grad_analyzer.analyze_gradients_absolute(grads)
    mmean, mstd = grad_analyzer.analyze_gradients_magnitudal(grads)

    print(f"Absolute mean: {amean}, Absolute std: {astd}")
    print(f"Magnitudal mean: {mmean}, Magnitudal std: {mstd}")

In [5]:
for nmolecules in nmolecules_values:
    existing_keys = database.get_existing_keys()
    if nmolecules in existing_keys:
        print(f"Skipping {nmolecules}, was already computed")
        continue

    grads = []
    print(f"==== for {nmolecules}**3={nmolecules**3} molecules ====")
    for i in range(1, N_GRADIENTS + 1):
        key = random.randrange(0, 20000)
        def simulation_wrapper(LJ_SIGMA_OO: float) -> float:
            prediction = run_entire_simulation(
                LJ_SIGMA_OO, 
                dv.N_STEPS, 
                nmolecules, 
                dv.N_SNAPSHOTS, 
                dv.N_Q, 
                key)
            reference = run_entire_simulation(
                dv.LJ_SIGMA_OO, 
                dv.N_STEPS, 
                nmolecules, 
                dv.N_SNAPSHOTS, 
                dv.N_Q, 
                key)
            return loss_anaylzer.L1_loss(prediction, reference)
        grad_fn = jax.grad(simulation_wrapper)
        grad = grad_fn(3.1)
        print(grad)
        grads.append(grad)

    grads = onp.array(grads)
    print_info(grads)
    database.save_intermediate_result(nmolecules, grads)

Skipping 3, was already computed
Skipping 4, was already computed
Skipping 5, was already computed
==== for 6**3=216 molecules ====
-2.5359074460296814e+50
-2.7033722160245724e+50
-1.5886815063374054e+47
-1.52802545920439e+46
Absolute mean: 1.3102552865266278e+50, Absolute std: 1.3107225866841572e+50
Magnitudal mean: 48.555301668799096, Magnitudal std: 1.8971233166391877
==== for 7**3=343 molecules ====
1.967320961130483e+53
1.4035513149550273e+53
5.372182364263319e+49
-2.531719744258954e+53
Absolute mean: 1.4757823096452228e+53, Absolute std: 9.405024598641754e+52
Magnitudal mean: 52.39366747251086, Magnitudal std: 1.5404658863297365
==== for 8**3=512 molecules ====
3.0627956928982643e+53
3.6409251744937606e+53
-6.2788371987991e+51
7.553082923821574e+51
Absolute mean: 1.7105100171545578e+53, Absolute std: 1.6540346726547596e+53
Magnitudal mean: 52.68083331418978, Magnitudal std: 0.8437267673576795


In [8]:
keys, values = database.load_result()
for key, grads in zip(keys, values):
    print(f"==== for {int(key)}**3={int(key**3)} molecules ====")
    print_info(grads)

==== for 3**3=27 molecules ====
Absolute mean: 2.808480255894132e+48, Absolute std: 3.713909149339848e+48
Magnitudal mean: 45.70402646442572, Magnitudal std: 3.020062853441939
==== for 4**3=64 molecules ====
Absolute mean: 3.4976435941441275e+48, Absolute std: 5.996177702802229e+48
Magnitudal mean: 45.605430405466464, Magnitudal std: 2.8138802757444017
==== for 5**3=125 molecules ====
Absolute mean: 3.827117042634167e+47, Absolute std: 3.637767069908443e+47
Magnitudal mean: 47.15894721881544, Magnitudal std: 0.7511470226584356
==== for 6**3=216 molecules ====
Absolute mean: 1.3102552865266278e+50, Absolute std: 1.3107225866841572e+50
Magnitudal mean: 48.555301668799096, Magnitudal std: 1.8971233166391877
==== for 7**3=343 molecules ====
Absolute mean: 1.4757823096452228e+53, Absolute std: 9.405024598641754e+52
Magnitudal mean: 52.39366747251086, Magnitudal std: 1.5404658863297365
==== for 8**3=512 molecules ====
Absolute mean: 1.7105100171545578e+53, Absolute std: 1.6540346726547596e+5