In [1]:
import torch
import time
import scipy
from memory_profiler import memory_usage
import cProfile
import pstats

from torchfem import Solid
from torchfem.materials import Isotropic

In [2]:
def get_cube(N):
    # Create nodes
    grid = torch.linspace(0, 1, N)
    x, y, z = torch.meshgrid(grid, grid, grid, indexing="ij")
    nodes = torch.vstack([x.ravel(), y.ravel(), z.ravel()]).T

    # Create elements
    indices = torch.arange(N**3).reshape((N, N, N))
    n0 = indices[:-1, :-1, :-1].ravel()
    n1 = indices[1:, :-1, :-1].ravel()
    n2 = indices[:-1, 1:, :-1].ravel()
    n3 = indices[1:, 1:, :-1].ravel()
    n4 = indices[:-1, :-1, 1:].ravel()
    n5 = indices[1:, :-1, 1:].ravel()
    n6 = indices[:-1, 1:, 1:].ravel()
    n7 = indices[1:, 1:, 1:].ravel()
    elements = torch.vstack([n0, n1, n3, n2, n4, n5, n7, n6]).T

    # Material model
    material = Isotropic(E=1000.0, nu=0.3)

    # Define cube
    cube = Solid(nodes, elements, material)

    # Assign boundary conditions
    cube.forces = torch.zeros_like(nodes, requires_grad=True)
    cube.constraints[nodes[:, 0] == 0.0, :] = True
    cube.constraints[nodes[:, 0] == 1.0, 0] = True
    cube.displacements[nodes[:, 0] == 1.0, 0] = 0.1

    return cube

In [3]:
# cProfile
cProfile.run("get_cube(30).solve()", "stats")

In [4]:
p = pstats.Stats("stats")
p.sort_stats("tottime").print_stats(10)

Fri Oct 25 19:07:31 2024    stats

         27780 function calls (27384 primitive calls) in 7.358 seconds

   Ordered by: internal time
   List reduced from 596 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    2.286    1.143    2.286    1.143 {method 'coalesce' of 'torch._C.TensorBase' objects}
        4    1.877    0.469    1.877    0.469 {built-in method torch.matmul}
        4    1.375    0.344    1.375    0.344 {built-in method torch.isin}
      150    0.634    0.004    0.634    0.004 {built-in method scipy.sparse._sparsetools.csr_matvec}
       12    0.387    0.032    0.387    0.032 {built-in method torch.einsum}
       45    0.168    0.004    0.168    0.004 {built-in method torch.stack}
        2    0.111    0.055    3.828    1.914 /Users/meyernil/Code/torch-fem/src/torchfem/base.py:96(assemble_stiffness)
        4    0.077    0.019    0.077    0.019 {method 'sum' of 'torch._C.TensorBase' objects}
        1    0.0

<pstats.Stats at 0x11aedf160>

In [5]:
results = {}
for N in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
    print(f"Running N={N}")
    box = get_cube(N)
    dofs = box.n_dofs

    # Forward pass
    start_time = time.time()
    mem_usage, (u, f, sigma, epsilon, state) = memory_usage(
        lambda: box.solve(), retval=True, interval=0.1
    )
    end_time = time.time()
    fwd_mem_usage = max(mem_usage) - min(mem_usage)
    fwd_time = end_time - start_time
    print(f"  ... forward pass with {dofs} DOFs done in {fwd_time:.2f}s.")

    # Backward pass
    start_time = time.time()
    mem_usage = memory_usage(lambda: u.sum().backward(retain_graph=True), interval=0.1)
    end_time = time.time()
    bwd_mem_usage = max(mem_usage) - min(mem_usage)
    bwd_time = end_time - start_time
    print(f"  ... backward pass with {dofs} DOFs done in {bwd_time:.2f}.")

    results[N] = (
        dofs,
        fwd_time,
        fwd_mem_usage,
        bwd_time,
        bwd_mem_usage,
    )

Running N=10
  ... forward pass with 3000 DOFs done in 0.45s.
  ... backward pass with 3000 DOFs done in 0.61.
Running N=20
  ... forward pass with 24000 DOFs done in 4.15s.
  ... backward pass with 24000 DOFs done in 2.77.
Running N=30
  ... forward pass with 81000 DOFs done in 5.78s.
  ... backward pass with 81000 DOFs done in 1.47.
Running N=40
  ... forward pass with 192000 DOFs done in 14.70s.
  ... backward pass with 192000 DOFs done in 4.11.
Running N=50
  ... forward pass with 375000 DOFs done in 31.28s.
  ... backward pass with 375000 DOFs done in 9.31.
Running N=60


KeyboardInterrupt: 

In [None]:
# Format results as a table
print("|  N  |    DOFs | FWD Time |  FWD Memory | BWD Time |  BWD Memory |")
print("| --- | ------- | -------- | ----------- | -------- | ----------- |")
for N, (dofs, fwd_t, fwd_mem, bwd_t, bwd_mem) in results.items():
    print(
        f"| {N:3d} | {dofs:7d} |"
        f" {fwd_t:7.2f}s |  {fwd_mem:7.2f} MB |"
        f" {bwd_t:7.2f}s |  {bwd_mem:7.2f} MB |"
    )


scipy.show_config()