In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# os.environ['JAX_PLATFORM_NAME'] = "cpu"
# os.environ['JAX_PLATFORMS'] = "cpu"

In [2]:
import jax.numpy as jnp
import jax
from jax import lax

import numpy as onp

import matplotlib.pyplot as plt
import random

from simulation.simulate_full import run_entire_simulation

from analyzers import defaultvalues as dv, database, loss as loss_anaylzer, gradients as grad_analyzer

database.set_filename("../data/grad_analyzer/nsimulations_scan.npz")

2024-10-11 22:14:14.559220: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
  from IPython.utils import traitlets as _traitlets


In [3]:
N_GRADIENTS = 4
nsimulations_values = onp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
print(nsimulations_values)

[1 2 3 4 5 6 7 8 9]


In [4]:
def print_info(grads):
    amean, astd = grad_analyzer.analyze_gradients_absolute(grads)
    mmean, mstd = grad_analyzer.analyze_gradients_magnitudal(grads)

    print(f"Absolute mean: {amean}, Absolute std: {astd}")
    print(f"Magnitudal mean: {mmean}, Magnitudal std: {mstd}")

In [None]:
for nsims in nsimulations_values:
    existing_keys = database.get_existing_keys()
    if nsims in existing_keys:
        print(f"Skipping {nsims}, was already computed")
        continue

    grads = []
    print(f"==== for {nsims} simulations ====")
    for i in range(1, N_GRADIENTS + 1):
        keys = jnp.array([random.randrange(0, 20000) for _ in range(nsims)])
        def simulation_wrapper(LJ_SIGMA_OO: float, key: int) -> float:
            prediction = run_entire_simulation(
                LJ_SIGMA_OO, 
                dv.N_STEPS, 
                dv.N_MOLECULES_PER_AXIS, 
                dv.N_SNAPSHOTS, 
                dv.N_Q, 
                key)
            reference = run_entire_simulation(
                dv.LJ_SIGMA_OO, 
                dv.N_STEPS, 
                dv.N_MOLECULES_PER_AXIS, 
                dv.N_SNAPSHOTS, 
                dv.N_Q, 
                key)
            return loss_anaylzer.L1_loss(prediction, reference)
        v_simulation_wrapper = jax.vmap(simulation_wrapper, in_axes=(None, 0))
        def multiple_simulation_wrapper(LJ_SIGMA_OO: float, keys) -> float:
            losses = v_simulation_wrapper(LJ_SIGMA_OO, keys)
            return jnp.mean(losses)

        grad_fn = jax.grad(multiple_simulation_wrapper, 0)
        grad = grad_fn(3.1, keys)
        print(grad)
        grads.append(grad)

    grads = jnp.array(grads)
    print_info(grads)
    database.save_intermediate_result(nsims, grads)

Skipping 1, was already computed
Skipping 2, was already computed
Skipping 3, was already computed
Skipping 4, was already computed
Skipping 5, was already computed
Skipping 6, was already computed
Skipping 7, was already computed
==== for 8 simulations ====


In [None]:
database.set_filename("../data/grad_analyzer/nsimulations_scan2.npz")
for nsims in nsimulations_values:
    existing_keys = database.get_existing_keys()
    if nsims in existing_keys:
        print(f"Skipping {nsims}, was already computed")
        continue

    grads = []
    print(f"==== for {nsims} simulations ====")
    for i in range(1, N_GRADIENTS + 1):
        keys = jnp.array([random.randrange(0, 20000) for _ in range(nsims)])
        def simulation_wrapper(LJ_SIGMA_OO: float, key: int) -> float:
            prediction = run_entire_simulation(
                LJ_SIGMA_OO, 
                dv.N_STEPS, 
                dv.N_MOLECULES_PER_AXIS, 
                dv.N_SNAPSHOTS, 
                dv.N_Q, 
                key)
            reference = run_entire_simulation(
                dv.LJ_SIGMA_OO, 
                dv.N_STEPS, 
                dv.N_MOLECULES_PER_AXIS, 
                dv.N_SNAPSHOTS, 
                dv.N_Q, 
                key)
            return loss_anaylzer.L1_loss(prediction, reference)
        def multiple_simulation_wrapper(LJ_SIGMA_OO: float) -> float:
            _, losses = lax.scan(lambda _, xscan: (0, simulation_wrapper(LJ_SIGMA_OO, xscan)), 0, keys)            
            return jnp.mean(losses)

        grad_fn = jax.grad(multiple_simulation_wrapper)
        grad = grad_fn(3.1)
        print(grad)
        grads.append(grad)

    grads = onp.array(grads)
    print_info(grads)
    database.save_intermediate_result(nsims, grads)

Skipping 1, was already computed
==== for 2 simulations ====
1.2006835602501748e+49


In [None]:
keys, values = database.load_result()
for keys, grads in zip(keys, values):
    print(f"==== for {keys//500}ps ====")
    print_info(grads)