In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# os.environ['JAX_PLATFORM_NAME'] = "cpu"
# os.environ['JAX_PLATFORMS'] = "cpu"

In [2]:
import jax.numpy as jnp
import jax

import numpy as onp

import matplotlib.pyplot as plt
import random

from simulation.simulate_full import run_entire_simulation

from analyzers import defaultvalues as dv, database, loss as loss_anaylzer, gradients as grad_analyzer

database.set_filename("../data/grad_analyzer/nsnaps_scan.npz")

2024-10-11 20:27:59.852341: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
  from IPython.utils import traitlets as _traitlets


In [3]:
N_GRADIENTS = 4
nsnaps_values = onp.array([50, 100, 150, 200, 250])
print(nsnaps_values)

[ 50 100 150 200 250]


In [4]:
def print_info(grads):
    amean, astd = grad_analyzer.analyze_gradients_absolute(grads)
    mmean, mstd = grad_analyzer.analyze_gradients_magnitudal(grads)

    print(f"Absolute mean: {amean}, Absolute std: {astd}")
    print(f"Magnitudal mean: {mmean}, Magnitudal std: {mstd}")

In [5]:
for nsnaps in nsnaps_values:
    existing_keys = database.get_existing_keys()
    if nsnaps in existing_keys:
        print(f"Skipping {nsnaps}, was already computed")
        continue

    grads = []
    print(f"==== for {nsnaps} snapshots ====")
    for i in range(1, N_GRADIENTS + 1):
        key = random.randrange(0, 20000)
        def simulation_wrapper(LJ_SIGMA_OO: float) -> float:
            prediction = run_entire_simulation(
                LJ_SIGMA_OO, 
                dv.N_STEPS, 
                dv.N_MOLECULES_PER_AXIS, 
                nsnaps, 
                dv.N_Q, 
                key)
            reference = run_entire_simulation(
                dv.LJ_SIGMA_OO, 
                dv.N_STEPS, 
                dv.N_MOLECULES_PER_AXIS, 
                nsnaps, 
                dv.N_Q, 
                key)
            return loss_anaylzer.L1_loss(prediction, reference)
        grad_fn = jax.grad(simulation_wrapper)
        grad = grad_fn(3.1)
        print(grad)
        grads.append(grad)

    grads = onp.array(grads)
    print_info(grads)
    database.save_intermediate_result(nsnaps, grads)

==== for 50 snapshots ====
1.0682571381189187e+51
3.6116808405700227e+49
6.388790338950187e+49
3.5521318653325285e+49
Absolute mean: 3.009457921418615e+50, Absolute std: 4.431556402772743e+50
Magnitudal mean: 49.985573221139035, Magnitudal std: 0.6109182353322882
==== for 100 snapshots ====
7.977811366993957e+48
-1.568379530775108e+52
-1.7122023968295645e+48
3.4814759459220925e+49
Absolute mean: 3.932075020243531e+51, Absolute std: 6.784870264801117e+51
Magnitudal mean: 49.71816335712697, Magnitudal std: 1.5032004092682942
==== for 150 snapshots ====
1.9738864952038912e+46
6.444583748853121e+47
-7.272063785885953e+48
-3.370201461158549e+45
Absolute mean: 1.9849078067961156e+48, Absolute std: 3.0634622275327867e+48
Magnitudal mean: 47.12345764741994, Magnitudal std: 1.2965306194621937
==== for 200 snapshots ====
3.2767264370563838e+47
1.453417854244644e+50
-4.872166936306448e+52
-1.2232716412626552e+46
Absolute mean: 1.2216837763462264e+52, Absolute std: 2.1076157679418553e+52
Magnituda

In [6]:
keys, values = database.load_result()
for key, grads in zip(keys, values):
    print(f"==== for {nsnaps} snapshots ====")
    print_info(grads)

==== for 250 snapshots ====
Absolute mean: 3.009457921418615e+50, Absolute std: 4.431556402772743e+50
Magnitudal mean: 49.985573221139035, Magnitudal std: 0.6109182353322882
==== for 250 snapshots ====
Absolute mean: 3.932075020243531e+51, Absolute std: 6.784870264801117e+51
Magnitudal mean: 49.71816335712697, Magnitudal std: 1.5032004092682942
==== for 250 snapshots ====
Absolute mean: 1.9849078067961156e+48, Absolute std: 3.0634622275327867e+48
Magnitudal mean: 47.12345764741994, Magnitudal std: 1.2965306194621937
==== for 250 snapshots ====
Absolute mean: 1.2216837763462264e+52, Absolute std: 2.1076157679418553e+52
Magnitudal mean: 49.11326893598928, Magnitudal std: 2.529108915897649
==== for 250 snapshots ====
Absolute mean: 2.9012189752554035e+51, Absolute std: 4.429702625413683e+51
Magnitudal mean: 50.86118445581357, Magnitudal std: 0.7040144559290721
