# Object Oriented Programming: Quick-RAIL, A Simplified RAIL Pipeline
##  2024 University of Sydney Hunstead Tutorial 4
### Bryan Scott, CIERA | Northwestern University

This exercise was originally presented as part of LSST-DA Data Science Fellowship Program Session 21: Software Engineering and Databases held at the University of Illinois Urbana-Champaign, Illinois, United States. It was originally produced by Olivia Lynn, LINCC Frameworks Software Engineer and has been modified for the Hunstead Series by Bryan Scott (CIERA|Northwestern).

**R**edshift **A**ssessment and **I**nfrastructure **L**ayers (**RAIL**) is a [LINCC frameworks](https://lsstdiscoveryalliance.org/programs/lincc-frameworks/) project to develop comprehensive analysis infrastructure that will be used for validating redshifts of galaxies from photometry in the LSST Dark Energy Science Collaboration (DESC) and more broadly within the Vera C Rubin Observatory Community.

As mentioned in Lecture 1, choice of metrics and sensitivity to errors in classification or parameter estimation are significant outstanding problems in data-driven astronomy. This notebook will walk you through how to write code implementing an analysis and validation pipeline.

This heavily references RAIL's Degradation Demo notebook, which can be found [rendered on ReadTheDocs](https://rail-hub.readthedocs.io/projects/rail-notebooks/en/latest/rendered/creation_examples/degradation-demo.html) and in [notebook form on GitHub](https://github.com/LSSTDESC/rail/blob/main/examples/creation_examples/degradation-demo.ipynb).

We will work in this notebook with redshift data conditioned on photometric measurements in each of the LSST bands. The goal of this notebook is to understand how object oriented programming works and how the core concepts can be used to write pipelines for working with realistic data.

In [None]:
# imports:
from numbers import Number
import os

import pickle
import matplotlib.pyplot as plt

## Read in truth data generated by PZFlow

pzflow is a package developed by John Franklin-Crenshaw (University of Washington) and collaborators for the modelling of redshift probability distributions with normalizing flows. By sampling from such a distribution, one can generate conditional probability distributions of redshifts on colors. Since pzflow does not currently support ARM processors, I have done this for you and saved a truth table that you can read in here: 

In [None]:
with open('samples_truth.pickle', 'rb') as handle:
    samples_truth = pickle.load(handle)

In [None]:
samples_truth

## QUAIL Base Classes

We'll be using a few highly simplified versions of RAIL classes, namely:
- QuailStage (accompanied by a NothingStage, to demonstrate how we'll inherit for our other degrader stages)
- DataStore
- Pipeline

In [None]:
from abc import ABC, abstractmethod

class QuailStage(ABC):
    """A class for the QuailStage stage."""

    def __init__(self, name):
        """Constructor.
        
        Parameters
        ----------
        name : str
            The (human-readable) name of the stage (this will be used in the Pipeline's __repr__).
        """
        self.name = name
        self.data_in = None
        self.data_out = None
        
    # @abstractmethod is called a decorator - these change the behavior of functions without changing the source code
    # this one says that the run method will be implemented in subclasses 

    @abstractmethod 
    def run(self):
        """To be implemented in subclasses. This should set the data_out attribute."""
        pass

In [None]:
class NothingStage(QuailStage):
    """A stage that does nothing."""
    def __init__(self, name):
        """Constructor.
        
        Parameters
        ----------
        name : str
            The name of the stage.
        """
        super().__init__(name)

    def run(self):
        """Run the stage."""
        self.data_out = self.data_in

## Degrader: ErrorModel

We start by creating an incredibly naive error model as a stand-in for the LSSTErrorModel. To do this, copy the NothingStage Example from above and implement a constructor and apply_errors method. The apply_errors method should loop over the column names (u, g, r, i, z, y) and return the error as a uniform fraction of the data_in for each column. Remember that data_in is an attribute of the QuailStage class, which the ErrorModel stage will inherit from. This will generate u_err, g_err, etc. columns for us, but there's no scientific basis for the values it generates.

Once you have done that, implement the run method that was left abstract in QuailStage. The doc-string for the run abstract method tells you what this should return. 



In [None]:
class BadErrorModel(QuailStage):
    """A stage that applies bad errors."""
    def __init__(self, name):
        """Constructor.
        
        Parameters
        ----------
        name : str
            The name of the stage.
        """
        super().__init__(name)

    def apply_errors(self):
        """Apply the errors to the input data.
        
        Returns
        -------
        errors : pandas.DataFrame
            The input data with errors applied.
        """
        # create a copy of the input data
                
            
        errors = self.data_in.copy()
        
        # you implement this
        
        for col in 
        
            # your code goes here  
            
        # reorder the columns
        errors = errors[["redshift", "u", "u_err", "g", "g_err", "r", "r_err", "i", "i_err", "z", "z_err", "y", "y_err"]]
        return errors

    def run(self):
        """Run the stage."""
        self.data_out = self.apply_errors()

    def plot(self):
        """Plot the truth data and the errors."""
        if self.data_out is None:
            raise ValueError("You must run the stage first.")
        
        fig, ax = plt.subplots(figsize=(5, 4), dpi=100)

        for band in "ugrizy":
            # pull out the magnitudes and errors
            mags = self.data_out[band].to_numpy()
            errs = self.data_out[band + "_err"].to_numpy()

            # sort them by magnitude
            mags, errs = mags[mags.argsort()], errs[mags.argsort()]

            # plot errs vs mags
            ax.plot(mags, errs, label=band)

        ax.legend()
        ax.set(xlabel="Magnitude (AB)", ylabel="Error (mags)")
        plt.show()

## Making a DataStore Object

At each stage of the pipeline, we will want to keep track of the state of the data we are working with. In this example, we will not modify the colors or redshifts directly, but in general, error model stages will modify the data so that estimators can be applied to determine how each source of error impacts scientific results. 

Write a data store object. 

This should have a __getattr__ method that will allow you to access parameters of this class as if they were keys (I've filled this in for you) and a __repr__ method that will print out useful information for instantiated objects.

In [None]:
class DataStore(dict):
    def __init__(self):
        dict.__init__(self)
    
    def __getattr__(self, key): # Code copied from RAIL's DataStore 
        """Allow attribute-like parameter access"""
        try:
            return self.__getitem__(key)
        except KeyError as msg:
            # Kludge to get docstrings to work
            if key in ["__objclass__"]:  # pragma: no cover
                return None
            raise KeyError from msg
    
    # you fill this in 
    
    def __repr__(self): 
        """Prints the DataStore keys and shapes in a human-readable format."""
        s = "DataStore\n"
        for key in self.keys():
            s += f" # you fill this in \n"
        return s
    


## Making the Pipeline

Next we will write a pipeline class. This should have a run method that takes in data in a DataStore Object, the classes and names of the stages you wish to run, and applies the ErrorModel stage we have written, and updates the output_data key of the DataStore. It should also have methods that return the stages in the pipeline. 

To write the run method, loop over the stages defined when instantiating the pipeline object (these are set by the constructor), and at each stage, update the data_in for the current stage to be the current_data variable. Then run the stage and set the name_data (where name is the name of the stage being run) to the data_out returned from that stage. Then update the current_data to be the data_out from this stage before the next iteration of the loop. 

In [None]:
class Pipeline:
    def __init__(self, data_store, stages):
        self.data_store = data_store
        self.stages = stages

    def run(self):
        current_data = self.data_store["input_data"]
        
        # you implement this 
        
        for stage in self.stages:
            print(f"Running stage: {stage.name}")
            
            # your code goes here 
            
        self.data_store["output_data"] = current_data

    def __repr__(self):
        s = "Pipeline\n"
        
        # you implement this 
        
        for stage in self.stages:
            s += # your code goes here 
        return s
    
    def get_stage(self, stage_name):
        for stage in self.stages:
            if stage.name == stage_name:
                return stage
        print(f"Stage {stage_name} not found in pipeline.")
        return None

    def plot(self):
        pass

    def save(self):
        pass


## Instantiate and run the pipeline

In [None]:
DS = DataStore()
DS["input_data"] = samples_truth

# you implement the stages definition

stages = [
         # what goes in this list? 
]

pipeline = Pipeline(DS, stages)

In [None]:
# Our DataStore before running the pipeline

DS

In [None]:
pipeline.run()

In [None]:
# Our DataStore after running the pipeline

DS

## Looking at the data

Poke around and see the changes you've made! Note that overloading the `__getattr__` in our DataStore class lets us access data via `DS.key` syntax.

In [None]:
DS.keys()

In [None]:
DS.bad_error_model_data

## Plotting

We can access the plot method of our error model like so:

In [None]:
pipeline.get_stage("bad_error_model").plot()  # Compare this to the plot in the RAIL degrader notebook,
                                              # where we see very different errors!