In [None]:
from modelforge.tests.helper_functions import setup_potential_for_test
import torch
from modelforge.utils.profiling import start_record_memory_history, export_memory_snapshot, stop_record_memory_history, setup_waterbox_testsystem


In [None]:
# --------------------------------------------------- #
# This script demonstrates how to record memory usage #
# --------------------------------------------------- #
# define the potential, device and precision
potential_name = 'AimNet2'
precision = torch.float32
device = 'cuda'

# setup the input and model
nnp_input = setup_waterbox_testsystem(2.5, device=device, precision=precision)
model = setup_potential_for_test(
    potential_name,
    "training",
    potential_seed=42,
    use_training_mode_neighborlist=True,
    simulation_environment='PyTorch',
)['trainer']

# this is the function that will be profiled
def loop_to_record():
    for _ in range(5):
        # perform the forward pass through each of the models
        r = model(nnp_input)["per_system_energy"]
        # Compute the gradient (forces) from the predicted energies
        grad = torch.autograd.grad(
            r,
            nnp_input.positions,
            grad_outputs=torch.ones_like(r),
            create_graph=False,
            retain_graph=False,
        )[0]

def loop_to_record():
    model.train_potential()

In [None]:
# Start recording memory snapshot history
start_record_memory_history()
loop_to_record()
# Create the memory snapshot file
file_name = export_memory_snapshot()
print(file_name)
# Stop recording memory snapshot history
stop_record_memory_history()


In [None]:
# the memory snapshot is a pickle file, to visualize this 
# create a snapshot.html file using the following command
!python _memory_viz.py trace_plot a7srv5.pch.univie.ac.at_Nov_09_21_18_29.pickle -o snapshot.html