In [181]:
#Utility
import numpy as np
import scipy.integrate

#Plotting & animation
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FFMpegWriter

#Astro
import astropy.constants as const
import astropy.units as u

#GUI
from __future__ import print_function
import ipywidgets as widgets
from IPython.display import display
from ipywidgets import interact, interactive, fixed, interact_manual
from ipywidgets import GridspecLayout

# %matplotlib osx 
# ^ UNCOMMENT THIS LINE IF USING MAC

%matplotlib qt 
# ^ UNCOMMENT THIS LINE IF USING WINDOWS

### Class Design Time! (Revisions probably needed) (Classes are in progress below!!)

Things we will need to represent:
- Planets
- Stars
- Solar system (which contains a list of planets, as well as a star? Maybe that's the way to do it. Would make potential future expansion easier if we ever want to add binary system/etc/other system with multiple stars)
- Transits (should just be a list representing flux over a time interval)
- Vectors (data abstraction -> make a class!)
- User inputs (keep this organized though)
- GUI: Widgets seems like possible solution? Tutorial: https://towardsdatascience.com/bring-your-jupyter-notebook-to-life-with-interactive-widgets-bc12e03f0916

- How to handle time? (By days? Can users speed up/slow it down?)

In [3]:
class Vector: 
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __str__(self):
        return f'<{self.x}, {self.y}>'
    
    def length(self):
        return math.sqrt(pow(self.x, 2) + pow(self.y, 2))
    
    def angle(self):
        return math.atan2(self.y, self.x)

In [38]:
# Physics stuff. Functions for useful equations go here

    '''
    Represents the differential equation for the law of gravitation.
    rv: an array containing the positions and velocities of 2 objects
    t: time
    '''
def gravitation(rv, t, mass_star, mass_planet):
    r1 = rv[:2] #* u.au
    v1 = rv[2:4] #* (u.meter/u.second)
    r12 = np.linalg.norm(r1)
    
    dv1bydt = mass_planet * (-r1)/(pow(r12, 3))
    dr1bydt = v1
    
    derivatives = np.concatenate((dr1bydt, dv1bydt))
    return derivatives

In [177]:
class Planet:
    
    '''
    radius: radius of planet in Earth radii
    mass: mass of planet in Earth masses
    distance: planet's distance from star (au) [distance_x, distance_y]
    velocity: planet's velocity (m/s) [vx, vy]
    '''
    
    def __init__(self, mass, radius, distance, init_velocity, color, implemented=True):
        self.mass = mass #* u.M_earth
        self.distance = distance #* u.au
        self.velocity = init_velocity #* (u.meter/u.second)
        self.implemented = implemented
        self.color = color
        self.radius = radius
        
    def __str__(self):
        return f'Planet: Mass={self.mass}, Radius={self.radius}, Distance={self.distance}, Velocity={self.velocity}'
    
    def set_mass(self, mass):
        self.mass = mass
    
    def set_distance(self, distance):
        self.distance = distance
    
    def set_implemented(self, status):
        self.implemented = status
    

In [178]:
class Star:
    
    '''
    mass: mass of star in solar masses
    position: list representing x, y position of star (defaults to origin)
    velocity: vector representing x and y components of star's velocity (defaults to 0; ie. stationary)
    radius is in Earth radii
    '''
    def __init__(self, mass, radius, position=[0,0], velocity=[0, 0], name='Sol'):
        self.mass = mass #* u.M_sun
        self.position = position #* u.au
        self.velocity = velocity #* (u.meter/u.second)
        self.name = name
        self.radius = radius
    
        
    def __str__(self):
        return f'Star {self.name}. Mass={self.mass}, Radius={self.radius}'   

In [179]:
class SolarSystem: 
    
    def __init__(self, star, planets=[]):
        self.star = star
        self.planets = planets
    
    def get_planet_masses(self):
        masses = []
        for planet in self.planets:
            if planet.implemented:
                masses.append(planet.mass)
        return masses
    
    def get_planet_distances(self):
        distances = []
        for planet in self.planets:
            if planet.implemented:
                distances.append(planet.distance[0])
        return distances
    
    def most_massive_planet(self):
        masses = self.get_planet_masses()
        largest = max(masses)
        for planet in self.planets:
            if planet.implemented and planet.mass == largest:
                return planet
        return None
    
    def farthest_planet(self):
        distances = self.get_planet_distances()
        farthest = max(distances)
        for planet in self.planets:
            if planet.implemented and planet.distance[0] == farthest:
                return planet
        return None
            
        

## SETTING UP SOLAR SYSTEM

In [280]:
p1 = Planet(3 * 10**(-5), 1.0, [1, 0], [0, 0.75], color='blue')

In [281]:
p2 = Planet(5 * 10**(-5), 3.0, [2, 0], [0, 0.75], color='green')

In [282]:
p3 = Planet(7 * 10**(-5), 6.0, [2.5, 0], [0, 0.75], color='purple')

In [283]:
p4 = Planet(10 * 10**(-5), 8.0, [3, 0], [0, 0.75], color='cyan')

In [284]:
p5 = Planet(13 * 10**(-5), 10.0, [7, 0], [0, 0.5], color='red', implemented=True)

In [285]:
sun = Star(1, 109.0)

In [286]:
# Set up solar system object
test_system = SolarSystem(sun, [p1, p2, p3, p4, p5])

## SET PLANET PROPERTIES

In [188]:
planet_1 = widgets.Checkbox(
                value=False,
                description='Planet 1',
                disabled=False,
                continous_update=True,
                indent=False
                )

planet_2 = widgets.Checkbox(
                value=False,
                description='Planet 2',
                disabled=False,
                continous_update=True,
                indent=False
                )

planet_3 = widgets.Checkbox(
                value=False,
                description='Planet 3',
                disabled=False,
                continous_update=True,
                indent=False
                )

planet_4 = widgets.Checkbox(
                value=False,
                description='Planet 4',
                disabled=False,
                continous_update=True,
                indent=False
                )

planet_5 = widgets.Checkbox(
                value=False,
                description='Planet 5',
                disabled=False,
                continous_update=True,
                indent=False
                )

mass_1 = widgets.FloatSlider(
            value=0.5,
            min=0.5,
            max=50.0,
            step=0.5,
            description='Mass:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )

velocity_1 = widgets.FloatSlider(
            min=0.1,
            max=1.0,
            step=0.05,
            description='Velocity:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
            )

distance_1 = widgets.FloatSlider(
            value=1,
            min=0.5,
            max=3.0,
            step=0.1,
            description='Distance From The Sun:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )

mass_2 = widgets.FloatSlider(
            value=0.5,
            min=0.5,
            max=50.0,
            step=0.5,
            description='Mass:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )

velocity_2 = widgets.FloatSlider(
            min=0.1,
            max=1.0,
            step=0.05,
            description='Velocity:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
            )

distance_2 = widgets.FloatSlider(
            value=1,
            min=0.5,
            max=3.0,
            step=0.1,
            description='Distance From The Sun:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )
mass_3 = widgets.FloatSlider(
            value=0.5,
            min=0.5,
            max=50.0,
            step=0.5,
            description='Mass:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )

velocity_3 = widgets.FloatSlider(
            min=0.1,
            max=1.0,
            step=0.05,
            description='Velocity:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
            )

distance_3 = widgets.FloatSlider(
            value=1,
            min=0.5,
            max=3.0,
            step=0.1,
            description='Distance From The Sun:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )
mass_4 = widgets.FloatSlider(
            value=0.5,
            min=0.5,
            max=50.0,
            step=0.5,
            description='Mass:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )

velocity_4 = widgets.FloatSlider(
            min=0.1,
            max=1.0,
            step=0.05,
            description='Velocity:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
            )

distance_4 = widgets.FloatSlider(
            value=1,
            min=0.5,
            max=3.0,
            step=0.1,
            description='Distance From The Sun:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )
mass_5 = widgets.FloatSlider(
            value=0.5,
            min=0.5,
            max=50.0,
            step=0.5,
            description='Mass:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )

velocity_5 = widgets.FloatSlider(
            min=0.1,
            max=1.0,
            step=0.05,
            description='Velocity:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
            )

distance_5 = widgets.FloatSlider(
            value=1,
            min=0.5,
            max=3.0,
            step=0.1,
            description='Distance From The Sun:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.1f',
            )

In [189]:
grid = GridspecLayout(5, 4, height='300px')
grid[0,0] = planet_1
grid[1,0] = planet_2
grid[2,0] = planet_3
grid[3,0] = planet_4
grid[4,0] = planet_5
grid[0,1] = mass_1
grid[1,1] = mass_2
grid[2,1] = mass_3
grid[3,1] = mass_4
grid[4,1] = mass_5
grid[0,2] = velocity_1
grid[1,2] = velocity_2
grid[2,2] = velocity_3
grid[3,2] = velocity_4
grid[4,2] = velocity_5
grid[0,3] = distance_1
grid[1,3] = distance_2
grid[2,3] = distance_3
grid[3,3] = distance_4
grid[4,3] = distance_5

In [203]:
def f(planet_1, planet_2, planet_3, planet_4, planet_5):
    print((planet_1, planet_2, planet_3, planet_4, planet_5))

out = widgets.interactive_output(f, {'planet_1': planet_1, 'planet_2': planet_2, 'planet_3': planet_3, 'planet_4': planet_4, 'planet_5': planet_5})

display(grid, out)

GridspecLayout(children=(Checkbox(value=True, description='Planet 1', indent=False, layout=Layout(grid_area='w…

Output()

In [223]:
planet_1.value

True

In [204]:
test_system.planets[0].implemented = planet_1.value
test_system.planets[1].implemented = planet_2.value
test_system.planets[2].implemented = planet_3.value
test_system.planets[3].implemented = planet_4.value
test_system.planets[4].implemented = planet_5.value

## SIMULATION SETUP

In [287]:
# Set up & solve ODE for motion of each planet in the solar system
# Automatically set timespan to run the simulation, so that we get at least 1 full orbit of each planet
# Sets plot dimensions automatically for simulation display

all_orbit_solutions = []
# time to run simulation should depend on how far the farthest planet is
# todo: factor to multiply by, as well as number of points, can probably be adjusted in a better way (ie. not fixed)
time_span = np.linspace(0, max(test_system.get_planet_distances()) * 300, 5000) 
axes_limits = [-max(test_system.get_planet_distances()) * 10, max(test_system.get_planet_distances()) * 5]
lightcurve = []
time_dim = []


for planet in test_system.planets:
    if planet.implemented:
        init_params = np.array([planet.distance, planet.velocity])
        init_params = init_params.flatten()
        sol = scipy.integrate.odeint(gravitation, init_params, time_span, args=(planet.mass, test_system.star.mass))
        sol_for_planet = sol[:, :2]
        all_orbit_solutions.append(sol_for_planet)

In [288]:
# Initialize writer 
metadata = dict(title='Orbit Test', artist='Matplotlib')
writer = FFMpegWriter(fps=50, metadata=metadata, bitrate=200000) # change fps for different frame rates
fig = plt.figure(dpi=200)

## ANIMATION

In [289]:
# SAVE AS MP4 (will be saved in whatever directory you are working in)
fig, (ax1, ax2) = plt.subplots(2, 1)

with writer.saving(fig, "orbit_test_18.mp4", dpi=200):
    for i in range(len(time_span)):

        # Animation of orbiting planets
        ax1.clear()
        
        for planet_idx in range(len(test_system.planets)):
            if test_system.planets[planet_idx].implemented:
                planet_sols = all_orbit_solutions[planet_idx]
                ax1.plot(planet_sols[:i, 0], planet_sols[:i, 1], color = test_system.planets[planet_idx].color, alpha=0.5) # path
                ax1.scatter(planet_sols[i,0], planet_sols[i,1],color = test_system.planets[planet_idx].color, marker="o",s=20, zorder=5) # planet

        
        ax1.scatter(0, 0, color="orange",marker="*", s=50, zorder=5) # star
        
        ax1.set_xlim(axes_limits[0], axes_limits[1])
        ax1.set_ylim(axes_limits[0], axes_limits[1])
        ax1.set_title('Solar System Animation')


        
        # Animation of lightcurve with transits
        ax2.clear()
        
        # 1. define observation of transit to be when planet moves back to where it started (position = [orig_distance, 0]) (+/- a bit)
        # 2. lightcurve value is 1.0 if no transits
        # 3. lightcurve value is 1.0 - (planet size/star ratio) (placeholder: planet mass * 100 for test)
        
        flux_level = 1.0
        for planet_idx in range(len(test_system.planets)):
            if test_system.planets[planet_idx].implemented:
                planet_sols = all_orbit_solutions[planet_idx]
                dist = test_system.planets[planet_idx].distance
                loc = [planet_sols[i, 0], planet_sols[i, 1]]
                if abs(loc[0] - dist[0]) < 0.1 and abs(loc[1]) < 0.1:
                # Transit depth = (radius of planet/radius of star)^2
                    flux_level -= pow((test_system.planets[planet_idx].radius/test_system.star.radius), 2)
        
        lightcurve.append(flux_level)
        time_dim.append(i)
            
        ax2.scatter(i, flux_level, color = 'orange', marker = 'o', s=10, zorder = 5)
        ax2.plot(time_dim, lightcurve, color='red', alpha=0.3)
        
        
        ax2.set_xlim(0, len(time_span))
        ax2.set_ylim(0.98, 1.005) # TODO: set this automatically based on planet radii, instead of fixed
        ax2.set_xlabel('Time')
        ax2.set_ylabel('Flux')
        ax2.set_title('Lightcurve')
        
        plt.draw()
        plt.pause(0.01)
        writer.grab_frame()