In [None]:
# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly_resampler import FigureResampler
import plotly.io as pio
pio.renderers
pio.renderers.default = "notebook_connected"
from mpl_toolkits import mplot3d
import os
from IPython.display import clear_output
import csv

np.random.seed()

In [None]:
def create_sun(size, clr, dist=0, opacity=1):
    """
    create a yellow sphere, with a given size and opacity
    """
    # Set up 100 points. First, do angles
    theta = np.linspace(0,2*np.pi,100)
    phi = np.linspace(0,np.pi,100)
    
    # Set up coordinates for points on the sphere
    x0 = dist + size * np.outer(np.cos(theta),np.sin(phi))
    y0 = size * np.outer(np.sin(theta),np.sin(phi))
    z0 = size * np.outer(np.ones(100),np.cos(phi))
    
    # Set up trace
    trace = go.Surface(x=x0, y=y0, z=z0, colorscale=[[0,clr], [1,clr]], opacity=opacity)
    trace.update(showscale=False)

    return trace

def create_photon(x, y, z, clr='white', wdth=2):
    """
    create a photon track using a 3d scatter plot

    Parameters
    ----------
    x: np.ndarray
        an array of points to plot in the x dim
    y: np.ndarray
        an array of points to plot in the y dim
    z: np.ndarray
        an array of points to plot in the z dim
    """
    # build trace
    trace = go.Scatter3d(x=x, y=y, z=z, line=dict(color=clr, width=wdth), marker=dict(size=0.1))

    return trace

def plot_photon_and_sun(photon: list, cut_down_size=None, mode=None, plot_sun=True, sun_size=0, static=None):
    """
    This function plots the photon and sun with Plotly Express

    Parameters
    ----------
    photon: list | np.ndarray
        a photon class object
    cut_down_size: tuple(int, int)
        the amount of datapoints to plot to, i.e 1000 plots to the 1000th datapoint
    mode: str
        the mode to plot in, e.g. inline
    plot_sun: bool
        whether to overlay the sun or not
    sun_size: int
        the scale of the sun
    static: str
        if passed in, save the plot as html with static being the filepath
    """
    photon_track = np.array(photon)
    
    if sun_size == 0:
        sun_size = R

    if cut_down_size:
        photon_track = photon_track[cut_down_size[0]:cut_down_size[1]]

    scene = go.layout.Scene(
        xaxis=dict(nticks=5, range=[-sun_size - 1_000_000, sun_size + 1_000_000],),
        yaxis = dict(nticks=5, range=[-sun_size - 1_000_000, sun_size + 1_000_000],),
        zaxis = dict(nticks=5, range=[-sun_size - 1_000_000, sun_size + 1_000_000],)
    )

    layout = go.Layout(
        autosize=False,
        width=700,
        height=700,
        margin=go.layout.Margin(
            l=50,
            r=50,
            b=100,
            t=100,
            pad = 4
        ),
        scene=scene
    )

    # create figure resampler, for more efficient plotting
    fig = FigureResampler(go.Figure(layout=layout))

    # create the sun
    if plot_sun:
        sun = create_sun(sun_size, '#ffff00', opacity=0.2) # Sun
        fig.add_trace(sun)

    # create the photon
    photon_trace = create_photon(photon_track[:,0], photon_track[:,1], photon_track[:,2], clr='red')

    # add traces to the figure
    fig.add_trace(photon_trace)

    scene = go.layout.Scene(
        xaxis=dict(nticks=8, range=[-sun_size - 1_000_000, sun_size + 1_000_000],),
        yaxis = dict(nticks=8, range=[-sun_size - 1_000_000, sun_size + 1_000_000],),
        zaxis = dict(nticks=8, range=[-sun_size - 1_000_000, sun_size + 1_000_000],)
    )
    
    fig.update_xaxes(range=[-sun_size - 1_000_000, sun_size + 1_000_000])
    fig.update_yaxes(range=[-sun_size - 1_000_000, sun_size + 1_000_000])

    fig.update_layout(scene=scene)

    # show the plot
    if static is not None:
        print("Exporting to html...")
        # fig.write_html(f"{static}.html")

        print("Exporting to png...")
        fig.write_image(f"{static}.png", 'png')
    else:
        fig.show_dash(mode=mode)

def pyplot_photon_path(photon: Photon):
    """
    Plot the photon path in matplotlib
    Not an interactible chart, but is static and is easier to show large datasets.

    Parameters
    ----------
    photon: Photon
        a photon object
    """

    photon_track = np.array(photon.history)
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax = plt.axes(projection='3d')
    ax.scatter3D(photon_track[:,0], photon_track[:,1], photon_track[:,2])

    plt.show()

In [None]:
# Plot average model


In [None]:
# Plot linear decreasing
data = np.load("photon_histories/linear_density.npy")
# plot_photon_and_sun(data, plot_sun=True, static="images/large-scale-test", cut_down_size=(0, 100_000))
# max size of cut down has been just under 6_000_000
plot_photon_and_sun(data, plot_sun=True, cut_down_size=(0, 5_500_000))

In [None]:
# Plot discrete model
plot_photon_and_sun(p_2.history, plot_sun=True, cut_down_size=(0, 5_500_000))