In [40]:
# configure the logger to print to console
import logging
from matplotlib import pyplot as plt

import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.linear_model import LassoCV, LinearRegression
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import r2_score



from yeastdnnexplorer.ml_models.lasso_modeling import (
    generate_modeling_data,
    stratification_classification,
    stratified_cv_modeling,
    bootstrap_stratified_cv_modeling,
    examine_bootstrap_coefficients)


logging.basicConfig(level=logging.ERROR)

# Interactor Modeling Workflow

This tutorial describes a process of modeling perturbation response by binding data
with the goal of discovering a meaningful set of interactor terms. More specifically,
we start with the following model:

$$
tf_{perturbed} \sim tf_{perturbed} + tf_{perturbed}:tf_{2} + tf_{perturbed}:tf_{2} + ... + max(non\ perturbed\ binding)
$$

Where the response variable is the $tf_{perturbed}$ perturbation response, and the
predictor variables are binding data (e.g., calling card experiments). Predictor
terms such as $tf_{perturbed}:tf_{2}$ represent the interaction between the
$tf_{perturbed}$ binding and the binding of another transcription factor. The final
term, $\max(\text{non-perturbed binding})$, is defined as the maximum binding score
for each gene, excluding $tf_{perturbed}$. This term is included to mitigate the
effect of outlier genes which may have high binding scores across multiple
transcription factors, potentially distorting the model.

We assume that the actual relationship between the perturbation response and the
binding data is sparse and use the following steps to identify significant terms.
These terms represent a set of TFs which, when considered as interactors with the
perturbed TF, improve the inferred relationship between the binding and perturbation
data.


## Interactor sparse modeling

1. First, we apply bootstrapping to a 4-fold cross-validated Lasso model. The folds
are stratified based on the binding data domain of the perturbed TF, ensuring that
each fold better represents the domain structure.

    - We produce two variations of this model:
        
        1. A model trained using all available data.
        
        2. A model trained using only the top 10% of data based on the binding
        score of the perturbed TF.

1. For model `1.1`, we select coefficients whose 99.8% confidence interval does not
include zero. For model `1.2`, we select coefficients whose 90.0% confidence interval
does not include zero. We assume that, due to the non-linear relationship between
perturbation response and binding, interaction effects are more pronounced in the
top 10% of the data. By intersecting the coefficients from both models, we highlight
those that are predictive across the full dataset.

1. With this set of predictors, next create an OLS model using the same 4-fold
stratified cross validation from which we calculated an average $R^2$. Next, for each
interactor in the model, we produce two other cross validated OLS models, one by
replacing the interactor with its corresponding main effect, and another that
includes both the interaction term and the main effect. We note which of these
variants yields the best average $R^2$.

1. Finally, we report, as significant interactors, those interaction terms which, when 
retained in the model, improve the $R^2$.


***NOTE***: To generate the `response_df` and `predictors_df` below, see the first six 
cells in the LassoCV tutorial.

In [21]:
response_df = pd.read_csv("~/htcf_local/lasso_bootstrap/erics_tfs/response_dataframe_20241105.csv", index_col=0)
predictors_df = pd.read_csv("~/htcf_local/lasso_bootstrap/erics_tfs/predictors_dataframe_20241105.csv", index_col=0)

### Step 1: Find significant predictors using all of the data

The function `get_significant_predictors()` is a wrapper of the lassoCV bootstrap 
protocol described in the LassoCV notebook. It allows using the same code to produce
both the 'all data' (step 1.1) and 'top 10%' models (step 1.2), and returns the 
significant coefficients as described in the protocol above.

In [None]:

# TODO: must have option to add max_lrb

# TODO: the top10% models likely should not have teh same number of classes as the
# all data, possibly not stratified at all

# return at this point, for use later in the notebook, the response variable
# and the all data stratification classes

def get_significant_predictors(perturbed_tf, response_df, predictors_df, **kwargs):
    """
    This function is used to get the significant predictors for a given TF. It is
    capable of conducting steps 1.1 and 1.2 described above.

    :params perturbed_tf: str, the TF for which the significant predictors are to be
        identified
    :params response_df: pd.DataFrame, the response dataframe containing the response
        values
    :params predictors_df: pd.DataFrame, the predictors dataframe containing the
        predictor values
    :params kwargs: dict, additional arguments to be passed to the function. Expected 
    arguments are 'quantile_threshold' fom generate_modeling_data() and 'ci_percentile'
    from examine_bootstrap_coefficients()

    :return sig_coef_dict: dict, a dictionary containing the significant predictors
        and their corresponding coefficients
    """

    y, X = generate_modeling_data(perturbed_tf,
                                  response_df,
                                  predictors_df,
                                  quantile_threshold=kwargs.get("quantile_threshold", None),
                                  drop_intercept=True)

    # NOTE: fit_intercept is set to `true`
    lassoCV_estimator = LassoCV(
        fit_intercept=True,
        max_iter=10000,
        selection="random",
        random_state=42,
        n_jobs=4)

    # Fit the model to the data in order to extract the alphas_ which are generated
    # during the fitting process
    lasso_model = stratified_cv_modeling(y, X, lassoCV_estimator)

    # set the alphas_ attribute of the lassoCV_estimator to the alphas_ attribute of the
    # lasso_model fit on the whole data. This will allow the
    # bootstrap_stratified_cv_modeling function to use the same set of lambdas
    lassoCV_estimator.alphas_ = lasso_model.alphas_


    # for test purposes, set n_bootstraps to 10
    # NOTE: fit_intercept=True is passed to the internal Lasso model for bootstrap
    # iterations, along with some other settings
    
    logging.info("running bootstraps")
    bootstrap_lasso_output = bootstrap_stratified_cv_modeling(
        y=y,
        X=X,
        estimator=lassoCV_estimator,
        ci_percentile=kwargs.get("ci_percentile", 95.0),
        n_bootstraps=kwargs.get("n_bootstraps", 10),
        max_iter=10000,
        fit_intercept=True,
        selection="random",
        random_state=42)

    sig_coef_plt, sig_coef_dict = examine_bootstrap_coefficients(
        bootstrap_lasso_output,
        ci_level=kwargs.get("ci_percentile", 95.0))
    
    plt.close(sig_coef_plt)
    
    return sig_coef_dict

all_data_sig_coef = get_significant_predictors(
    "CBF1",
    response_df,
    predictors_df,
    ci_percentile=99.8,
    n_bootstraps=100)

top10_data_sig_coef = get_significant_predictors(
    "CBF1",
    response_df,
    predictors_df,
    quantile_threshold=0.1,
    ci_percentile=90.0,
    n_bootstraps=100)

## Step 2

We next need to intersect the significant coefficients (see definitions above) in both
models. In this case, a single interactor survives (note that there are only 100
bootstraps in this example in the interest of speed for the tutorial. We recommend no 
less than 1000 in practice).

In [None]:
intersect_coefficients = set(all_data_sig_coef.keys()).intersection(set(top10_data_sig_coef.keys()))
print(f"The surviving coefficients are: {intersect_coefficients}")

The surviving coefficients are: {'CBF1:MET28'}


## Step 3

We next implement the method which searches alternative models, which include the
surviving interactor terms, with variations on including the main effect. In this case, 
we have only 1 term. But, we would do the following for each surviving interactor term,
if there is more than one. The goal of this process, remember, is to generate a set of
high confidence interactor terms for this TF. If the predictive power of the main effect
is equivalent or better than a model with the interactor, we consider that a low
confidence interactor effect.

In [None]:
def stratified_cv_r_squared(
    y: pd.DataFrame,
    X: pd.Dataframe,
    classes: list,
    estimator: BaseEstimator = LinearRegression(), 
    skf: StratifiedKFold = StratifiedKFold(n_splits=4, shuffle=True, random_state=42),
    **kwargs
) -> float:
    """
    """

    estimator_local = clone(estimator)
    r2_scores = []

    for train_idx, test_idx in skf.split(X, classes):
        # Use train and test indices to split X and y
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        # Fit the model
        model = estimator_local.fit(X_train, y_train)

        # Predict on the test set
        y_pred = model.predict(X_test)

        # Calculate R-squared and append to r2_scores
        r2_scores.append(r2_score(y_test, y_pred))


    return np.mean(r2_scores)

# NOTE! I am adding the main effect to the model b/c it is needed in
# `stratified_cv_r_squared`
# it is subsequencint
input_model_formula = f"CBF1_LRR ~ {' + '.join(intersect_coefficients)} + CBF1"

input_model_y, input_model_X = generate_modeling_data(
    'CBF1',
    response_df,
    predictors_df,
    formula = input_model_formula,
    drop_intercept=False
)

# TODO: for chase -- I am going to refractor the way the lasso fucntions work so that
# they also take 'classes' as an argument. 
classes = stratification_classification(input_model_X['CBF1'].squeeze(), input_model_y['CBF1'].squeeze())

input_model_X = input_model_X.drop(columns=['CBF1'], axis=1)

input_model_avg_rsquared = stratified_cv_r_squared(
    input_model_y,
    input_model_X,
    classes)

print(input_model_avg_rsquared)

INFO:main:Removing CBF1 from the data rows (removing the perturbed TF)
INFO:main:Number of rows in the merged response/predictors: 6149
INFO:main:Generating modeling data with formula: CBF1_LRR ~ CBF1:MET28 + CBF1


0.010517208487239943




In [None]:
# set up a function that iterates over iteractors and compares the 'main_effect' and
# 'main_effect + interactor' models' avg r-squared values to the input_model_avg_rsquared.
# return value is going to be a dictionary of predictors where the value denotes significance?

def get_high_confidence_interactors(
        response_variable,
        predictors,
        response_df,
        predictors_df):

    
    # at this point, for loop over the predictors. If the predictor is an interaction
    # term, then run a model where it is replaced by the main effect and replaced with
    # the main effect + interactor. After this is done (in the loop)