In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# os.environ['JAX_PLATFORM_NAME'] = "cpu"
# os.environ['JAX_PLATFORMS'] = "cpu"

In [2]:
import jax.numpy as jnp
import jax

import numpy as onp

import matplotlib.pyplot as plt
import random

from simulation.simulate_full import run_entire_simulation

from analyzers import defaultvalues as dv, database, loss as loss_anaylzer, gradients as grad_analyzer

database.set_filename("../data/grad_analyzer/runtime_scan.npz")

  from IPython.utils import traitlets as _traitlets


In [3]:
N_GRADIENTS = 4
runtime_values = onp.array([12, 20, 28, 36, 44])
nsteps_intervals = runtime_values * 1000 // 2
print(nsteps_intervals)

[ 6000 10000 14000 18000 22000]


In [4]:
def print_info(grads):
    amean, astd = grad_analyzer.analyze_gradients_absolute(grads)
    mmean, mstd = grad_analyzer.analyze_gradients_magnitudal(grads)

    print(f"Absolute mean: {amean}, Absolute std: {astd}")
    print(f"Magnitudal mean: {mmean}, Magnitudal std: {mstd}")

In [5]:
for nsteps in nsteps_intervals:
    existing_keys = database.get_existing_keys()
    if nsteps in existing_keys:
        print(f"Skipping {nsteps}, was already computed")
        continue

    grads = []
    print(f"==== for {nsteps//500}ps ====")
    for i in range(1, N_GRADIENTS + 1):
        key = random.randrange(0, 20000)
        def simulation_wrapper(LJ_SIGMA_OO: float) -> float:
            prediction = run_entire_simulation(
                LJ_SIGMA_OO, 
                nsteps, 
                dv.N_MOLECULES_PER_AXIS, 
                dv.N_SNAPSHOTS, 
                dv.N_Q, 
                key)
            reference = run_entire_simulation(
                dv.LJ_SIGMA_OO, 
                nsteps, 
                dv.N_MOLECULES_PER_AXIS, 
                dv.N_SNAPSHOTS, 
                dv.N_Q, 
                key)
            return loss_anaylzer.L1_loss(prediction, reference)
        grad_fn = jax.grad(simulation_wrapper)
        grad = grad_fn(3.1)
        print(grad)
        grads.append(grad)

    grads = onp.array(grads)
    print_info(grads)
    database.save_intermediate_result(nsteps, grads)

Skipping 6000, was already computed
Skipping 10000, was already computed
Skipping 14000, was already computed
==== for 36ps ====
-2.1608042764066084e+127
-1.7476490290659488e+127
-2.914285968125017e+124
8.09813867019835e+122
Absolute mean: 9.778621432068461e+126, Absolute std: 9.872313894022748e+126
Magnitudal mean: 125.48749675629713, Magnitudal std: 1.8834805562821817
==== for 44ps ====
4.444220911965936e+157
8.815746956963995e+154
7.532722785399551e+152
7.386824164600847e+152
Absolute mean: 1.1132964635980998e+157, Absolute std: inf
Magnitudal mean: 154.58461612520153, Magnitudal std: 1.9605132264016372


  x = um.multiply(x, x, out=x)


In [6]:
keys, values = database.load_result()
for key, grads in zip(keys, values):
    print(f"==== for {key//500}ps ====")
    print_info(grads)

==== for 12.0ps ====
Absolute mean: 9.554242244927689e+44, Absolute std: 1.2404851623432239e+45
Magnitudal mean: 44.6107374872459, Magnitudal std: 0.5559015695655483
==== for 20.0ps ====
Absolute mean: 7.368123095909171e+71, Absolute std: 1.0443981999086219e+72
Magnitudal mean: 70.95524023655834, Magnitudal std: 1.0888363453547292
==== for 28.0ps ====
Absolute mean: 9.550657164808937e+96, Absolute std: 1.6114563713485141e+97
Magnitudal mean: 94.79984378843099, Magnitudal std: 2.016615849115076
==== for 36.0ps ====
Absolute mean: 9.778621432068461e+126, Absolute std: 9.872313894022748e+126
Magnitudal mean: 125.48749675629713, Magnitudal std: 1.8834805562821817
==== for 44.0ps ====
Absolute mean: 1.1132964635980998e+157, Absolute std: inf
Magnitudal mean: 154.58461612520153, Magnitudal std: 1.9605132264016372
