### Argparse : Python's built in module for parsing command-line arguments, lets users customize the simulation without editing code
### PyTorch: GPU accelerated math for ODE solving and Autograd for automatic differentiation for gradient based optimization
#### Purpose: Imports necessary libraries for numerical computation (torch), ODE solving (torchdiffeq), and plotting (matplotlib).

In [2]:
import argparse
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchdiffeq import odeint, odeint_adjoint, odeint_event

### Comments explain what is happening

In [None]:
class BouncingBallExample(nn.Module):
    # Initializes parameters as PyTorch nn.Parameter for autograd
    # nn.Parameter is a subclass of torch.Tensor that automatically registers the tensor as a trainable parameter of an nn.Module
    # These parameters appear in model.parameters() and are visible to PyTorch optimizers (e.g., Adam, SGD).
    #Unlike a regular tensor, nn.parameter allows gradients to be computed and is visible to optimizers
    def __init__(self, radius=0.2, gravity=9.8, adjoint=False):
        super().__init__()
        self.gravity = nn.Parameter(torch.as_tensor([gravity]))
        self.log_radius = nn.Parameter(torch.log(torch.as_tensor([radius])))
        self.t0 = nn.Parameter(torch.tensor([0.0]))
        self.init_pos = nn.Parameter(torch.tensor([10.0]))
        self.init_vel = nn.Parameter(torch.tensor([0.0]))
        self.absorption = nn.Parameter(torch.tensor([0.2]))
        self.odeint = odeint_adjoint if adjoint else odeint
    #Defines the ODE system (free-fall dynamics):
    def forward(self, t, state):
        pos, vel, log_radius = state
        dpos = vel
        dvel = -self.gravity
        return dpos, dvel, torch.zeros_like(log_radius)

    def event_fn(self, t, state):
        # positive if ball in mid-air, negative if ball within ground.
        pos, _, log_radius = state
        return pos - torch.exp(log_radius)

    def get_initial_state(self):
        state = (self.init_pos, self.init_vel, self.log_radius)
        return self.t0, state

    def state_update(self, state):
        """Updates state based on an event (collision)."""
        pos, vel, log_radius = state
        pos = (
            pos + 1e-7
        )  # need to add a small eps so as not to trigger the event function immediately.
        vel = -vel * (1 - self.absorption)
        return (pos, vel, log_radius)

    def get_collision_times(self, nbounces=1):

        event_times = []

        t0, state = self.get_initial_state()

        for i in range(nbounces):
            event_t, solution = odeint_event(
                self,
                state,
                t0,
                event_fn=self.event_fn,
                reverse_time=False,
                atol=1e-8,
                rtol=1e-8,
                odeint_interface=self.odeint,
            )
            event_times.append(event_t)

            state = self.state_update(tuple(s[-1] for s in solution))
            t0 = event_t

        return event_times

    def simulate(self, nbounces=1):
        event_times = self.get_collision_times(nbounces)

        # get dense path
        t0, state = self.get_initial_state()
        trajectory = [state[0][None]]
        velocity = [state[1][None]]
        times = [t0.reshape(-1)]
        for event_t in event_times:
            tt = torch.linspace(
                float(t0), float(event_t), int((float(event_t) - float(t0)) * 50)
            )[1:-1]
            tt = torch.cat([t0.reshape(-1), tt, event_t.reshape(-1)])
            solution = odeint(self, state, tt, atol=1e-8, rtol=1e-8)

            trajectory.append(solution[0][1:])
            velocity.append(solution[1][1:])
            times.append(tt[1:])

            state = self.state_update(tuple(s[-1] for s in solution))
            t0 = event_t

        return (
            torch.cat(times),
            torch.cat(trajectory, dim=0).reshape(-1),
            torch.cat(velocity, dim=0).reshape(-1),
            event_times,
        )
