<a href="https://colab.research.google.com/github/dont-have-a-name/Project_IV/blob/main/edited_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pyelastica



In [None]:
import numpy as np

# import modules
from elastica.modules import BaseSystemCollection, Constraints, Forcing, CallBacks, Damping

# Import Boundary Condition Classes
from elastica.boundary_conditions import OneEndFixedRod, FreeRod
from elastica.external_forces import EndpointForces

# import rod class, damping and forces to be applied
from elastica.rod.cosserat_rod import CosseratRod
from elastica.dissipation import AnalyticalLinearDamper
from elastica.external_forces import GravityForces, NoForces #, MuscleTorques
from elastica.interaction import AnisotropicFrictionalPlane

# import timestepping functions
from elastica.timestepper.symplectic_steppers import PositionVerlet
from elastica.timestepper import integrate

# import call back functions
from elastica.callback_functions import CallBackBaseClass
from collections import defaultdict

#extra stuff to edit muscles
from elastica._linalg import _batch_matvec
from elastica.typing import SystemType, RodType
from elastica.utils import _bspline
from numba import njit
from elastica._linalg import _batch_product_i_k_to_ik

class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks, Damping):
    pass

snake_sim = SnakeSimulator()

# Define rod parameters
n_elem = 50
start = np.array([0.0, 0.0, 0.0]) #(up, out of page, right)
direction = np.array([0.0, 0.0, 1.0]) #[0,0,1]
normal = np.array([0.0, 1.0, 0.0]) #[0,1,0] so
base_length = 0.4 #0.35
base_radius = base_length * 0.011 #0.011
base_area = np.pi * base_radius ** 2    #think how to change this so it's different at every node
density = 1000
nu = 2e-3
E = 1e6
poisson_ratio = 0.5
shear_modulus = E / (poisson_ratio + 1.0)

# Create rod
shearable_rod = CosseratRod.straight_rod(
    n_elem,
    start,
    direction,
    normal,
    base_length,
    base_radius,
    density,
    youngs_modulus=E,
    shear_modulus=shear_modulus,
)

# Add rod to the snake system
snake_sim.append(shearable_rod)

dt = 1e-4
snake_sim.dampen(shearable_rod).using(
    AnalyticalLinearDamper,
    damping_constant=nu,
    time_step=dt,
)

# Add gravitational forces
gravitational_acc = -9.80665
snake_sim.add_forcing_to(shearable_rod).using(
    GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
)
print("Gravity now acting on shearable rod")

# Define muscle torque parameters
period = 2.0
wave_length = 1.0
b_coeff = np.array([3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]) #([5e-3, 4.5e-3, 4e-3, 3.5e-3, 3e-3, 2.5e-3]) #[3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3]
b_coeff2 = b_coeff[::-1]
#bending goes  down exponentially

# Add muscle torques to the rod
class MuscleTorques(NoForces):
  def __init__(
    self,
    base_length,
    b_coeff,
    period,
    wave_number,
    phase_shift,
    direction,
    rest_lengths,
    ramp_up_time,
    with_spline=False,
    ):
        super(MuscleTorques, self).__init__()

        self.direction = direction  # Direction torque applied
        self.angular_frequency = 2.0 * np.pi / period
        self.wave_number = wave_number
        self.phase_shift = phase_shift

        assert ramp_up_time > 0.0
        self.ramp_up_time = ramp_up_time

        # s is the position of nodes on the rod, we go from node=1 to node=nelem-1, because there is no
        # torques applied by first and last node on elements. Reason is that we cannot apply torque in an
        # infinitesimal segment at the beginning and end of rod, because there is no additional element
        # (at element=-1 or element=n_elem+1) to provide internal torques to cancel out an external
        # torque. This coupled with the requirement that the sum of all muscle torques has
        # to be zero results in this condition.
        self.s = np.cumsum(rest_lengths)
        self.s /= self.s[-1]

        if with_spline:
            assert b_coeff.size != 0, "Beta spline coefficient array (t_coeff) is empty"
            my_spline, ctr_pts, ctr_coeffs = _bspline(b_coeff)
            self.my_spline = my_spline(self.s)

        else:

            def constant_function(input):
                """
                Return array of ones same as the size of the input array. This
                function is called when Beta spline function is not used.

                Parameters
                ----------
                input

                Returns
                -------

                """
                return np.ones(input.shape)

            self.my_spline = constant_function(self.s)

  def apply_torques(self, rod: RodType, time: np.float64 = 0.0):
        self.compute_muscle_torques(
            time,
            self.my_spline,
            self.s,
            self.angular_frequency,
            self.wave_number,
            self.phase_shift,
            self.ramp_up_time,
            self.direction,
            rod.director_collection,
            rod.external_torques,
        )

  @staticmethod
  @njit(cache=True)
  def compute_muscle_torques(
        time,
        my_spline,
        s,
        angular_frequency,
        wave_number,
        phase_shift,
        ramp_up_time,
        direction,
        director_collection,
        external_torques,
    ):
        # Ramp up the muscle torque
        factor = min(1.0, time / ramp_up_time)
        # From the node 1 to node nelem-1
        # Magnitude of the torque. Am = beta(s) * sin(2pi*t/T + 2pi*s/lambda + phi)
        # There is an inconsistency with paper and Elastica cpp implementation. In paper sign in
        # front of wave number is positive, in Elastica cpp it is negative.
        torque_mag = (
            factor
            * my_spline
            * np.sin(angular_frequency * time - wave_number * s + phase_shift)
        )
        # Head and tail of the snake is opposite compared to elastica cpp. We need to iterate torque_mag
        # from last to first element.
        torque = _batch_product_i_k_to_ik(direction, torque_mag[::-1]) #was ::-1
        inplace_addition(
            external_torques[..., 1:],
            _batch_matvec(director_collection, torque)[..., 1:],
        )
        inplace_substraction(
            external_torques[..., :-1],
            _batch_matvec(director_collection[..., :-1], torque[..., 1:]),
        )
#addition to the addition
@njit(cache=True)
def inplace_addition(external_force_or_torque, force_or_torque):
    blocksize = force_or_torque.shape[1]
    for i in range(3):
        for k in range(blocksize):
            external_force_or_torque[i, k] += force_or_torque[i, k]


@njit(cache=True)
def inplace_substraction(external_force_or_torque, force_or_torque):
    blocksize = force_or_torque.shape[1]
    for i in range(3):
        for k in range(blocksize):
            external_force_or_torque[i, k] -= force_or_torque[i, k]
#end of addition to addition
#end of addition

snake_sim.add_forcing_to(shearable_rod).using(
    MuscleTorques,
    base_length=base_length,
    b_coeff=b_coeff, #b_coeff2,
    period=period,
    wave_number=2.0 * np.pi / (wave_length),
    phase_shift=0.0,
    rest_lengths=shearable_rod.rest_lengths,
    ramp_up_time=period,
    direction= -normal,
    with_spline=True
)
print("Muscle torques added to the rod")

# Define friction force parameters
origin_plane = np.array([0.0, -base_radius, 0.0]) #[0,-base radius,0]
normal_plane = normal
slip_velocity_tol = 1e-8
froude = 0.1
mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
kinetic_mu_array = np.array(
    [1.0 * mu, 1.5 * mu, 2.0 * mu]
)  # [forward, backward, sideways]
static_mu_array = 2 * kinetic_mu_array

# Add friction forces to the rod

snake_sim.add_forcing_to(shearable_rod).using(
    AnisotropicFrictionalPlane,
    k=1.0,
    nu=1e-6,
    plane_origin=origin_plane,
    plane_normal=normal_plane,
    slip_velocity_tol=slip_velocity_tol,
    static_mu_array=static_mu_array,
    kinetic_mu_array=kinetic_mu_array,
)
print("Friction forces added to the rod")

#snake_sim.constrain(shearable_rod).using(
#    OneEndFixedRod, constrained_position_idx=(0,), constrained_director_idx=(0,)
#)



# Add call backs
class ContinuumSnakeCallBack(CallBackBaseClass):

#Call back function for continuum snake


    def __init__(self, step_skip: int, callback_params: dict):
        CallBackBaseClass.__init__(self)
        self.every = step_skip
        self.callback_params = callback_params

    def make_callback(self, system, time, current_step: int):

        if current_step % self.every == 0:

            self.callback_params["time"].append(time)
            self.callback_params["step"].append(current_step)
            self.callback_params["position"].append(system.position_collection.copy())
            self.callback_params["velocity"].append(system.velocity_collection.copy())
            self.callback_params["avg_velocity"].append(
                system.compute_velocity_center_of_mass()
            )

            self.callback_params["center_of_mass"].append(
                system.compute_position_center_of_mass()
            )
            self.callback_params["curvature"].append(system.kappa.copy())

            return


pp_list = defaultdict(list)
snake_sim.collect_diagnostics(shearable_rod).using(
    ContinuumSnakeCallBack, step_skip=100, callback_params=pp_list
)
print("Callback function added to the simulator")

snake_sim.finalize()

final_time = 5.0 * period
total_steps = int(final_time / dt)
print("Total steps", total_steps)

timestepper = PositionVerlet()

integrate(timestepper, snake_sim, final_time, total_steps)

from IPython.display import Video
from tqdm import tqdm


def plot_video_2D(plot_params: dict, video_name="video.mp4", margin=0.2, fps=15):
    from matplotlib import pyplot as plt
    import matplotlib.animation as manimation

    t = np.array(plot_params["time"])
    positions_over_time = np.array(plot_params["position"])
    total_time = int(np.around(t[..., -1], 1))
    total_frames = fps * total_time
    step = round(len(t) / total_frames)
    #print(positions_over_time)

    print("creating video -- this can take a few minutes")
    FFMpegWriter = manimation.writers["ffmpeg"]
    metadata = dict(title="Movie Test", artist="Matplotlib", comment="Movie support!")
    writer = FFMpegWriter(fps=fps, metadata=metadata)

    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.axis("equal")
    rod_lines_2d = ax.plot(
        positions_over_time[0][2], positions_over_time[0][0], linewidth=3
    )[0]
    ax.set_xlim([0 - margin, 0.6 + margin]) #0-, 3+
    ax.set_ylim([-0.5 - margin, 0.5 + margin]) #+-1.5
    with writer.saving(fig, video_name, dpi=100):
        with plt.style.context("seaborn-whitegrid"):
            for time in range(1, len(t), step):
                rod_lines_2d.set_xdata(positions_over_time[time][2])
                rod_lines_2d.set_ydata(positions_over_time[time][0])

                writer.grab_frame()
    plt.close(fig)


filename_video = "continuum_snake.mp4"
plot_video_2D(pp_list, video_name=filename_video, margin=0.2, fps=125)

Video("continuum_snake.mp4")

from google.colab import files
files.download("continuum_snake.mp4")

"""
def plot_video(plot_params: dict, video_name="video.mp4", margin=0.2, fps=15):
    from matplotlib import pyplot as plt
    import matplotlib.animation as manimation
    from mpl_toolkits import mplot3d

    t = np.array(plot_params["time"])
    positions_over_time = np.array(plot_params["position"])
    total_time = int(np.around(t[..., -1], 1))
    total_frames = fps * total_time
    step = round(len(t) / total_frames)
    print("creating video -- this can take a few minutes")
    FFMpegWriter = manimation.writers["ffmpeg"]
    metadata = dict(title="Movie Test", artist="Matplotlib", comment="Movie support!")
    writer = FFMpegWriter(fps=fps, metadata=metadata)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    ax.set_xlim(0 - margin, 3 + margin)
    ax.set_ylim(-1.5 - margin, 1.5 + margin)
    ax.set_zlim(0, 1)
    ax.view_init(elev=20, azim=-80)
    rod_lines_3d = ax.plot(
        positions_over_time[0][2],
        positions_over_time[0][0],
        positions_over_time[0][1],
        linewidth=3,
    )[0]
    with writer.saving(fig, video_name, dpi=100):
        with plt.style.context("seaborn-whitegrid"):
            for time in range(1, len(t), step):
                rod_lines_3d.set_xdata(positions_over_time[time][2])
                rod_lines_3d.set_ydata(positions_over_time[time][0])
                rod_lines_3d.set_3d_properties(positions_over_time[time][1])

                writer.grab_frame()
    plt.close(fig)


filename_video = "continuum_snake_3d.mp4"
plot_video(pp_list, video_name=filename_video, margin=0.2, fps=60)

Video("continuum_snake_3d.mp4")

from google.colab import files
files.download("continuum_snake_3d.mp4")

#print("Time List:", pp_list["time"])
"""

Gravity now acting on shearable rod
Muscle torques added to the rod
Friction forces added to the rod
Callback function added to the simulator
Total steps 100000


100%|██████████| 100000/100000 [00:23<00:00, 4261.36it/s]


Final time of simulation is :  9.999999999983364
creating video -- this can take a few minutes


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

'\ndef plot_video(plot_params: dict, video_name="video.mp4", margin=0.2, fps=15):\n    from matplotlib import pyplot as plt\n    import matplotlib.animation as manimation\n    from mpl_toolkits import mplot3d\n\n    t = np.array(plot_params["time"])\n    positions_over_time = np.array(plot_params["position"])\n    total_time = int(np.around(t[..., -1], 1))\n    total_frames = fps * total_time\n    step = round(len(t) / total_frames)\n    print("creating video -- this can take a few minutes")\n    FFMpegWriter = manimation.writers["ffmpeg"]\n    metadata = dict(title="Movie Test", artist="Matplotlib", comment="Movie support!")\n    writer = FFMpegWriter(fps=fps, metadata=metadata)\n    fig = plt.figure()\n    ax = fig.add_subplot(111, projection="3d")\n    ax.set_xlim(0 - margin, 3 + margin)\n    ax.set_ylim(-1.5 - margin, 1.5 + margin)\n    ax.set_zlim(0, 1)\n    ax.view_init(elev=20, azim=-80)\n    rod_lines_3d = ax.plot(\n        positions_over_time[0][2],\n        positions_over_tim