# Adapting GalacticFlow code

This notebbok is meant to provide some instructions/tips on how to adapt the code for your own purposes.

For many cases it may suffice to use the existing code but in some cases you may want to make changes.

Although this code was not designed to be adaptable originally, and e.g. dpendencies are not perfectly seperated, some effort has been made to make it easier to adapt.

## The basic structure

At it's core of of course this model has a **Normalizing Flow**, most likely implemented as a pytorch `nn.Module` subclass. However there is alot of othersteps involved for GalacticFlow to properly learn the Galaxies.

As this model works with physical Data there is alot to process and prepare:

- Data loading/cleaning
- Chosing the right subset of data (Training/Validation, Milky Ways/All Galaxies, Star properties)
- **Normalizing the data**:
    - Denormalizing the data after sampling
    - Normalizing any conditions (=galactic parameters) before sampling
    - Respecting the Jacobian of the Normalization when evaluating the pdf

For this reason a **processor** class is used for all the models, to automate these steps.

Those two are then combined in the GalacticFlow API, that we recommend using (see below).

### The GalacticFlow API

For a detailed documentation of the API see `API_Workflow.ipynb`.

The API combines the **processor** and the **flow** into a single object, that can easily be loaded and saved, trained and evaluated.

Therefore all parameters (including arguments to processor and flow) must be given from initialization and are saved with the model. (See "definition dict".)
Of course the parameters stored there depend on the processor and flow used and if you change them you may need to change the definition dict and update the API dedicated loading/saving methods of processor and flow as well.

We now describe how do adapt the normalizing flow and the processor, especially in context of the GalacticFlow API.

### The Normalizing Flow

As the oldest building block it is the least well-thought out for adaptability.

If you re not interested in the API you probably want to either write your own flow from scratch and use it accordingly.
The only exception is if you only want to use a different base network or a different Coupling layer. You can pass them simply as arguments.

For using a custom flow for the API, most likely you also will want to write it from scratch, or at least change most methods.
The methods needed for the API are:
- forward(x, x_cond): The forward pass of the flow (normalizing direction), returning a tuple of:
    1. The transformed data
    2. The log-determinant of the Jacobian
    3. The prior log-probability
- backward(x, x_cond): The backward pass of the flow (sampling direction), returning a tuple of:
    1. The transformed data
    2. The log-determinant of the Jacobian
- sample_Flow(number, x_cond): Samples `n` samples from the flow with the given conditions and returns them.
- to(device): Although implemented in the base class, you may want to overwrite this to ensure everything is on the right device. (E.g. if using `torch.distributions.Normal`)
- classmethod load_API(definition): A method to load and build up the flow from a definition dictionary of the API. Returns an instance of the flow, properly setup and usable (e.g. state_dict is loaded if given).
- save_API(): A method to create a dictionary that can be used to load the flow again. Note that saving in the API already saves the hyperparameters, that were used in load_API, so you don't need to save them again. But togehter with those parameters the flow should be recreateable. (May just be a wrapper arround the `state_dict` method.)

After having defined a new class you need to also update the definition dict:

In [None]:
import copy
import torch
import torch.nn as nn

#Example flow:
class BetterFlow(nn.Module):
    def __init__(self, n_dim, n_cond, n_layers) -> None:
        super().__init__()
        self.n_dim = n_dim
        self.n_cond = n_cond
        self.n_layers = n_layers
        ...

    def forward(self, x, x_cond):
        dummy_result = x
        dummy_logdet = torch.zeros(x.shape[0])
        dummy_prior = torch.zeros(x.shape[0])
        return dummy_result, dummy_logdet, dummy_prior
    
    def backward(self, x, x_cond):
        pass

    def sample_Flow(self, number, x_cond):
        pass

    @classmethod
    def load_API(cls, definition):
        n_dim = definition["dim_notcond"]
        n_cond = definition["dim_cond"]
        n_layers = definition["n_layers"]
        flow =  cls(n_dim, n_cond, n_layers)

        #If this flow was saved before by the API i.e. a state_dict should be present:
        is_loaded = "was_saved" in definition and definition["was_saved"]

        if is_loaded:
            flow.load_state_dict(definition["flow_dict"])

    def save_API(self):
        return {"flow_dict": self.state_dict()}
    

#Then the original definition dict:
example_definition = {
    #Processor to use (str as registered in func_handle)
    "processor": "Processor_cond",
    #Processor init args
    "processor_args": {},
    #Processor get_data args here folder name with data
    "processor_data": {"folder": "all_sims"},
    #Processor cleaning_function args
    "processor_clean": {"N_min":500},
    #Flow to use (str as registered in func_handle)
    "flow": "NSFlow",
    #Flow hyperparameters
    "flow_hyper": {"n_layers":14, "dim_notcond": 10, "dim_cond": 4, "CL":"NSF_CL2", "K": 10, "B":3, "network":"MLP", "network_args":torch.tensor([128,4,0.2])},
    #Parameters for choosing the subset of the data to use:
    #cond_fn: The function that computes/determines the condition for each galaxy. (See Processor_cond.choose_subset() for details.)
    #use_fn_constructor: The function that constructs the subset of the data to use. (See Processor_cond.choose_subset() for details.)
    #Will be called with leavout_key and leavout_vals as kwargs. I.e. will leavout galaxies that have galaxy["galaxy"][leavout_key] in leavout_vals.
    #The remaining galaxies are used for training.
    #use_fn_constructor is also once called with leavout_vals=[] to construct the full dataset for comparing (i.e. include validation set)
    "subset_params": {"cond_fn": "cond_M_stars_2age_avZ", "use_fn_constructor": "construct_all_galaxies_leavout", "leavout_key": "id", "leavout_vals": [66, 20, 88, 48, 5]},
    #Parameters to processor.Data_to_flow
    #transformation_components[i] will be transformed with transformation_functions[i] and the corresponding inverse transformation is given by inverse_transformations[i]
    #transformation_logdets[i] is the logdet of the transformation_functions[i], needed in case of pdf evaluation.
    "data_prep_args": {"transformation_functions":("np.log10",), "transformation_components":(["M_stars"],), "inverse_transformations":("10**x",), "transformation_logdets":("logdet_log10",)}
}

#Then the definition dict for the adapted flow:
adapted_definition = copy.deepcopy(example_definition)
adapted_definition["flow"] = "BetterFlow"
#For this to work, the BetterFlow class needs to be registered in func_handle:
#Key: "BetterFlow", Value: BetterFlow

#NOTE: the keys "dim_notcond" and "dim_cond" are required for the API to work properly regardless if your flow uses them in this format or not.
flow_hyper_params = {"n_layers": 14, "dim_notcond": 10, "dim_cond": 4}
adapted_definition["flow_hyper"] = flow_hyper_params

### The Processor

The processor supplies alot of of functionallity that should be rather universal and useful to reuse even if not using the API.
In this case you may proceed as follows, to get a working processor:

- Subclass the `Processor_cond` class

The nexts steps depend on what you are trying to do:
Most likely you just want to change the `get_data` method to load your data accordingly and probably the `constraindata` method that cleans the data.
(The latter will possibly be re-written to be more flexible, with the cleaning supplied as function argument.)

This should be very easy, but there are some things necessary for the processor to work properly:

- Use the right data structure (see `Base_Processor_Workflow.ipynb`, `API_Workflow.ipynb` and the `Processor_cond` class)
- Properly register component and condtion names, as the processor keeps track of what components the data it is operating on has. 

See the `Processor_cond` code or the example below for more details.

In [None]:
import processing
import numpy as np
import pandas as pd
import copy

#Suppose the data we want to load is given as a csv file where the colums are the stellar properties and the rows are the stars:
# x, y, z, vx, vy, vz, mass, age, Teff, logg, feh
# <star 1>
# <star 2>
# ...
#And we have a numpy array of star numbers per galaxy, and we need to split the data into galaxies (As .npy file).
#We have an array of same length with the DM halo mass of each galaxy (As .npy file).

#Very likely you will need to change the constraindata method, as (right now) the component names and how to clean them are hard coded.

#Here is an example of how to do this, with all necessary steps:
class MyProcessor(processing.Processor_cond):
    def __init__(self):
        super().__init__()

    def get_data(self, stars, N_stars, halo_masses):

        #First load the data e.g. with pandas:
        data = pd.read_csv(stars)

        N_stars = np.load(N_stars)
        halo_masses = np.load(halo_masses)

        #Now we need to have the right data structure:
        # List of dicts one for each galaxy, keys "stars", "galaxy"
        # Stars is a pd.DataFrame with the stellar properties as already loaded
        # Galaxy is a dict with other information about the galaxy

        Galaxies = []

        for i, (n_star, halo_mass, n_star_already_loaded) in enumerate(zip(N_stars, halo_masses, np.cumsum(N_stars))):
            #Get the stars for this galaxy:
            stars = data.iloc[n_star_already_loaded - n_star : n_star_already_loaded]

            #Get the galaxy properties, e.g.:
            #Get the total mass of the stars:
            total_mass = stars["mass"].sum().values[0]
            #Asssign a unique id to the galaxy:
            unique_id = i
            galaxy = {"M_dm": halo_mass, "N_stars": n_star, "M_stars": total_mass, "id": unique_id}
            #You are free to choose whatever you want to store in the galaxy dict.
            #Note however, that the standard cleaning function assumes that the galaxy dict contains the keys "N_stars" and "M_stars" to be present.
            #And the standard use_fn for choosing the subset assumes to have an id assigned for selection train/test but you can just pass a different use_fn to the method.

            #Maybe we decide that we dont actually ar interested in logg, so we drop it:
            stars = stars.drop(columns=["logg"])

            #Now append the galaxy to the list of galaxies:
            Galaxies.append({"stars": stars, "galaxy": galaxy})

        #Now there just is one more important step:
        # Register the component names:
        used_component_names = data.drop(columns=["logg"]).columns.to_list()
        self.component_names["stars"] = used_component_names

        #If you were to also use e.g. gas you can just do the same as with stars but replacing "stars" with e.g. "gas" everywhere.

        #Now we are done and can return the list of galaxies:
        return Galaxies

    def constraindata(self, Galaxies, m_max, M_dm_interval, make_copy=True):
            #Say e.g. we are only interested in stars with mass < m_max
            #And galaxies with DM halo mass in M_dm_interval

            Galaxies_cleaned = []

            for galaxy in Galaxies:
                
                #Clean out unwanted galaxies:
                M_dm_max, M_dm_min = M_dm_interval
                galaxy_is_valid = galaxy["galaxy"]["M_dm"] < M_dm_max and galaxy["galaxy"]["M_dm"] > M_dm_min

                if galaxy_is_valid:
                    galaxy_cleaned = copy.deepcopy(galaxy) if make_copy else galaxy
                    
                    #Clean the stars:
                    stars_valid = galaxy_cleaned["stars"]["mass"] < m_max
                    galaxy_cleaned["stars"] = galaxy_cleaned["stars"][stars_valid]

                    #You may need to update some parameters
                    galaxy_cleaned["galaxy"]["N_stars"] = stars_valid.sum()
                    galaxy_cleaned["galaxy"]["M_stars"] = galaxy_cleaned["stars"]["mass"].sum().values[0]

                    #You may also want to rotate the galaxies, such that their preferred axes are aligned
                    galaxy_cleaned["stars"] = processing.rotate_galaxy_xy(galaxy_cleaned["stars"], quant=0.9)

                    Galaxies_cleaned.append(galaxy_cleaned)

                #(Note: You may need to swap Galaxy cleaning with star cleaning if e.g. you want to clean out galaxies with too few stars.)

            return Galaxies_cleaned

If you want to use the API, you will need to update the definition dict:

In [None]:
import copy
import torch
#Then the original definition dict:
example_definition = {
    #Processor to use (str as registered in func_handle)
    "processor": "Processor_cond",
    #Processor init args
    "processor_args": {},
    #Processor get_data args here folder name with data
    "processor_data": {"folder": "all_sims"},
    #Processor cleaning_function args
    "processor_clean": {"N_min":500},
    #Flow to use (str as registered in func_handle)
    "flow": "NSFlow",
    #Flow hyperparameters
    "flow_hyper": {"n_layers":14, "dim_notcond": 10, "dim_cond": 4, "CL":"NSF_CL2", "K": 10, "B":3, "network":"MLP", "network_args":torch.tensor([128,4,0.2])},
    #Parameters for choosing the subset of the data to use:
    #cond_fn: The function that computes/determines the condition for each galaxy. (See Processor_cond.choose_subset() for details.)
    #use_fn_constructor: The function that constructs the subset of the data to use. (See Processor_cond.choose_subset() for details.)
    #Will be called with leavout_key and leavout_vals as kwargs. I.e. will leavout galaxies that have galaxy["galaxy"][leavout_key] in leavout_vals.
    #The remaining galaxies are used for training.
    #use_fn_constructor is also once called with leavout_vals=[] to construct the full dataset for comparing (i.e. include validation set)
    "subset_params": {"cond_fn": "cond_M_stars_2age_avZ", "use_fn_constructor": "construct_all_galaxies_leavout", "leavout_key": "id", "leavout_vals": [66, 20, 88, 48, 5]},
    #Parameters to processor.Data_to_flow
    #transformation_components[i] will be transformed with transformation_functions[i] and the corresponding inverse transformation is given by inverse_transformations[i]
    #transformation_logdets[i] is the logdet of the transformation_functions[i], needed in case of pdf evaluation.
    "data_prep_args": {"transformation_functions":("np.log10",), "transformation_components":(["M_stars"],), "inverse_transformations":("10**x",), "transformation_logdets":("logdet_log10",)}
}

#becomes: (e.g.)
new_definition = copy.deepcopy(example_definition)
new_definition["processor"] = "MyProcessor"
new_definition["processor_args"] = {}
new_definition["processor_data"] = {"stars": "all_stars.csv", "N_stars": "N_stars.npy", "halo_masses": "halo_masses.npy"}
new_definition["processor_clean"] = {"m_max": 1e4, "M_dm_interval": (1e11, 1e12)}

#However for this to work you need to add the new processor to the func_handle dict in API.py:
#Key = "MyProcessor", Value = MyProcessor

If you change other things i.e. introduce parameters that need to be saved you may need to modify the `save_API` and `load_API` methods as well.

See above for the analogous example of the flow.

**NOTE:** The Key "subset params" with key "leavout_key" and "leavout_vals" is needed regardless of the processor used. Also the `.prepare` method of the API requires the full "subset_params" dict like above to be given and also "processor_data" "processor_clean" which is passed to the `get_data` and the `constraindata` method respectively.



If you want to change anything more fundamental beyond that you may need to change the API as well, e.g. the `.prepare` method that is responsible for how the data is obtained and handeled.

You can again subclass from the `GalacticFlow` class and overwrite the corresponding methods.