In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# os.environ['JAX_PLATFORM_NAME'] = "cpu"
# os.environ['JAX_PLATFORMS'] = "cpu"

In [2]:
import jax.numpy as jnp
import jax

import numpy as onp

import matplotlib.pyplot as plt
import random

from simulation.simulate_full import run_entire_simulation

from analyzers import defaultvalues as dv, database, loss as loss_anaylzer, gradients as grad_analyzer

database.set_filename("../data/grad_analyzer/nq_scan.npz")

  from IPython.utils import traitlets as _traitlets


In [3]:
N_GRADIENTS = 4
nq_values = onp.array([50, 100, 150, 200, 250])
print(nq_values)

[ 50 100 150 200 250]


In [4]:
def print_info(grads):
    amean, astd = grad_analyzer.analyze_gradients_absolute(grads)
    mmean, mstd = grad_analyzer.analyze_gradients_magnitudal(grads)

    print(f"Absolute mean: {amean}, Absolute std: {astd}")
    print(f"Magnitudal mean: {mmean}, Magnitudal std: {mstd}")

In [5]:
for nq in nq_values:
    existing_keys = database.get_existing_keys()
    if nq in existing_keys:
        print(f"Skipping {nq}, was already computed")
        continue

    grads = []
    print(f"==== for {nq} Q-values ====")
    for i in range(1, N_GRADIENTS + 1):
        key = random.randrange(0, 20000)
        def simulation_wrapper(LJ_SIGMA_OO: float) -> float:
            prediction = run_entire_simulation(
                LJ_SIGMA_OO, 
                dv.N_STEPS, 
                dv.N_MOLECULES_PER_AXIS, 
                dv.N_SNAPSHOTS, 
                nq, 
                key)
            reference = run_entire_simulation(
                dv.LJ_SIGMA_OO, 
                dv.N_STEPS, 
                dv.N_MOLECULES_PER_AXIS, 
                dv.N_SNAPSHOTS, 
                nq, 
                key)
            return loss_anaylzer.L1_loss(prediction, reference)
        grad_fn = jax.grad(simulation_wrapper)
        grad = grad_fn(3.1)
        print(grad)
        grads.append(grad)
    
    grads = onp.array(grads)
    print_info(grads)
    database.save_intermediate_result(nq, grads)

Skipping 50, was already computed
==== for 100 Q-values ====
1.2226200114762503e+50
-7.13525851591908e+48
7.904072015923501e+49
2.3994094390734792e+51
Absolute mean: 6.519618547240646e+50, Absolute std: 1.0097270621974956e+51
Magnitudal mean: 50.054664116007885, Magnitudal std: 0.8979781351219543
==== for 150 Q-values ====
-3.2694871359513996e+46
-1.5104427011508784e+48
1.0686528830868277e+49
1.635580471535717e+49
Absolute mean: 7.14636777968396e+48, Absolute std: 6.702881146901153e+48
Magnitudal mean: 48.23402311653473, Magnitudal std: 1.0666871866167937
==== for 200 Q-values ====
1.8320048865880734e+47
7.967663246639285e+46
3.485870125036165e+51
-1.4512791255254742e+48
Absolute mean: 8.71896070320704e+50, Absolute std: 1.5091787207765348e+51
Magnitudal mean: 48.467079937262575, Magnitudal std: 1.833836698370992
==== for 250 Q-values ====
-1.2324535541075086e+51
5.413137009916332e+52
-1.7246436812227224e+49
-4.761389088832081e+47
Absolute mean: 1.3845386557247985e+52, Absolute std: 2.

In [6]:
keys, values = database.load_result()
for key, grads in zip(keys, values):
    print(f"==== for {nq} Q-values ====")
    print_info(grads)

==== for 250 Q-values ====
Absolute mean: 7.354254607638597e+49, Absolute std: 9.273142762666753e+49
Magnitudal mean: 48.50409943979843, Magnitudal std: 1.6404806816391635
==== for 250 Q-values ====
Absolute mean: 6.519618547240646e+50, Absolute std: 1.0097270621974956e+51
Magnitudal mean: 50.054664116007885, Magnitudal std: 0.8979781351219543
==== for 250 Q-values ====
Absolute mean: 7.14636777968396e+48, Absolute std: 6.702881146901153e+48
Magnitudal mean: 48.23402311653473, Magnitudal std: 1.0666871866167937
==== for 250 Q-values ====
Absolute mean: 8.71896070320704e+50, Absolute std: 1.5091787207765348e+51
Magnitudal mean: 48.467079937262575, Magnitudal std: 1.833836698370992
==== for 250 Q-values ====
Absolute mean: 1.3845386557247985e+52, Absolute std: 2.3264487704242612e+52
Magnitudal mean: 50.18466315875472, Magnitudal std: 1.9039873136271448
