# `multidms.model_collection` fitting pipeline

<!-- A key feature of global epistasis-like modeling, is the interpretable parameter values associated with any given mutation -- most commonly referred to as _mutations effects_. The joint-modeling approach in `multidms` provides mutation effect parameters for a given reference experiment, in addition to a respective set of shift parameters for _each_ non-reference experiment in the training data. 

As with many machine learning tasks it is beneficial to perform hyper-parameter sweeps and train on multiple datasets for cross validation. 
Here we introduce an interface to training `Model` objects  -->

In the [previous example notebook](https://matsengrp.github.io/multidms/fit_delta_BA1_example.html), we saw an explanation of the `Data` and `Model` class for fitting, and visualizing the results from a single model. Here, we will see how to use the `ModelCollection` class and associated utilities to fit multiple models (in parallel using `multiprocessing`) for aggregation and comparison of the results between fits. 

Two very common use cases for this interface include:

1. Shrinkage analysis of lasso coefficient values
2. Training on distinct replicate training datasets

To give an example of each below, we use the `multidms.fit_models` function to get a collection of 
fits (in the form of a `pandas.DataFrame` object) spanning two replicate datasets, and a range of lasso coefficient values.
We then instantiate a `multidms.ModelCollection` object from these fits to aggregate and visualize the results from the fits.

Note
----

This module functionally wraps the `Model` interface for convenience. If you're training on cpu's and have more than one core in your machine then this is definitely way to go. Currently, the code doesn't do anything clever to optimize GPU usage by many models training in parallel -- Thar be dragons.


In the case that would like to use GPU's for training, it is probably better to train each model individually using the using the `fit_one_model` function in this module.

In [1]:
# import notebook dependencies
import pandas as pd
import multidms
%matplotlib inline

## Load functional scores 

In the previous example, we showed data from two conditions, and fit a single model to the data. Here, we'll load multiple replicates of that same data from [three deep mutational scanning experiments](https://github.com/dms-vep) across Delta, Omicron BA.1, and BA.2 Spike protein.

In [2]:
# load scores, and fill wt values with empty strings
func_score_df = pd.read_csv("Delta_BA1_BA2_func_score_df.csv").fillna("")
# split condition and replicate
func_score_df = func_score_df.assign(
    replicate = func_score_df["condition"].apply(lambda x: x.split("-")[-1]),
    condition = func_score_df["condition"].apply(lambda x: "-".join(x.split("-")[:-1])) 
)
func_score_df.sample(5)


Unnamed: 0,func_score,aa_substitutions,condition,replicate
498508,-0.7886,A846V,Omicron_BA.2,1
311113,-0.364,,Omicron_BA.1,2
453923,1.2999,K182N,Omicron_BA.2,1
7208,-0.0029,G504S Q677G,Delta,1
317963,0.1091,S446P Y501A A688S F981V,Omicron_BA.1,2


In [3]:
for condition, cfs in func_score_df.groupby('condition'):
    print(f"{condition} replicates:\n\t{cfs.replicate.unique()}\n")

Delta replicates:
	['1' '2' '3' '4']

Omicron_BA.1 replicates:
	['1' '2' '3']

Omicron_BA.2 replicates:
	['1' '2']



## Instantiate `multidms.Data` objects for fitting

We would like to create two replicate training datasets, each of which should consist of one replicate from each of the three experiments. For simplicity, we'll group the three experiments deriving from replicate '1' together, and similarly for replicate '2' -- keeping in mind there is no significance to the replicate names in this case.

We'll create the `Data` objects, as we've done before, but this time we'll create independent `Data` objects for each replicate. Keep in mind that when comparing across replicate datasets using the `multidms.ModelCollection` interface, it is best to keep the reference, and non-reference conditions consistent among datasets.

In [4]:
data_replicates = [
    multidms.Data(
        func_score_df.query("replicate == @rep").sample(10000),
        alphabet = multidms.AAS_WITHSTOP_WITHGAP,
        collapse_identical_variants = "mean",
        reference = "Delta",
        verbose = False,
        nb_workers=4,
        name = f"Replicate {rep}"
    )
    for rep in ["1", "2"]
]

I0000 00:00:1697663963.641265   30997 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Fit one model with `multidms.fit_one_model`

The `model_collection` module offers a simple interface to create and fit `Model` objects. First, Let's fit a single model to one of the `Data` replicates above. To do this, we'll simply need to define the model parameters. 

In [5]:
single_set_of_params = {
    "dataset": data_replicates[0], # only one replicate dataset
    "num_training_steps" : 1,
    "iterations_per_step": 5, # Small number of iterations for purposes of this example
    "scale_coeff_lasso_shift": 1e-5,
}


For a full list and descriptions of available hyperparameters, see: 
```python
help(multidms.model_collection.fit_one_model)
``` 

With these, we can now fit a singular model

In [6]:
fit = multidms.model_collection.fit_one_model(**single_set_of_params)
fit

model                        <multidms.model.Model object at 0x7f82583af1d0>
dataset_name                                                     Replicate 1
step_loss                                               [3.1401935958042104]
epistatic_model                                                      Sigmoid
output_activation                                                   Identity
scale_coeff_lasso_shift                                              0.00001
scale_coeff_ridge_beta                                                     0
scale_coeff_ridge_shift                                                    0
scale_coeff_ridge_gamma                                                    0
scale_coeff_ridge_alpha_d                                                  0
scale_coeff_huber                                                          1
gamma_corrected                                                        False
alpha_d                                                                False

Now we have the `Model` object along with the associated hyperparameters that were fit the model to the replicate dataset. Let's take a look at the beta's ($\beta_m$) from this fit using the `Model.mut_param_heatmap` method.


In [7]:
fit.model.mut_param_heatmap(mut_param="beta")


Next, we'll see how to fit multiple models in parallel.

## Fit multiple models (in parallel) with `multidms.fit_models`

Currently, the `model_collection` interface offers two public functions: `fit_one_model`, as we saw above, and `fit_models`. The former is wrapped by the latter, and allows for multiple models to be fit in parallel by spawning child processes using `multiprocessing`. The `fit_models` function takes in a single dictionary which defines the parameter space of all models you wish to run. Each value in the dictionary must be a list of values, even in the case of singletons. This function will compute all combinations of the parameter space and pass each combination to :func:`multidms.utils.fit_wrapper` to be run in parallel, thus only key-value pairs which match the `fit_one_model` kwargs are allowed. 

To exemplify this, let's again define the hyperparameters, but this time, we'll specify each value as a list of values to be fit in parallel.

In [8]:
# test out no free alpha_d param
collection_params = {
    "dataset": data_replicates,
    "num_training_steps" : [1],
    "iterations_per_step": [2000],
    "output_activation" : ["Softplus"],
    "lower_bound" : [-3.5],
    "scale_coeff_lasso_shift": [0.0, 1e-5, 1e-3],
}

Before we fit the models, let's take a look at what collection of models we're specifying with this dictionary by calling upon a "private" function `multidms.model_collection._explode_params_dict`. As implied by the "private" this functionality behavior is hidden from the user and is performed intrinsically when calling `fit_models`.

In [9]:
from pprint import pprint
pprint(multidms.model_collection._explode_params_dict(collection_params))

[{'dataset': <multidms.data.Data object at 0x7f825b7b79d0>,
  'iterations_per_step': 2000,
  'lower_bound': -3.5,
  'num_training_steps': 1,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 0.0},
 {'dataset': <multidms.data.Data object at 0x7f825b7b79d0>,
  'iterations_per_step': 2000,
  'lower_bound': -3.5,
  'num_training_steps': 1,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 1e-05},
 {'dataset': <multidms.data.Data object at 0x7f825b7b79d0>,
  'iterations_per_step': 2000,
  'lower_bound': -3.5,
  'num_training_steps': 1,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 0.001},
 {'dataset': <multidms.data.Data object at 0x7f82a9b26190>,
  'iterations_per_step': 2000,
  'lower_bound': -3.5,
  'num_training_steps': 1,
  'output_activation': 'Softplus',
  'scale_coeff_lasso_shift': 0.0},
 {'dataset': <multidms.data.Data object at 0x7f82a9b26190>,
  'iterations_per_step': 2000,
  'lower_bound': -3.5,
  'num_training_steps': 1,
  'output_a

What is produced is a list of **kwargs to pass to `fit_one_model`. In this case there are 6 total models to fit (2 replicate datasets x 3 lasso strengths). To fit these models, we simply pass the `collection_params` to `fit_models` and specify the number of threads available to run the model fits in parallel.

In [10]:
n_fit, n_failed, fit_models = multidms.model_collection.fit_models(collection_params, n_threads=4)

I0000 00:00:1697663981.668964   31356 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0000 00:00:1697663981.920269   31355 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0000 00:00:1697663982.089017   31357 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0000 00:00:1697663982.177453   31358 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [11]:
print(f"Of the 6 model fits, {n_fit} succeeded and {n_failed} failed")

Of the 6 model fits, 6 succeeded and 0 failed


The third object returned by `fit_models` is a `pandas.DataFrame` object which contains the results from each model fit by stacking the `pd.Series` objects as returned by `fit_one_model`. 

In [12]:
fit_models

Unnamed: 0,model,dataset_name,step_loss,epistatic_model,output_activation,scale_coeff_lasso_shift,scale_coeff_ridge_beta,scale_coeff_ridge_shift,scale_coeff_ridge_gamma,scale_coeff_ridge_alpha_d,...,gamma_corrected,alpha_d,init_beta_naught,lock_beta_naught_at,tol,num_training_steps,iterations_per_step,n_hidden_units,lower_bound,PRNGKey
0,<multidms.model.Model object at 0x7f82584afbd0>,Replicate 1,[0.13720360151936598],Sigmoid,Softplus,0.0,0,0,0,0,...,False,False,0.0,,0.0001,1,2000,5,-3.5,0
1,<multidms.model.Model object at 0x7f82582132d0>,Replicate 1,[0.14086653344690003],Sigmoid,Softplus,1e-05,0,0,0,0,...,False,False,0.0,,0.0001,1,2000,5,-3.5,0
2,<multidms.model.Model object at 0x7f82541cd950>,Replicate 1,[0.39089300306376057],Sigmoid,Softplus,0.001,0,0,0,0,...,False,False,0.0,,0.0001,1,2000,5,-3.5,0
3,<multidms.model.Model object at 0x7f82541f8110>,Replicate 2,[0.10422796879495169],Sigmoid,Softplus,0.0,0,0,0,0,...,False,False,0.0,,0.0001,1,2000,5,-3.5,0
4,<multidms.model.Model object at 0x7f8274946010>,Replicate 2,[0.10614017366336062],Sigmoid,Softplus,1e-05,0,0,0,0,...,False,False,0.0,,0.0001,1,2000,5,-3.5,0
5,<multidms.model.Model object at 0x7f8257e61a50>,Replicate 2,[0.36275755545409216],Sigmoid,Softplus,0.001,0,0,0,0,...,False,False,0.0,,0.0001,1,2000,5,-3.5,0


This `DataFrame` is all that's necessary to create a `multidms.ModelCollection` object. 

**Note** 

If you wanted to use a pipeline to farm out the fitting processes independently, the same `DataFrame` could be acquired by collecting the individual `Series` objects returned by `fit_one_model`,
then concatenated using the simple `multidms.model_collection.stack_fit_models` utility.

## `ModelCollection` Object

The ModelCollection class is simply a nice interface split-apply-combine the model attributes such as mutations dataframes, and variants_df contained within the `pandas.DataFrame` object returned by `fit_models`. To instantiate a `ModelCollection` object, we simply pass the dataframe to the constructor.

In [16]:
mc = multidms.ModelCollection(fit_models)

To get raw data in a nice tidy format, `ModelCollection.split_apply_combine_muts` has a straightforward name for a simple goal. This function follows the [split-apply-combine](https://pandas.pydata.org/docs/user_guide/groupby.html) paradigm to the collection of individual mutational effects tables (our example currently has 6) while keeping the fit hyperparameters of interest, tied to the data.

In [17]:
combined_lasso_strengths = mc.split_apply_combine_muts(
    groupby=["dataset_name", "scale_coeff_lasso_shift"]
)
combined_lasso_strengths.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,mutation,beta,shift_Omicron_BA.1,shift_Omicron_BA.2,predicted_func_score_Delta,predicted_func_score_Omicron_BA.1,predicted_func_score_Omicron_BA.2,times_seen_Delta,times_seen_Omicron_BA.1,times_seen_Omicron_BA.2
dataset_name,scale_coeff_lasso_shift,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Replicate 1,0.0,A1015S,-0.120579,-0.790163,0.0,-0.047522,-0.127383,0.348299,0.0,2.0,0.0
Replicate 1,0.0,A1015T,1.911051,0.0,-1.036721,0.350125,0.708258,0.628419,1.0,0.0,2.0
Replicate 1,0.0,A1016Q,0.018529,0.0,1.676144,0.006909,0.365041,0.729946,0.0,0.0,1.0
Replicate 1,0.0,A1016S,-0.435612,0.0,0.0,-0.194362,0.16377,0.201459,2.0,0.0,0.0
Replicate 1,0.0,A1016T,1.558712,1.180128,-1.656069,0.322231,0.74604,0.357804,3.0,3.0,2.0


The fit collection groupby features (`scale_coeff_lasso_shift`, and `dataset_name` in this case) are set as a [multiindex](https://pandas.pydata.org/docs/user_guide/advanced.html) -- the index then easily distinguishes _fit groups_ from from _mutation features_, and is more memory efficient. If `groupby = None` (default), then we group by all available fit attributes. Also note that by default, only mutations shared by all datasets are returned, but this can be changed by setting `inner_merge_dataset_muts=False`.

## Mutational parameter heatmaps

Just as you might use `Model.mut_param_heatmap` to visualize the mutation effects from a single model, you can use `ModelCollection.mut_param_heatmap` to visualize the aggregated mutation effects from a collection of models fit to multiple replicate datasets.

Using all defaults this would be called as follows:

```python
heatmap_chart = mc.mut_param_heatmap()
```

However, our current example fit collection has 3 different lasso strengths, which don't make sense to aggregate over. Thus, this call will result in:

```python
ValueError: invalid query, more than one unique hyper-parameter besides dataset_name
```

To fix this, we must subset out model collection such that we are only aggregating across different training datasets.

In [18]:
chart = mc.mut_param_heatmap(
    query="scale_coeff_lasso_shift == 1e-5",
    mut_param="beta"
)
chart

Here, we visualized the models beta ($\beta_m$) parameters, but we can also visualize the respective shift parameters for each non-reference condition ($\Delta_{d,m}$) by setting `param='shift'`.

In [19]:
chart = mc.mut_param_heatmap(
    query="scale_coeff_lasso_shift == 1e-5",
    mut_param="shift" # or "shift"
)
chart

Or, we can visualize the mutation predictions ($\hat{y}_{m, d}$), noting that by default, we are viewing the predictions with phenotype as effect (difference from non-zero wildtype prediction).

In [20]:
chart = mc.mut_param_heatmap(
    query="scale_coeff_lasso_shift == 1e-5",
    mut_param="predicted_func_score",
    phenotype_as_effect=True
)
chart

## Trace charts for mutational shrinkage

Another common reason you might fit a collection of models is to test multiple lasso strength coefficients. When you have a few mutations of interest, you might want to see how the lasso strength affects the shrinkage of the mutation effects. To do this, we can use the `ModelCollection.mut_param_traceplot` method. 

Begin by selecting a subset of mutations to visualize. Here, we'll select the top 10 mutations by absolute value of the beta parameter across all fits.

In [21]:
combined_lasso_strengths["abs_beta"] = combined_lasso_strengths["beta"].abs()
muts_of_interest = combined_lasso_strengths.sort_values("abs_beta", ascending=False).head(10).mutation.values
mc.mut_param_traceplot(mutations = muts_of_interest, mut_param="shift")

In [23]:
import altair as alt

In [None]:
df = mc.split_apply_combine_muts(
    
)