In [3]:
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import csv


args = {'name': 'example'}

def force_function(v):
    """ In some cases is purely a function of position, other on velocity """
    return (v*(300-v)/1000) * ( 1 + torch.sin(v/20)/10 + torch.cos(v/40)/10 ) + (v/70)**2


def f(t, z):
    """function in the differential equation"""
    y, v = z

    # ======== UPDATE HERE IF NEW force_function ================= <<<<
    # dzdt = [v, force_function(y,v)]
    dzdt = [v, force_function(v) ]
    # ============================================================
    return dzdt


In [None]:
def main():
    # the end time is controlled by the number of points
    N = 1000000
    # harcodding the dt so is always fixed for a NN
    dt = 1e-5

    z0 = [100.0, 0.0]  # initial condition: y=1, v=0
    
    # tol = 1e-5  # tolerance for error control

    t0 = 0.0
    tf = t0 + N * dt
    print(f't0: {t0}; tf: {tf}')

    t_span = [t0,tf]
    t_eval = np.arange(start= t0, stop=tf, step=dt)

    # ======= writting the data
    with open(f'data/{args.name}_parameters.txt', 'w', newline='') as file:
        file.write(f'N: {N} \n')
        file.write(f'dt: {dt} \n')
        file.write(f't0: {t0}, tf= {tf} \n')
        file.write(f'z0: {z0} \n')
        # file.write(f'tolerance: {tol} \n')
    # ======

    sol = solve_ivp(fun=lambda t, z: f(t, z), t_span=t_span, 
        y0=z0, t_eval=t_eval, vectorized=True)


    if sol.status == 0:
        y = sol.y[0] # position
        v = sol.y[1] # velocity
        data = np.vstack((sol.t, y, v)).T

        with open(f'data/{args.name}.csv', 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['time', 'position', 'velocity'])
            writer.writerows(data)

        plt.title(f'{args.name}')
        plt.plot(sol.t, y, label='rk4 - position')
        plt.plot(sol.t, v, label='rk4 - velocity')
        plt.xlabel('Time')
        plt.legend()
        plt.savefig(f'data/{args.name}.png')
        plt.show()

    else:
        print('Solver failed to converge!')
        print(sol.message)

# def export limits


