# FEM Tutorial using JAX Implementation

This tutorial demonstrates how to use the JAX-based implementation of our finite element framework.

In [1]:
import jax.numpy as jnp
from src.jax.frame import Frame
from src.jax.element import Element
import matplotlib.pyplot as plt

## Creating a Simple Beam Element

First, let's create a beam element class that inherits from our base Element class.

In [2]:
class BeamElement(Element):
    def local_stiffness(self, coords):
        # Get element length
        dx = coords[1,0] - coords[0,0]
        dy = coords[1,1] - coords[0,1]
        L = jnp.sqrt(dx**2 + dy**2)
        
        # Beam element stiffness matrix in local coordinates
        K = jnp.array([
            [self.E*self.A/L, 0, 0, -self.E*self.A/L, 0, 0],
            [0, 12*self.E*self.I/L**3, 6*self.E*self.I/L**2, 0, -12*self.E*self.I/L**3, 6*self.E*self.I/L**2],
            [0, 6*self.E*self.I/L**2, 4*self.E*self.I/L, 0, -6*self.E*self.I/L**2, 2*self.E*self.I/L],
            [-self.E*self.A/L, 0, 0, self.E*self.A/L, 0, 0],
            [0, -12*self.E*self.I/L**3, -6*self.E*self.I/L**2, 0, 12*self.E*self.I/L**3, -6*self.E*self.I/L**2],
            [0, 6*self.E*self.I/L**2, 2*self.E*self.I/L, 0, -6*self.E*self.I/L**2, 4*self.E*self.I/L]
        ])
        return K
    
    def transformation_matrix(self, coords):
        # Calculate element orientation
        dx = coords[1,0] - coords[0,0]
        dy = coords[1,1] - coords[0,1]
        L = jnp.sqrt(dx**2 + dy**2)
        c = dx/L  # cosine
        s = dy/L  # sine
        
        # Build transformation matrix
        t = jnp.array([[c, s, 0], [-s, c, 0], [0, 0, 1]])
        T = jnp.block([[t, jnp.zeros((3,3))],
                       [jnp.zeros((3,3)), t]])
        return T

## Example: Cantilever Beam

Let's analyze a simple cantilever beam using our JAX implementation.

In [3]:
# Create frame
frame = Frame()

# Add nodes
frame.add_node(0, 0.0, 0.0)  # Fixed end
frame.add_node(1, 2.0, 0.0)  # Free end

# Add beam element
beam = BeamElement(
    nodes=(0, 1),
    E=200e9,  # Steel Young's modulus (Pa)
    A=0.01,   # Cross-sectional area (m²)
    I=8.33e-6 # Moment of inertia (m⁴)
)
frame.add_element(beam)

# Fix left end (all DOFs)
frame.fix_dof(0)  # x-translation
frame.fix_dof(1)  # y-translation
frame.fix_dof(2)  # rotation

# Apply load at free end
P = 1000.0  # 1kN downward force
loads = jnp.array([0.0, 0.0, 0.0, 0.0, -P, 0.0])

# Solve system
displacements = frame.solve(loads)

TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/jax-ml/jax/issues/4564 for more information.

## Results Analysis

In [4]:
# Print displacements at free end
print("Displacements at free end:")
print(f"x-displacement: {displacements[3]:.6f} m")
print(f"y-displacement: {displacements[4]:.6f} m")
print(f"rotation: {displacements[5]:.6f} rad")

# Theoretical solution for tip deflection of cantilever beam
L = 2.0  # beam length
theoretical_deflection = P * L**3 / (3 * beam.E * beam.I)
print(f"\nTheoretical tip deflection: {theoretical_deflection:.6f} m")
print(f"Numerical tip deflection: {abs(displacements[4]):.6f} m")

Displacements at free end:


NameError: name 'displacements' is not defined

## Visualizing the Beam

In [None]:
def plot_beam(frame, displacements, scale=1.0):
    plt.figure(figsize=(10, 6))
    
    # Plot undeformed shape
    x_coords = [frame.nodes[0][0], frame.nodes[1][0]]
    y_coords = [frame.nodes[0][1], frame.nodes[1][1]]
    plt.plot(x_coords, y_coords, 'k--', label='Undeformed')
    
    # Plot deformed shape
    x_deformed = [
        frame.nodes[0][0] + scale * displacements[0],
        frame.nodes[1][0] + scale * displacements[3]
    ]
    y_deformed = [
        frame.nodes[0][1] + scale * displacements[1],
        frame.nodes[1][1] + scale * displacements[4]
    ]
    plt.plot(x_deformed, y_deformed, 'r-', label='Deformed')
    
    plt.grid(True)
    plt.axis('equal')
    plt.legend()
    plt.title('Cantilever Beam Deformation')
    plt.xlabel('X (m)')
    plt.ylabel('Y (m)')
    plt.show()

# Plot the beam with exaggerated deformation
plot_beam(frame, displacements, scale=100)

## Performance Comparison

Let's compare the performance of the JAX implementation with the original numpy implementation.

In [None]:
import time
from jax import jit

# JIT compile the solve function
@jit
def solve_jitted(frame, loads):
    return frame.solve(loads)

# Timing comparison
n_runs = 1000

# Time regular solution
start = time.time()
for _ in range(n_runs):
    _ = frame.solve(loads)
regular_time = time.time() - start

# Time JIT-compiled solution (including compilation)
start = time.time()
for _ in range(n_runs):
    _ = solve_jitted(frame, loads)
jit_time = time.time() - start

print(f"Regular execution time (average over {n_runs} runs): {regular_time/n_runs*1000:.3f} ms")
print(f"JIT execution time (average over {n_runs} runs): {jit_time/n_runs*1000:.3f} ms")