### Install Package

In [None]:
# install dependency libraries
!pip install pytest-cov 
!pip install pytest
!pip install torch 
!pip install astroquery 
!pip install astropy
!pip install scipy
!pip install requests

In [None]:
# install our library
!pip install -i https://test.pypi.org/simple/ skywalker-team23==0.0.8

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
from skywalker_team23.data_retrieval import GetData
from skywalker_team23.preprocessing import CleanSpectralData
from skywalker_team23.data_augmentation import DataAugmentation
from skywalker_team23.cross_match import CrossMatch_Gaia
from skywalker_team23.classification import SpectralClassifier
from skywalker_team23.visualization import SpectralVisualizer

### Demo of Data_retrieval and Cross_matching

First, we use the GetData module to query the SDSS and download data. The retrieve_sdss_data method will download the fluxes and also return the metadata like class, ra, dec, etc.

In [None]:
data_retriever = GetData()
df = data_retriever.retrieve_sdss_data(
    sql_query="""
                SELECT TOP 20 s.fiberid, s.plate, s.mjd, s.run2d, s.class
                FROM PhotoObj AS p
                JOIN SpecObj AS s ON s.bestobjid = p.objid
                """
)
#df.to_csv("demo_data.csv")


In [None]:
print(df.columns)
display(df)

The user can also pass in custom queries

In [None]:
data_retriever.execute_custom_query("""SELECT TOP 10 objID FROM PhotoObj AS p""")

With this dataframe, we can also crossmatch the sky patch with Gaia using the corss_match module

In [None]:
xmatch = CrossMatch_Gaia()

# Get the ra and dec from GetData in the SDSS database
target_ra = df["ra"][1]
target_dec = df["dec"][1]
angular_range=20

# lookup for Gaia source ids that are within a certain angular distance
print(f"target ra: {target_ra}, target dec: {target_dec}")
match_df, source_ids = xmatch.match_coords(target_ra, target_dec, angular_range)
display(match_df)

We can then use those matches to get relevent astrophysical parameters from Gaia matching that part of the sky

In [None]:
astro_params = xmatch.get_astrophysical_params(source_ids)
display(astro_params)

ceph_params = xmatch.get_cepheid_star_param(source_ids)
display(ceph_params)


### Demo of pairing with preprocessing

To make this data useful, we call the preprocessing module. It converts the raw data query dataframe into a dataframe that is more easily used for computational work. This pipeline is used in the backbone of the classifier code.

In [None]:
# the module can be initiated with the retrieved data or with a path to a saved csv file. Here I just show using the data
cleaned_data = CleanSpectralData(dataframe=df)

# First lets get the initial data that is aligned and reformatted but not edited in other ways
_ = cleaned_data.align_wavelengths(num_wl = 1000)
# alignemnt can also be done calling
#lam = np.linspace(4000, 6000, 100)
#cleaned_data.interpolate_flux(lam)

init_data = cleaned_data.data.copy()
init_lam = init_data["lam"]
init_flux = init_data["flux"]

# We can call to remove outliers by the IQR method
data2 = cleaned_data.remove_flux_outliers_iqr().copy()
lam2 = data2["lam"]
flux2 = data2["flux"]

# And we can apply a redshift correction 
data3 = cleaned_data.correct_redshift()
lam3 = data3["lam"]
flux3 = data3["flux"]

# We can also get the normalized flux and an inferred continuum
normalized_fluxes = cleaned_data.get_normalize_flux(update_df=True)
inferred_cont = cleaned_data.get_inferred_continuum()

fig, ax = plt.subplots(1,5, figsize=(20, 4))
for i in range(2):
    ax[0].plot(init_lam[i], init_flux[i], '-')
    ax[1].plot(lam2[i], flux2[i], '-')
    ax[2].plot(lam3[i], flux3[i], '-')
    ax[3].plot(lam3[i], normalized_fluxes[i], '-')
    ax[4].plot(lam3[i], inferred_cont[i], '-')

titles = ["Initial data algned", "Remove outliers", "redshift corrected", "normalized", "inferred continuum (norm)"]
for i, axi in enumerate(ax.flatten()):
    axi.set_xlabel("Wavelength 1/A")
    axi.set_title(titles[i])
ax[0].set_ylabel('10$^{-17}$ ergs/cm$^2$/s/\AA')

## Module for classification

Then, we use the classification module to distinguish between Stars, Galaxies, and QSOs.

In [None]:
# We are going to want more data so lets load more 
# The classifier can query the data on its own like:
classifier = SpectralClassifier(datapath=None, num_spectra=10, num_wl=500, classifier_layers=[32, 32])

# but instead its better to load a large preloaded cs
classifier = SpectralClassifier(datapath="./demo_data.csv", num_spectra=10, num_wl=500, classifier_layers=[32, 32])


In [None]:
# Call the classifier fit function
classifier.train(epochs=100,verbose=True)
classifier.plot_train_accuracy()
plt.show()

In [None]:
# Try out the predict module on some new data
# Load some real spectra instead of making up curves
data_retriever = GetData()
load_num_spec = 20
df = data_retriever.retrieve_sdss_data(
    sql_query=f"""
        SELECT TOP {load_num_spec} s.fiberid, s.plate, s.mjd, s.run2d, s.class
        FROM PhotoObj AS p
        JOIN SpecObj AS s ON s.bestobjid = p.objid
        """
)
lam = df["lam"].values
flux = df["flux"].values

prob, predicted_label = classifier.transform_predict(lam, flux)
print("Predicted Labels: ", predicted_label)

In [None]:
true_labels = df["class"].values
print("Correct Prediction Boolean: ", [ predicted_label[i]==true_labels[i] for i in range(len(predicted_label))])

## Module for Data Augmentation

Here, we present the data augmentation module that is able to compute derivatives as well as fractional derivatives and append them to each preprocessed spectra. These new features can be used for future analysis on spectral data.

In [None]:
# DataAugmentation
augmentor = DataAugmentation(dataframe=cleaned_data.data)
augmentor.process_data()
augmented_data = augmentor.data

# Print out the data frame after augmentation
augmented_data.head()
        

## Module for visualization

Lastly, we show that the visualization module can create an interactive module to enable users to select plot regions and quantify the flux of spectral lines.

In [None]:
# Visualization
visualizer = SpectralVisualizer(dataframe=augmented_data)
visualizer.plot_spectral_visualization()