In [None]:
import torch
from modelforge.dataset.dataset import single_batch
from modelforge.tests.helper_functinos import setup_potential_for_test
from loguru import logger as log


In [None]:
# Setup input 
potential_name = 'SchNet'


nnp_input = single_batch(64, 'QM0').nnp_input
nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0]

model = setup_potential_for_test(
    potential_name,
    "inference",
    potential_seed=42,
    use_training_mode_neighborlist=True,
    simulation_environment='PyTorch',
)


In [2]:
import socket
from datetime import datetime, timedelta
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000
def start_record_memory_history() -> None:
   if not torch.cuda.is_available():
       log.info("CUDA unavailable. Not recording memory history")
       return

   log.info("Starting snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(
       max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
   )

def stop_record_memory_history() -> None:
   if not torch.cuda.is_available():
       log.info("CUDA unavailable. Not recording memory history")
       return

   log.info("Stopping snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(enabled=None)

def export_memory_snapshot() -> None:
   if not torch.cuda.is_available():
       log.info("CUDA unavailable. Not exporting memory snapshot")
       return

   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   try:
       log.info(f"Saving snapshot to local file: {file_prefix}.pickle")
       torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
   except Exception as e:
       log.error(f"Failed to capture memory snapshot {e}")
       return


In [None]:
# Start recording memory snapshot history
start_record_memory_history()

for _ in range(5):
    # perform the forward pass through each of the models
    r = model(nnp_input)["per_molecule_energy"]
    # Compute the gradient (forces) from the predicted energies
    grad = torch.autograd.grad(
        r,
        nnp_input.positions,
        grad_outputs=torch.ones_like(r),
        create_graph=False,
        retain_graph=False,
    )[0]

# Create the memory snapshot file
export_memory_snapshot()

# Stop recording memory snapshot history
stop_record_memory_history()
