In [155]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import scipy as sc
import sympy as sp
from IPython import display
from matplotlib.animation import Animation, FuncAnimation, writers
from mayavi import mlab
from mpl_toolkits.mplot3d import Axes3D

with open("equatorial_coordinate_geodesic.pkl", "rb") as file:
    dVt = pickle.load(file)
    dVr = pickle.load(file)
    dVth = pickle.load(file)
    dVp = pickle.load(file)

dV = [dVt, dVr, dVth, dVp];

In [174]:
class Plot_Geodesic:
    def __init__(self, geodesics, M=1):
        self.mass = M
        self.geodesics = [
            geodesic.subs(sp.Symbol("M"), self.mass) for geodesic in geodesics
        ]
        self.dvt = sp.lambdify(
            (
                sp.Symbol("r"),
                sp.Symbol("theta"),
                sp.Symbol("phi"),
                sp.Symbol("u1"),
                sp.Symbol("u2"),
                sp.Symbol("u3"),
            ),
            self.geodesics[0],
            "numpy",
        )

        self.dvr = sp.lambdify(
            (
                sp.Symbol("r"),
                sp.Symbol("theta"),
                sp.Symbol("phi"),
                sp.Symbol("u1"),
                sp.Symbol("u2"),
                sp.Symbol("u3"),
            ),
            self.geodesics[1],
            "numpy",
        )

        self.dvth = sp.lambdify(
            (
                sp.Symbol("r"),
                sp.Symbol("theta"),
                sp.Symbol("phi"),
                sp.Symbol("u1"),
                sp.Symbol("u2"),
                sp.Symbol("u3"),
            ),
            self.geodesics[2],
            "numpy",
        )

        self.dvp = sp.lambdify(
            (
                sp.Symbol("r"),
                sp.Symbol("theta"),
                sp.Symbol("phi"),
                sp.Symbol("u1"),
                sp.Symbol("u2"),
                sp.Symbol("u3"),
            ),
            self.geodesics[3],
            "numpy",
        )

    def __black_hole(self):
        black_hole_radius = 2 * self.mass

        bh_phi, bh_theta = np.mgrid[0 : np.pi : 101j, 0 : 2 * np.pi : 101j]
        bh_x = black_hole_radius * np.sin(bh_phi) * np.cos(bh_theta)
        bh_y = black_hole_radius * np.sin(bh_phi) * np.sin(bh_theta)
        bh_z = black_hole_radius * np.cos(bh_phi)

        return bh_x, bh_y, bh_z

    def __geodesic(self, t, state):
        r, theta, phi, V1, V2, V3 = state

        dr = V1
        dth = 0
        dph = V3
        dV1 = self.dvr(r, 0, 0, V1, 0, V3)
        dV2 = 0
        dV3 = self.dvp(r, 0, 0, V1, 0, V3)

        return np.array([dr, dth, dph, dV1, dV2, dV3])

    def __orbit(self, y0, lim_t=10000):
        def lim_fun(t, y):
            return (y[0] - 2) > 2e-15

        lim_fun.terminal = True

        t_span = (0.0, lim_t)
        t_eval = np.linspace(0, lim_t, lim_t + 1)

        result_solve_ivp = sc.integrate.solve_ivp(
            self.__geodesic,
            t_span,
            y0,
            rtol=1e-14,
            atol=1e-15,
            method="RK45",
            t_eval=t_eval,
            events=lim_fun,
        )

        r = result_solve_ivp.y[0]
        theta = result_solve_ivp.y[1]
        phi = result_solve_ivp.y[2]

        x = r * np.cos(phi)
        y = r * np.sin(phi)
        z = r * 0

        return x, y, z

    def __plot_black_hole_matplotlib(self):
        %matplotlib widget
        fig = plt.figure()
        ax = fig.add_subplot(111, projection="3d")
        plt.cla()

        bh_x, bh_y, bh_z = self.__black_hole()
        ax.plot_surface(bh_x, bh_y, bh_z, color="k")
        return fig, ax

    def plot_geodesic_matplotlib(self, y0, lim_t=10000, x_lim=10, y_lim=10):

        ax = self.__plot_black_hole()
        x, y, z = self.__orbit(y0, lim_t)
        (q,) = ax.plot(x, y, z, linewidth=0.2)

        ax.set_title("solve_ivp")
        plt.xlabel("x")
        plt.ylabel("y")
        ax.view_init(elev=10, azim=-45)

        ax.grid(False)
        plt.xlim([-x_lim, x_lim])
        plt.ylim([-x_lim, x_lim])
        ax.set_aspect("equal")
        return ax

    def animate_geodesic_matplotlib(
        self, y0, frames=50, lim_t=10000, x_lim=10, y_lim=10
    ):
        fig, ax = self.__plot_black_hole()
        x, y, z = self.__orbit(y0, lim_t)

        ax.grid(False)
        plt.xlim([-x_lim, x_lim])
        plt.ylim([-y_lim, y_lim])
        ax.set_aspect("equal")

        spheres = np.empty(frames, dtype=object)
        ax.view_init(elev=90, azim=90)
        (spheres[0],) = ax.plot(x[0], y[0], z[0], "bo")

        def animation_function(ii):
            # global spheres
            if ii == 0:
                return
            if ii > 0:
                spheres[(ii - 1)].set_visible(False)
            # plt.cla()
            # ax.view_init(elev=0, azim=90)
            (spheres[ii],) = ax.plot(x[ii], y[ii], z[ii], "bo")

        anim = FuncAnimation(
            fig, animation_function, frames=frames, repeat=False, interval=100
        )
        # converting to an html5 video
        video = anim.to_html5_video()

        # embedding for the video
        html = display.HTML(video)

        # draw the animation
        display.display(html)

    def __mayavi_surface(self):
        surface_radius, surface_phi = np.meshgrid(
            np.geomspace(2.01, 10, num=400), np.linspace(0, 2 * np.pi, 101)
        )

        surface_x = surface_radius * np.cos(surface_phi)
        surface_y = surface_radius * np.sin(surface_phi)

        surface_scalar = np.sqrt(1 / (1 - 2 / np.sqrt(surface_x**2 + surface_y**2)))
        surface_scalar[np.isinf(surface_scalar) | np.isnan(surface_scalar)] = 0

        surface_z = -surface_scalar + 1
        norm = surface_z.min()
        surface_z = -surface_z / norm * 8

        return surface_x, surface_y, surface_z, surface_scalar

    def animate_geodesic_mayavi(self, y0, lim_t=10000, x_lim=10, y_lim=10):
        fig = mlab.gcf()
        mlab.clf()

        bh_x, bh_y, bh_z = self.__black_hole()
        surface_x, surface_y, surface_z, surface_scalar = self.__mayavi_surface()
        x, y, z = self.__orbit(y0, lim_t)
        
        
        z = -np.sqrt(1 / (1 - 2 / np.sqrt(x**2 + y**2))) + 1
        z = -z / surface_z.min() * 8
        z[z < surface_z.min()] = surface_z.min()

        black_hole = mlab.mesh(bh_x, bh_y, bh_z, color=(0, 0, 0))
        equatorial_surface = mlab.mesh(
                                        surface_x,
                                        surface_y,
                                        surface_z,
                                        scalars=surface_scalar,
                                        opacity=0.4,
                                        colormap="blue-red",
                                    )
        orbit = mlab.plot3d(x, y, z, tube_radius=0.05, opacity=1, color=(0.3, 0.4, 0.5))
        planet = mlab.points3d(x[0], y[0], z[0], color=(1, 0.01, 0.01))

        @mlab.animate(delay=1000)
        def anim():
            fig = mlab.gcf()
            for ii in np.arange(0, len(x)):
                planet.mlab_source.trait_set(x=x[ii], y=y[ii], z=z[ii])
                if ii > 0:
                    orbit.mlab_source.reset(x=x[:ii], y=y[:ii], z=z[:ii])
                yield
                fig.scene.reset_zoom()

        anim()
        mlab.view(0, 180)
        mlab.show()

In [175]:
a = Plot_Geodesic(dV)

In [176]:
a.animate_geodesic_mayavi(np.array([4, np.pi / 2, 0, -0.1, 0, 0.143498197]))

1   HIToolbox                           0x00007ff80d9d2726 _ZN15MenuBarInstance22EnsureAutoShowObserverEv + 102
2   HIToolbox                           0x00007ff80d9d22b8 _ZN15MenuBarInstance14EnableAutoShowEv + 52
3   HIToolbox                           0x00007ff80d941cd7 _ZN15MenuBarInstance21UpdateAggregateUIModeE21MenuBarAnimationStylehhh + 1113
4   HIToolbox                           0x00007ff80d9d2173 _ZN15MenuBarInstance19SetFullScreenUIModeEjj + 175
5   AppKit                              0x00007ff806e994b7 -[NSApplication _setPresentationOptions:instance:flags:] + 1145
6   AppKit                              0x00007ff806cee165 -[NSApplication _updateFullScreenPresentationOptionsForInstance:] + 582
7   CoreFoundation                      0x00007ff803acd6d6 __CFNOTIFICATIONCENTER_IS_CALLING_OUT_TO_AN_OBSERVER__ + 137
8   CoreFoundation                      0x00007ff803b66cbc ___CFXRegistrationPost_block_invoke + 86
9   CoreFoundation                      0x00007ff803b66c13 _CFXR