In [71]:
# Load the necessary libraries
import numpy as np
import sympy as sp
from sympy.physics.mechanics import dynamicsymbols
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

# This is for pretty printing
import IPython.display as disp

Constants

In [72]:
m1_crank_val, m2_val, m3_val = 1.0, 2.0, 0.5  # Masses (kg) - m1_crank_val is mass of crank part only
L1_val, L2_val = 0.5, 1.5                   # Lengths (m)
I1_crank_val, I2_val = 0.1, 1.5             # Moments of Inertia (kg*m^2) - I1_crank_val is for crank part only
F0_val = 50                                 # Force amplitude (N)
k_val = 1                                   # Damping coefficient for crank (Nms/rad)
g_val = 9.81                                # Gravity (m/s^2)

# --- MODIFICATION: Flywheel Properties ---
m_flywheel_val = 50.0  # Mass of the flywheel (kg) - Let's make it significant
r_flywheel_val = 0.25 # Radius of the flywheel (m) - Assuming a solid disk
I_flywheel_val = 0.5 * m_flywheel_val * r_flywheel_val**2 # Moment of inertia of a solid disk (kg*m^2)
# --- END MODIFICATION ---

In [73]:
t = sp.symbols('t') # Time variable
x1, x2, y1, y2, theta1, theta2, x3, y3 = dynamicsymbols('x1 x2 y1 y2 theta1 theta2 x3 y3') # Generalized coordinates as SymPy dynamicsymbols
q = sp.Matrix([x1, y1, theta1, x2, y2, theta2, x3, y3]) # Creates vector q containing all the generalized coordinates
dq = q.diff(t) # Creates vector dq containing the time derivatives of the generalized coordinates

# Define SymPy matrices for the 2D position vector of the CoM of each body
x_com_1 = sp.Matrix([x1, y1]) 
x_com_2 = sp.Matrix([x2, y2])
x_com_3 = sp.Matrix([x3, y3])

R = lambda theta: sp.Matrix([[sp.cos(theta), -sp.sin(theta)], [sp.sin(theta), sp.cos(theta)]]) # Lambda function R that takes angle theta and returns a 2D rotation matrix

# --- MODIFICATION: Update mass and inertia for body 1 (crank + flywheel) ---
m1_total_val = m1_crank_val + m_flywheel_val  # Total mass of crank + flywheel
I1_total_val = I1_crank_val + I_flywheel_val  # Total moment of inertia of crank + flywheel

M_np = np.diag([m1_total_val, m1_total_val, I1_total_val,  # Body 1 (crank + flywheel)
                m2_val, m2_val, I2_val,                  # Body 2 (conrod)
                m3_val, m3_val])                         # Body 3 (piston) - still 2 translational DoFs in M
# --- END MODIFICATION ---

W_np = np.linalg.inv(M) # Inverse of the mass matrix

# --- MODIFICATION: Update gravitational force for body 1 ---
Q = sp.Matrix([0,
               -m1_total_val * g_val,  # Gravity acts on the total mass of body 1
               -k_val * theta1.diff(t),
               0,
               -m2_val * g_val,
               0,
               0,
               -m3_val * g_val + F0_val * sp.cos(theta1)])
# --- END MODIFICATION ---

In [74]:
i_cap = sp.Matrix([1, 0]) # Unit vector in the x-direction
j_cap = sp.Matrix([0, 1]) # Unit vector in the y-direction

constraint_1 = x_com_1 + R(theta1) @ sp.Matrix([-L1/2, 0]) # Calculates the global position of one end of the crank (the pivot point)
C1 = constraint_1.dot(i_cap)
C2 = constraint_1.dot(j_cap)

constraint_2 = x_com_1 - x_com_2 + R(theta1) @ sp.Matrix([L1/2, 0]) - R(theta2) @ sp.Matrix([-L2/2, 0])
C3 = constraint_2.dot(i_cap)
C4 = constraint_2.dot(j_cap)

constraint_3 = x_com_2 + R(theta2) @ sp.Matrix([L2/2, 0]) - x_com_3
C5 = constraint_3.dot(i_cap)
C6 = constraint_3.dot(j_cap)

constraint_4 = x_com_3[0]
C7 = constraint_4

C = sp.Matrix([C1, C2, C3, C4, C5, C6, C7])

In [75]:
# Cell 6: Formulate terms for DAE solution and lambdify
J = C.jacobian(q)
# dq is already q.diff(t)

# Velocity level constraints and their derivative term
dC_expr = J @ dq  # This is the symbolic expression for dC/dt
dJ_times_dq_dt_expr = dC_expr.jacobian(q) @ dq # This is (d(J@dq)/dq) @ dq

# Symbolic inverse mass matrix placeholder
W_sym = sp.MatrixSymbol('W_matrix', M_np.shape[0], M_np.shape[1])

# Define RHS symbolically using the expressions Q, C, dC_expr
# Q, C are already defined as symbolic matrices/vectors above based on q, dq, and constants.
RHS_sym = -dJ_times_dq_dt_expr - J @ W_sym @ Q - 10 * C - 10 * dC_expr

# Lambdify: Convert SymPy expressions to fast numerical functions
# The main arguments for these functions will be the state (q, dq) and the numerical M_inv (W_np)
JWJT_fn = sp.lambdify(args=(q, dq, W_sym), expr=(J @ W_sym @ J.T), modules=['numpy', 'sympy'])
RHS_fn = sp.lambdify(args=(q, dq, W_sym), expr=RHS_sym, modules=['numpy', 'sympy'])

# Numerical functions for C, J, Q, and dC_expr needed for the solver loop or checks
C_fn_num = sp.lambdify(args=(q,), expr=C, modules=['numpy', 'sympy'])
J_fn_num = sp.lambdify(args=(q,), expr=J, modules=['numpy', 'sympy'])
Q_fn_num = sp.lambdify(args=(q, dq), expr=Q, modules=['numpy', 'sympy'])
dC_fn_num = sp.lambdify(args=(q, dq), expr=dC_expr, modules=['numpy', 'sympy']) # Lambdify the expression J@dq

In [76]:
dtheta1_init = 0.5 # rad/s

q_init_guess = np.array([
    0, L1_val/2, np.pi/2,                      # x1, y1, theta1
    0, L1_val + L2_val/2, np.pi/2,             # x2, y2, theta2
    0, L1_val + L2_val                          # x3, y3
])

C_at_init_pos = C_fn_num(q_init_guess).flatten() # Use C_fn_num
print(f"Initial position constraint violation C(q_init_guess): {C_at_init_pos}")
assert np.allclose(C_at_init_pos, 0, atol=1e-6), "Chosen initial positions DO NOT satisfy constraints C=0!"


def velocity_constraint_solver(b_vel, q_pos, known_dtheta1):
    dq_full = np.array([b_vel[0], b_vel[1], known_dtheta1,  # dx1, dy1, dtheta1
                        b_vel[2], b_vel[3], b_vel[4],  # dx2, dy2, dtheta2
                        b_vel[5], b_vel[6]])           # dx3, dy3
    J_num = J_fn_num(q_pos) # Use J_fn_num
    dC_val = (J_num @ dq_full).flatten() 
    return dC_val

import scipy.optimize as opt
# Initial guess for the unknown velocity components
b_vel_guess = np.zeros(7)
# Find roots of velocity_constraint_solver(b_vel) = 0
solution = opt.root(velocity_constraint_solver, b_vel_guess, args=(q_init_guess, dtheta1_init), method='hybr')

if not solution.success:
    print("Warning: Initial velocity optimization might not have converged.")
    print(solution.message)

b_vel_sol = solution.x
dq_init_consistent = np.array([b_vel_sol[0], b_vel_sol[1], dtheta1_init,
                               b_vel_sol[2], b_vel_sol[3], b_vel_sol[4],
                               b_vel_sol[5], b_vel_sol[6]])

# Verify velocity constraints with the solved dq_init_consistent
# J_num_init = J_fn_num(q_init_guess) # J_fn_num used inside velocity_constraint_solver
dC_at_init_vel = dC_fn_num(q_init_guess, dq_init_consistent).flatten() # Use dC_fn_num
print(f"Initial velocity constraint violation dC(q_init, dq_consistent): {dC_at_init_vel}")
assert np.allclose(dC_at_init_vel, 0, atol=1e-6), "Solved initial velocities DO NOT satisfy dC=0!"

x0 = np.concatenate((q_init_guess, dq_init_consistent))
print("\nConsistent initial state vector x0:")
print(x0)


# initial_position_body_1 = np.array([0, L1/2, np.pi/2])
# initial_position_body_2 = np.array([0, L1 + L2/2, np.pi/2])
# initial_position_body_3 = np.array([0, L1 + L2])
# initial_velocity_body_1 = np.array([0, 0, dtheta1]) # To start the engine
# initial_velocity_body_2 = np.array([0, 0, 0])
# initial_velocity_body_3 = np.array([0, 0])
# x0 = np.concatenate((initial_position_body_1, initial_position_body_2, initial_position_body_3,
#                     initial_velocity_body_1, initial_velocity_body_2, initial_velocity_body_3))

Initial position constraint violation C(q_init_guess): [-1.5308085e-17  0.0000000e+00  6.1232340e-17  0.0000000e+00
  4.5924255e-17  0.0000000e+00  0.0000000e+00]
Initial velocity constraint violation dC(q_init, dq_consistent): [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00 -1.89911355e-65]

Consistent initial state vector x0:
[ 0.00000000e+00  2.50000000e-01  1.57079633e+00  0.00000000e+00
  1.25000000e+00  1.57079633e+00  0.00000000e+00  2.00000000e+00
 -1.25000000e-01  7.65404249e-18  5.00000000e-01 -1.25000000e-01
  7.65404249e-18 -1.66666667e-01 -1.89911355e-65  1.43492963e-42]


Calculate initial conditions for the system

In [77]:
# import scipy.optimize as opt

# x, _ = np.split(x0, 2)
# def optimiser(b):
#     dx1, dy1, dx2, dy2, dtheta2, dx3, dy3 = b
#     dq = np.array([dx1, dy1, dtheta1, dx2, dy2, dtheta2, dx3, dy3])
#     val = dC_fn(x, dq).flatten()
#     return val

# initial_guess = np.array([0, 0, 0, 0, 0, 0, 0])
# result = opt.root(optimiser, initial_guess)
# print(result)

# b = result.x
# dx = np.array([b[0], b[1], dtheta1, b[2], b[3], b[4], b[5], b[6]])

# C_val = C_fn(x, dx)
# dC_val = dC_fn(x, dx)

# print(f'Position constraint: {C_val}')
# print(f'Velocity constraint: {dC_val}')
# assert np.allclose(C_val, 0), "Initial position constraint violated"
# assert np.allclose(dC_val, 0), "Initial velocity constraint violated"
# x0 = np.concatenate((x, dx))
# x0

In [78]:
# Cell 8: Define the function for the DAE solver
def piston_engine_flywheel(t, state):
    q_num, dq_num = np.split(state, 2) # Current q and dq from solver

    # W_np is the global numerical inverse mass matrix
    
    # Calculate terms for the solver
    # JWJT_fn and RHS_fn will use the numerical q_num, dq_num, and W_np
    # The expressions for Q, C, dC_expr embedded in RHS_sym will be evaluated by lambdify
    
    JWJT_val = JWJT_fn(q_num, dq_num, W_np) 
    RHS_val = RHS_fn(q_num, dq_num, W_np)

    # Solve for Lagrange multipliers (lambda)
    try:
        lam = np.linalg.solve(JWJT_val, RHS_val)
    except np.linalg.LinAlgError:
        print(f"Singular matrix encountered at t={t}. JWJT_val: {JWJT_val}")
        lam = np.linalg.pinv(JWJT_val) @ RHS_val # Use pseudo-inverse as a fallback

    # Calculate constraint forces/torques
    J_val_num = J_fn_num(q_num) # Still need J_fn_num for Qhat
    Qhat = J_val_num.T @ lam

    # Calculate generalized accelerations (ddq)
    Q_val_num = Q_fn_num(q_num, dq_num) # Still need Q_fn_num for ddq calculation
    ddq = W_np @ (Q_val_num + Qhat)
    ddq = ddq.flatten()

    return np.concatenate((dq_num, ddq))

# Test run
initial_derivatives = piston_engine_flywheel(0, x0)
print("\nInitial derivatives (dq, ddq) from test run:")
print(initial_derivatives)


Initial derivatives (dq, ddq) from test run:
[-1.25000000e-01  7.65404249e-18  5.00000000e-01 -1.25000000e-01
  7.65404249e-18 -1.66666667e-01 -1.89911355e-65  1.43492963e-42
  1.46341463e-01 -6.25000000e-02 -5.85365854e-01  1.46341463e-01
 -1.45833333e-01  1.95121951e-01  1.38777878e-17 -1.66666667e-01]


In [79]:
# Cell 9: Run the numerical simulation
t_start = 0
t_end = 10 # Reduced time for quicker test, original was 30
num_points = 500 # Original was 500

t_span = (t_start, t_end)
t_eval = np.linspace(*t_span, num_points)

print("\nStarting numerical integration...")
sol = solve_ivp(piston_engine_flywheel, t_span, x0,
                atol=1e-7, rtol=1e-7, method='BDF', t_eval=t_eval)
print("Integration finished.")
print(f"Solver status: {sol.message}")
if not sol.success:
    print("Warning: Solver did not terminate successfully.")

# Explanation:
# t_start, t_end, num_points: Parameters for the simulation time.
# t_span: Tuple defining the start and end time of the simulation.
# t_eval: Array of time points where the solution should be stored.
# solve_ivp: Solves the system of DAEs.
#    - piston_engine_flywheel: The function defining the system dynamics.
#    - t_span: Integration interval.
#    - x0: Consistent initial state.
#    - atol, rtol: Absolute and relative error tolerances for the solver.
#    - method='BDF': Uses the Backward Differentiation Formula, good for stiff DAEs.
#    - t_eval: Times at which to store the solution.
# sol: The solution object containing time points (sol.t) and state vectors (sol.y).
# Print statements provide feedback on the integration process.


Starting numerical integration...
Integration finished.
Solver status: The solver successfully reached the end of the integration interval.


Animation

In [80]:
# Class for drawing the box
class Box:
    def __init__(self, width, height, color='b'):
        self.width = width
        self.height = height
        self.color = color
        self.offset = -np.array([width/2, height/2])

    def first_draw(self, ax):
        corner = np.array([0, 0])
        self.patch = plt.Rectangle(corner, 0, 0, angle=0, 
                        rotation_point='center', color=self.color, animated=True)
        ax.add_patch(self.patch)
        self.ax = ax
        return self.patch
    
    def set_data(self, x, y, theta):
        self.x = x
        self.y = y
        self.theta = theta

    def update(self, i):
        x, y, theta = self.x[i], self.y[i], self.theta[i]
        theta = np.rad2deg(theta)

        # The rectangle is drawn from the left bottom corner
        # So, we need to calculate the corner position
        corner = np.array([x, y]) + self.offset

        # Update the values for the rectangle
        self.patch.set_width(self.width)
        self.patch.set_height(self.height)
        self.patch.set_xy(corner)
        self.patch.set_angle(theta)
        return self.patch
    
# --- MODIFICATION: Corrected Circle class for the flywheel WITH a rotating line ---
from matplotlib.patches import Circle
from matplotlib.lines import Line2D

class FlywheelVisual:
    def __init__(self, radius, color='dimgray', line_color='black', line_angle_offset_rad=np.pi/2): # Added offset
        self.radius = radius
        self.color = color
        self.line_color = line_color
        self.line_angle_offset_rad = line_angle_offset_rad # Store the offset
        self.patch = None 
        self.line = None  

    def first_draw(self, ax):
        self.patch = Circle((0,0), self.radius, fc=self.color, animated=True, zorder=1)
        ax.add_patch(self.patch)
        
        # Initial line position will be set by the first call to update via animate_frame
        self.line = Line2D([0, 0], [0, 0], color=self.line_color, lw=2, animated=True, zorder=2) 
        ax.add_line(self.line)
        
        return [self.patch, self.line]

    def set_data(self, center_x_array, center_y_array, theta_crank_array):
        self.center_x_array = center_x_array
        self.center_y_array = center_y_array
        self.theta_crank_array = theta_crank_array

    def update(self, i):
        cx = self.center_x_array[i]
        cy = self.center_y_array[i]
        self.patch.center = (cx, cy)
        
        # --- MODIFICATION: Apply angle offset for the flywheel line ---
        th_crank_effective = self.theta_crank_array[i] + self.line_angle_offset_rad
        # --- END MODIFICATION ---
        
        line_start_x = cx
        line_start_y = cy
        line_end_x = cx + self.radius * np.cos(th_crank_effective) # Use effective angle
        line_end_y = cy + self.radius * np.sin(th_crank_effective) # Use effective angle
        
        self.line.set_data([line_start_x, line_end_x], [line_start_y, line_end_y])
        
        return [self.patch, self.line]
# --- END MODIFICATION ---

In [81]:
# Cell 12: Create and display the animation
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

fig_anim, ax_anim = plt.subplots(figsize=(6,8))
plt.close(fig_anim)

ax_anim.set_ylim(- (L1_val + L2_val + r_flywheel_val + 0.2), (L1_val + L2_val + r_flywheel_val + 0.2) )
ax_anim.set_xlim(- (L1_val + r_flywheel_val + 0.2), (L1_val + r_flywheel_val + 0.2) )
ax_anim.set_aspect('equal')
ax_anim.grid(True)

x1_sol, y1_sol, theta1_sol = sol.y[0,:], sol.y[1,:], sol.y[2,:]
x2_sol, y2_sol, theta2_sol = sol.y[3,:], sol.y[4,:], sol.y[5,:]
x3_sol, y3_sol             = sol.y[6,:], sol.y[7,:]
theta3_sol = np.zeros_like(x3_sol)

crank_visual = Box(L1_val, 0.05, color='blue') 
conrod_visual = Box(L2_val, 0.05, color='red')
piston_visual = Box(0.2, 0.4, color='green') 

# --- MODIFICATION: Instantiate FlywheelVisual with a default 90-degree offset for its line ---
# The default offset is np.pi/2 (90 degrees) in the class definition.
# You could also pass it here: line_angle_offset_rad=np.pi/2
flywheel_visual_obj = FlywheelVisual(r_flywheel_val, color='dimgray', line_color='black')
# --- END MODIFICATION ---

crank_visual.set_data(x1_sol, y1_sol, theta1_sol)
conrod_visual.set_data(x2_sol, y2_sol, theta2_sol)
piston_visual.set_data(x3_sol, y3_sol, theta3_sol)

flywheel_center_x_data = np.zeros_like(theta1_sol) 
flywheel_center_y_data = np.zeros_like(theta1_sol) 
flywheel_visual_obj.set_data(flywheel_center_x_data, flywheel_center_y_data, theta1_sol)

artists_to_animate = [flywheel_visual_obj, piston_visual, conrod_visual, crank_visual]

def init_animation():
    ax_anim.set_title("t=0.00 sec", fontsize=15)
    all_patches_flat = []
    for artist_obj in artists_to_animate:
        # Ensure first_draw actually draws something visible for the first frame if blit=True is used later.
        # Or, ensure the first call to animate_frame(0) correctly sets all data.
        # The line in FlywheelVisual is initialized to (0,0)-(0,0) and then updated by the first animate_frame.
        patches_from_obj = artist_obj.first_draw(ax_anim)
        if isinstance(patches_from_obj, list):
            all_patches_flat.extend(patches_from_obj)
        else:
            all_patches_flat.append(patches_from_obj)
    # Call animate_frame(0) once in init if blit=True is problematic for the very first frame
    # This ensures all artists are in their correct initial state before blitting starts caching.
    # return animate_frame(0) # Alternative for blit=True issues
    return all_patches_flat


def animate_frame(i):
    ax_anim.set_title(f"t={sol.t[i]:.2f} sec", fontsize=15)
    updated_patches_flat = []
    for artist_obj in artists_to_animate:
        patches_from_obj = artist_obj.update(i)
        if isinstance(patches_from_obj, list):
            updated_patches_flat.extend(patches_from_obj)
        else:
            updated_patches_flat.append(patches_from_obj)
    return updated_patches_flat

sim_dt = sol.t[1] - sol.t[0] if len(sol.t) > 1 else 0.04 
animation_interval_ms = 1000 * sim_dt 

# --- MODIFICATION: Try blit=False first to remove stationary line ---
anim_piston = FuncAnimation(fig_anim, animate_frame, frames=len(sol.t),
                            init_func=init_animation, blit=False, interval=max(20, animation_interval_ms)) 
# --- END MODIFICATION ---

html_video = HTML(anim_piston.to_html5_video())
disp.display(html_video)