# More Advanced Python Workshop

### Problem Definition
Neutron stars are ultra-compact astronomical objects that are left remaining after a Core Collapse Supernova occurs. When massive stars are in binaries, multiple neutron stars can be left behind in a binary system. These binary neutron stars spiral around each other and eventually merge. But, because they are so dense, the binary neutron stars emit gravitational waves that we are able to observe with detectors on Earth.

Kilonova is the name for the optical light we observe following a binary neutron star merger. Because a binary neutron star merger is a rapid event, we can watch the optical emission following the merger evolve over human time scales. This is called a "light curve" and is a plot of the light we see, or "flux", as a function of time.

In the directory where this jupyter notebook is I've provided two things that we will use throughout this tutorial:
1. A comma-separated file of the light curve of the only known kilonova, called AT2017gfo. This file is called `AT2017gfo-lightcurve.csv`
2. A directory of a grid models of kilonova light curves. These models are from [Kasen+17](https://arxiv.org/abs/1710.05463)

In this notebook, we will reproduce the model fitting results of [Kasen+17](https://arxiv.org/abs/1710.05463) by searching for the best model in the model grid of the AT2017gfo light curve! This will tell us unique properties about the kilonova that are otherwise not measureable.

### Learning Goals
By the end of going through this jupyter notebook I am hoping that you have built the following skills:
1. A better understanding of reading code documentation and applying the functions you find in it
2. Improved plotting skills, take some time to make your matplotlib plots pretty!
3. How to "pair code", or work with a partner to write code
4. How to use Astropy Tables and Units (these are super super useful for astronomy!)  

### Imports

In [None]:
!pip install "numpy<2" matplotlib astropy h5py lightcurve-fitting scipy extinction

In [None]:
import os # for operating system related commands
import glob # for gathering lots of files matching a common path

# numpy and matplotlib, cause duh
import numpy as np
import matplotlib.pyplot as plt

# some astropy things
from astropy.table import Table
from astropy.io.misc.pyarrow.csv import read_csv
from astropy.cosmology import Planck18 as cosmo
from astropy import units as u

# some things for parsing the model files
import h5py
from lightcurve_fitting.filters import filtdict
from lightcurve_fitting.lightcurve import LC
from scipy.signal import resample
import extinction

### Part 1: Reading in the AT2017gfo light curve

Here is the documentation for reading in a CSV file into an astropy Table: https://docs.astropy.org/en/stable/api/astropy.io.misc.pyarrow.csv.read_csv.html

In [None]:
at2017gfo_table = # YOUR CODE HERE

print(at2017gfo_table.columns)

### Part 2: Plotting the light curve

You should now have the AT2017gfo Table loaded into a variable called `at2017gfo_table`. Here is a description of the relevant column names (you can safely ignore the others):
* `MJD` --> The Modified Julian Date, or MJD for short
* `Phase` --> The time since discovery of the transient event
* `Instrument` --> The instrument used for measuing this flux value
* `Telescope` --> The telescope used for measuring this flux value
* `Band` --> The filter used on the telescope for measuring this flux. We will focus on the "r-band", or the "red" filter, for this notebook.
* `mag` --> The magnitude value for this flux. If you do not know what this is you should take some time reading [this wikipedia page](https://en.wikipedia.org/wiki/Magnitude_(astronomy)) and/or asking one of the grad coorinators.
* `e_mag` --> The error on the magnitude.
* `Ref` --> The reference for this flux value

To access a column in the astropy table you do the following:
```
<table-name>["<column-name>"]
```
So, if we want to access the `MJD` column we would do
```
at2017gfo["MJD"]
```

In the following cell, make sure you understand by accessing the `mag` column

In [None]:
mag_column = # YOUR CODE HERE

Now that we know how to access columns, let's briefly discuss boolean indexing. This allows you to pass boolean arrays as indices to the table to access only rows that satisfy those conditions. 

As an example, let's say we want to access all of the data that was taken with the "Magellan" telescope, we would do the following:
```
magellan_data = table[table["telescope"] == "Magellan"]
```

We will now do something similar to access all of the rows that have the `Band == r`. Please do this in the following cell and save it in a variable called `data`:

In [None]:
data = # YOUR CODE HERE

Finally, now that we only have the r-band data for AT2017gfo, we want to plot a light curve for this event. As a reminder, that will be the `mag` column as a function of the `Phase` column. Please create this plot in the next cell:

_Hint 1_: Remember that magnitudes are "backwards". Or, in other words, a smaller magnitude means brighter. What should we do to the y-axis to make this plot more intuitive?)

_Hint 2_: Remember that there is a column with errors on the magnitude `e_mag`. You should plot these as errorbars. Checkout the docs for matplotlib `errorbar`.

_Bonus_: Plot the light curve colored by the telescope name with a nice legend :)

In [None]:
fig, ax = plt.subplots()

ax.set_xlabel("Days Since Discovery")
ax.set_ylabel("Magnitude")
ax.invert_yaxis()

# YOUR CODE HERE

##############################################################################################################
# PAUSE, CHECKPOINT 1
Please show your light curve plot from the previous part to at least one graduate coordinator. You should also send this in slack for everyone to see your hard work!
###############################################################################################################

### Part 3: Reading and plotting the light curve models

All of the models in the Kasen+17 simulation grid are stored in files called [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format). While understanding how to work with these files is an important skill, learning it is out of the scope of this notebook. Furthermore, the Kasen+17 models are stored in three dimensions: 1) frequency; 2) time; 3) flux. This can make it difficult to understand how to parse them into a useable light curve (It took me an entire afternoon when I was originally doing this!). As a result, I've provided a python convenience function below for reading and parsing _a single model file_.

Reading code is an important skill so you should take a few minutes and read this function to understand how it works. If you have any questions on how this function works, please ask a grad coordinator.

In [None]:
def read_kasen17_model(model_file:str, filtname="r", redshift=0.00984, ebv=0.109) -> Table:
    """
    Takes in a path to a model file as a string and returns an astropy table with the following columns:
    - filter --> The filter name
    - mag --> The magnitude corresponding to this flux
    - dt --> The days since discovery, usually called "phase"

    Example:
    To read in a file from Kasen+17 use
    ```
    lc = read_kasen17_model("kasen_kilonova_models_2017/knova_d1_n10_m0.020_vk0.10_Xlan1e-3.0.h5")
    ```

    Args:
        model_file (str): The path to the Kasen+17 model file
        filtname (str): The filter name to grab the data for. This is important for 
                        applying the proper transmission curve.
        redshift (float): The redshift to the transient that we will fit these models to.
                          Default is 0.00984 for AT2017gfo (from 
                          https://www.wis-tns.org/object/2017gfo)
        

    Returns:
        An astropy Table with the data stored in the model
    """

    # read in the file
    fin = h5py.File(model_file,'r')

    # parse various columns in the file, including the frequency 
    # (but convert to THz, hence the conversion using 1e-12)
    nu = np.array(fin['nu'],dtype='d')*1e-12

    # and the time
    times = np.array(fin['time'])
    times = times/3600.0/24.0

    # and the specific luminosity 
    Lnu_all = np.array(fin['Lnu'],dtype='d')

    # now generate a boolean array that only takes the data in the frequency range corresponding to `filtname`
    freq_range = (
        (filtdict["r"].freq_eff - filtdict["r"].dfreq).value, 
        (filtdict["r"].freq_eff + filtdict["r"].dfreq).value
    )
    where_filt = np.where(
        (nu > min(freq_range)) * 
        (nu < max(freq_range))
    )[0]

    # we now need the transmission curve for this filter
    # but we need to "downsample" it to the length of the times array, 
    # we do this using scipy.signal.resample
    trans_curve = filtdict[filtname].trans
    trans_resampled_idx = np.array([
        np.argmin(
            np.abs(
                t-trans_curve["freq"]
            )
        ) for t in nu
    ])
    trans_resampled = trans_curve["T"][trans_resampled_idx]
    
    # finally, we integrate ("sum") along the frequency axis in the luminosity array
    # while simultaneously applying the transmission function for this filter
    Lnu = np.sum(
        (Lnu_all * trans_resampled)[:,where_filt], 
        axis=1
    )

    # this gave us the light curve in luminosity space, but our data is in magnitude!
    # so, we ust convert using the luminosity distance
    # but, that means we need to find the lumdist from the redshift passed in,
    # this is where the astropy cosmology package is super useful!
    lumdist = cosmo.luminosity_distance(redshift)

    # now we can convert to magnitudes using the normal equations
    # (That are already programmed for you in this lightcurve_fitting package)
    Fnu = (Lnu*u.erg/u.s/u.Hz /(4*np.pi*lumdist**2)).to(u.erg/u.s/u.cm**2/u.Hz)
    zp = filtdict["r"].fnu*u.W/u.m**2/u.Hz
    lc = Table({
        "dt":times,
        "mag":-2.5*np.log10(Fnu/zp),
        "filter":[filtname]*len(times)
    })

    # convert backwards from a LC object to it's superclass "Table"

    # "redden" the models to simulate MW extinction
    R_V = 3.1
    A_V = R_V * ebv

    filt_wave_eff = (filtdict[filtname].freq_eff).to(u.AA, equivalencies=u.spectral()).value
    dust_law_mw = extinction.ccm89(np.array([filt_wave_eff]), A_V, R_V)
    lc_tab["_unextincted_mag"] = lc_tab["mag"]
    lc_tab["mag"] = lc_tab["mag"] + dust_law_mw[0]
    
    return lc 

Now, use this function to read in _at least one_ file in the `kasen_kilonova_models_2017` directory. 

_Hint 1_: Use the python `help(funcname)` function to read the docstring for this function!

In [None]:
"""TODO: YOUR CODE HERE"""

### Part 4: Plot a Model

Now that you've read in a model, plot the light curve for it! As a reminder, this will be `mag` vs. `dt`. Since this is simulated data (and I made up the `dmag` column to get the conversion to work with `lightcurve_fitting`), we don't need to plot any errorbars on the data!

In [None]:
fig, ax = plt.subplots()

ax.set_xlabel("Days Since Discovery")
ax.set_ylabel("Magnitude")
ax.invert_yaxis()

# YOUR CODE HERE

##############################################################################################################
# PAUSE, CHECKPOINT 2
Please show your light curve plot from the previous part to at least one graduate coordinator. You should also send this in slack for everyone to see your hard work!
###############################################################################################################

### Part 5: Sum of the Squares Residual

Now that we know how to read in both the Kasen+17 model files and the AT2017gfo light curve, we need to write some code to find the best fitting model, automatically, The simplest way to do this is to use a least squares algorithm: https://en.wikipedia.org/wiki/Least_squares.

If you go to that wikipedia link, you will see the following equation:
$$
S = \Sigma_i r^2_i
$$
where $r = data-model$ is the residual and $S$ is the sum of the least square residuals.

Please complete the following `sum_square_residual` function that I've begun for you:

_Hint 1_: Assume that both of the inputs are numpy arrays. This means that you **should not** need to write any loops!

_Hint 2_: Check out the numpy docs to see if they have anything for sums

_Bonus_: Write a docstring for this function (see the example of a numpy formatted docstring above in `read_kasen17_model`.

In [None]:
def sum_square_residual(data:np.ndarray, model:np.ndarray) -> np.ndarray:
    residual = # TODO
    residual_squared = # TODO
    sum_residual_squared = # TODO
    return sum_residual_squared

Since we need this function working for later parts, please run the following test code to make sure that it does work as expected! These are tests, `assert` means make sure the following statment is True, otherwise throw an AssertionError.

If it throws any errors then you need to return to the above cell and fix some stuff.

In [None]:
# TODO: Make sure this cell does not throw any errors when run
assert sum_square_residual(np.array([1,2,3]), np.array([4,5,6])) == 27
assert sum_square_residual(np.array([4,5,6]), np.array([1,2,3])) == 27
assert sum_square_residual(np.array([1,2,3]), np.array([1,2,3])) == 0

### Part 6: Finding the best Kasen+17 model

Now that we have a function to compute the sum of the square residual, we need to apply this to all of the models in `kasen_kilonova_models_2017` and find the model with the _least_ square residual. Please do this below:

_REMINDER_: This is supposed to be hard! Give it a try, since we believe that coming up with algorithms like this is important. Butm we will reveal some more hints over time to help get you all started. 

_Hint 1_: You can use the built-in `glob` library (best library name in python IMO!) to get a list of all of the file names in the `kasen_kilonova_models_2017` directory
```
all_model_files = glob.glob("kasen_kilonova_models_2017/*h5")
```

In [None]:
# TODO: Find the model in the kasen_kilonova_models_2017 that has the lowest least squares residual

### Part 7: Plot the best fitting model over top the light curve

This should basically just be a combination of your other plotting code from above. Take some time to make it pretty and then send it in slack!

In [None]:
fig, ax = plt.subplots()

ax.set_xlabel("Days Since Discovery")
ax.set_ylabel("Magnitude")
ax.invert_yaxis()

# YOUR CODE HERE
# TODO: Plot the light curve data
# TODO: Plot the model on the same axis overtop of the light curve data

##############################################################################################################
# PAUSE, CHECKPOINT 3
Please show your light curve plot from the previous part to at least one graduate coordinator. You should also send this in slack for everyone to see your hard work!
###############################################################################################################

### Extra Challenge

If you've made it this far, congratulations! You are clearly a wizard at programming :) Here's a challenge for you to ponder and attempt:

We are currently only finding the best fit model to the r-band light curve. However, there are a bunch of other filters in the `AT2017gfo-lightcurve.csv` file. In reality, we should be jointly fitting the light curve in three dimensional space (time, frequency, and flux/magnitude), using _all_ of the data in that file to find the best fit model across the entire spectrum. 

This is quite difficult, but I'd like you to try to do this! Good luck!