In [None]:
# configure the logger to print to console
from typing import Union
import logging
from matplotlib import pyplot as plt
import re
import random

import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator, clone
from sklearn.linear_model import LassoCV, LinearRegression
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import r2_score



from yeastdnnexplorer.ml_models.lasso_modeling import (
    generate_modeling_data,
    stratification_classification,
    stratified_cv_modeling,
    bootstrap_stratified_cv_modeling,
    examine_bootstrap_coefficients)


logging.basicConfig(level=logging.ERROR)

random.seed(42)
np.random.seed(42)

# Interactor Modeling Workflow

This tutorial describes a process of modeling perturbation response by binding data
with the goal of discovering a meaningful set of interactor terms. More specifically,
we start with the following model:

$$
tf_{perturbed} \sim tf_{perturbed} + tf_{perturbed}:tf_{2} + tf_{perturbed}:tf_{2} + ... + max(non\ perturbed\ binding)
$$

Where the response variable is the $tf_{perturbed}$ perturbation response, and the
predictor variables are binding data (e.g., calling card experiments). Predictor
terms such as $tf_{perturbed}:tf_{2}$ represent the interaction between the
$tf_{perturbed}$ binding and the binding of another transcription factor. The final
term, $\max(\text{non-perturbed binding})$, is defined as the maximum binding score
for each gene, excluding $tf_{perturbed}$. This term is included to mitigate the
effect of outlier genes which may have high binding scores across multiple
transcription factors, potentially distorting the model.

We assume that the actual relationship between the perturbation response and the
binding data is sparse and use the following steps to identify significant terms.
These terms represent a set of TFs which, when considered as interactors with the
perturbed TF, improve the inferred relationship between the binding and perturbation
data.


## Interactor sparse modeling

1. First, we apply bootstrapping to a 4-fold cross-validated Lasso model. The folds
are stratified based on the binding data domain of the perturbed TF, ensuring that
each fold better represents the domain structure.

    - We produce two variations of this model:
        
        1. A model trained using all available data.
        
        2. A model trained using only the top 10% of data based on the binding
        score of the perturbed TF.

1. For model `1.1`, we select coefficients whose 99.8% confidence interval does not
include zero. For model `1.2`, we select coefficients whose 90.0% confidence interval
does not include zero. We assume that, due to the non-linear relationship between
perturbation response and binding, interaction effects are more pronounced in the
top 10% of the data. By intersecting the coefficients from both models, we highlight
those that are predictive across the full dataset.

1. With this set of predictors, next create an OLS model using the same 4-fold
stratified cross validation from which we calculated an average $R^2$. Next, for each
interactor in the model, we produce two other cross validated OLS models, one by
replacing the interactor with its corresponding main effect, and another that
includes both the interaction term and the main effect. We note which of these
variants yields the best average $R^2$.

1. Finally, we report, as significant interactors, those interaction terms which, when 
retained in the model, improve the $R^2$.


***NOTE***: To generate the `response_df` and `predictors_df` below, see the first six 
cells in the LassoCV tutorial.

In [3]:
response_df = pd.read_csv("~/htcf_local/lasso_bootstrap/erics_tfs/response_dataframe_20241105.csv", index_col=0)
predictors_df = pd.read_csv("~/htcf_local/lasso_bootstrap/erics_tfs/predictors_dataframe_20241105.csv", index_col=0)

### Step 1: Find significant predictors using all of the data

The function `get_significant_predictors()` is a wrapper of the lassoCV bootstrap 
protocol described in the LassoCV notebook. It allows using the same code to produce
both the 'all data' (step 1.1) and 'top 10%' models (step 1.2), and returns the 
significant coefficients as described in the protocol above.

In [4]:

# TODO: must have option to add max_lrb

# TODO: the top10% models likely should not have teh same number of classes as the
# all data, possibly not stratified at all

# return at this point, for use later in the notebook, the response variable
# and the all data stratification classes

def get_significant_predictors(perturbed_tf, response_df, predictors_df, **kwargs):
    """
    This function is used to get the significant predictors for a given TF. It is
    capable of conducting steps 1.1 and 1.2 described above.

    :params perturbed_tf: str, the TF for which the significant predictors are to be
        identified
    :params response_df: pd.DataFrame, the response dataframe containing the response
        values
    :params predictors_df: pd.DataFrame, the predictors dataframe containing the
        predictor values
    :params kwargs: dict, additional arguments to be passed to the function. Expected 
    arguments are 'quantile_threshold' fom generate_modeling_data() and 'ci_percentile'
    from examine_bootstrap_coefficients()

    :return sig_coef_dict: dict, a dictionary containing the significant predictors
        and their corresponding coefficients
    """

    y, X = generate_modeling_data(perturbed_tf,
                                  response_df,
                                  predictors_df,
                                  quantile_threshold=kwargs.get("quantile_threshold", None),
                                  drop_intercept=True)

    # NOTE: fit_intercept is set to `true`
    lassoCV_estimator = LassoCV(
        fit_intercept=True,
        max_iter=10000,
        selection="random",
        random_state=42,
        n_jobs=4)
    
    predictor_variable = re.sub(r"_rep\d+", "", perturbed_tf)

    stratification_classes = stratification_classification(X[predictor_variable].squeeze(), y.squeeze())

    # Fit the model to the data in order to extract the alphas_ which are generated
    # during the fitting process
    lasso_model = stratified_cv_modeling(
        y, X, stratification_classes, lassoCV_estimator)

    # set the alphas_ attribute of the lassoCV_estimator to the alphas_ attribute of the
    # lasso_model fit on the whole data. This will allow the
    # bootstrap_stratified_cv_modeling function to use the same set of lambdas
    lassoCV_estimator.alphas_ = lasso_model.alphas_


    # for test purposes, set n_bootstraps to 10
    # NOTE: fit_intercept=True is passed to the internal Lasso model for bootstrap
    # iterations, along with some other settings
    
    logging.info("running bootstraps")
    bootstrap_lasso_output = bootstrap_stratified_cv_modeling(
        y=y,
        X=X,
        estimator=lassoCV_estimator,
        ci_percentile=kwargs.get("ci_percentile", 95.0),
        n_bootstraps=kwargs.get("n_bootstraps", 10),
        max_iter=10000,
        fit_intercept=True,
        selection="random",
        random_state=42)

    sig_coef_plt, sig_coef_dict = examine_bootstrap_coefficients(
        bootstrap_lasso_output,
        ci_level=kwargs.get("ci_percentile", 95.0))
    
    plt.close(sig_coef_plt)
    
    return sig_coef_dict, y, stratification_classes

all_data_sig_coef, all_y, all_stratification_classes = get_significant_predictors(
    "CBF1",
    response_df,
    predictors_df,
    ci_percentile=99.8,
    n_bootstraps=100)

top10_data_sig_coef, top10_y, top10_stratification_classes = get_significant_predictors(
    "CBF1",
    response_df,
    predictors_df,
    quantile_threshold=0.1,
    ci_percentile=90.0,
    n_bootstraps=100)

Significant coefficients for 99.8, where intervals are entirely above or below ±0.0:
CBF1:SWI6: (-0.1485865745733975, -0.013427250137445276)
CBF1:RGM1: (0.019074716568397616, 0.17308781555585018)
CBF1:ARG81: (-0.2192689594801929, -0.03363015605202929)
CBF1:MET28: (0.08239014045383972, 0.21725137611340745)
CBF1:SUT1: (0.0013204470995899372, 0.20979309187315287)
CBF1:AZF1: (-0.15859562883092804, -0.01911590208260141)
CBF1:GAL4: (0.08855249921344015, 0.30769402877789037)
CBF1:MSN2: (0.09345201116818, 0.2768407901563067)
Significant coefficients for 90.0, where intervals are entirely above or below ±0.0:
CBF1:MET28: (0.0702883319201939, 0.21243861704095554)


## Step 2

We next need to intersect the significant coefficients (see definitions above) in both
models. In this case, a single interactor survives (note that there are only 100
bootstraps in this example in the interest of speed for the tutorial. We recommend no 
less than 1000 in practice).

In [5]:
intersect_coefficients = set(all_data_sig_coef.keys()).intersection(set(top10_data_sig_coef.keys()))
print(f"The surviving coefficients are: {intersect_coefficients}")

The surviving coefficients are: {'CBF1:MET28'}


## Step 3

We next implement the method which searches alternative models, which include the
surviving interactor terms, with variations on including the main effect. In this case, 
we have only 1 term. But, we would do the following for each surviving interactor term,
if there is more than one. The goal of this process, remember, is to generate a set of
high confidence interactor terms for this TF. If the predictive power of the main effect
is equivalent or better than a model with the interactor, we consider that a low
confidence interactor effect.

In [None]:
def stratified_cv_r2(
    y: pd.DataFrame,
    X: pd.DataFrame,
    classes: np.ndarray,
    estimator: BaseEstimator = LinearRegression(), 
    skf: StratifiedKFold = StratifiedKFold(n_splits=4, shuffle=True, random_state=42)
) -> float:
    """
    Calculate the average stratified CV r-squared for a given estimator and data. By
    default, this is a 4-fold stratified CV with a LinearRegression estimator.

    :param y: The response variable. See generate_modeling_data()
    :param X: The predictor variables. See generate_modeling_data()
    :param classes: the stratification classes for the data
    :param estimator: the estimator to be used in the modeling. By default, this is a
        LinearRegression() model.
    :param skf: the StratifiedKFold object to be used in the modeling. By default, this
        is a 4-fold stratified CV with shuffle=True and random_state=42.
    
    :return: the average r-squared value for the stratified CV
    """

    estimator_local = clone(estimator)
    r2_scores = []

    for train_idx, test_idx in skf.split(X, classes):
        # Use train and test indices to split X and y
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        # Fit the model
        model = estimator_local.fit(X_train, y_train)

        # Calculate R-squared and append to r2_scores
        r2_scores.append(r2_score(y_test, model.predict(X_test)))


    return np.mean(r2_scores)


def test_interactor_variants(
        intersect_coefficients,
        interactor,
        **kwargs) -> dict[str | Union[str, float]]:
    """
    For a given interactor, replace the term in the formula with two variants:
        1. the main effect
        2. the main effect + interactor
    For each of the variants, calculate the average stratified CV r-squared with 
    stratified_cv_r2().

    :param intersect_coefficients: the set of coefficients that are determined to be
        significant, expected to be from either a bootstrap procedure on a LassoCV
        model on a full partition of the data and the top 10% by perturbed binding, or
        LassoCV followed by backwards selection by adj-rsquared.
    :param interactor: the interactor term to be tested
    :param kwargs: additional arguments to be passed to the function. Expected
        arguments are 'y', 'X', and 'stratification_classes'. See stratified_cv_r2()
        for more information.

    :return: a list with three dict entries, each with key
        'interactor', 'variant', 'avg_r2'
    """
    y = kwargs.get("y")
    if y is None:
        raise ValueError("y must be passed as a keyword argument")
    X = kwargs.get("X")
    if X is None:
        raise ValueError("X must be passed as a keyword argument")
    stratification_classes = kwargs.get("stratification_classes")
    if stratification_classes is None:
        raise ValueError("stratification_classes must be passed as a keyword argument")
    
    main_effect = interactor.split(":")[1]
    interactor_formula_variants = [main_effect, [main_effect, interactor]]

    output = []
    for variant in interactor_formula_variants:
        # replace the interactor term in the formula with the variant
        variant_predictors = ([term for term in intersect_coefficients if term != interactor] 
                              + [variant] if isinstance(variant, str) else variant)
        # conduct the stratified CV r-squared calculation with the formula variant
        input_model_avg_rsquared = stratified_cv_r2(
            y, 
            X.loc[:,variant_predictors],
            stratification_classes)
        
        # append the results to the output list
        output.append({
            "interactor": interactor,
            "variant": variant,
            "avg_r2": input_model_avg_rsquared
        })
        
    return output

def get_interactor_importance(
        y: pd.DataFrame,
        full_X: pd.DataFrame,
        stratification_classes: np.ndarray,
        intersect_coefficients: set,
        )-> tuple[float, list[dict[str | Union[str, float]]]]:
    """
    For each interactor in the intersect_coefficients, run test_interactor_importance
    to compare the variants' avg_rsquared to the input_model_avg_rsquared. If a variant
    of the interactor term is better, record it in a dictionary. Return the 
    `instersect_coefficient` model's avg R-squared and the dictionary of interaction
    alternatives that, when that alternative replaces a single interaction term,
    improves the rsquared.

    :param y: the response variable
    :param full_X: the full predictor matrix
    :param stratification_classes: the stratification classes for the data
    :param intersect_coefficients: the set of coefficients that are determined to be
        significant, expected to be from either a bootstrap procedure on a LassoCV
        model on a full partition of the data and the top 10% by perturbed binding, or
        LassoCV followed by backwards selection by adj-rsquared.
    
    :return: a tuple with the first element being the input_model_avg_rsquared and the
        second element being a list of dictionaries with keys 'interactor', 'variant',
        and 'avg_r2'
    """
    
    input_model_avg_rsquared = stratified_cv_r2(
        y,
        full_X.loc[:,intersect_coefficients],
        stratification_classes)

    # for each interactor in the intersect_coefficients, run test_interactor_importance
    # compare the variants' avg_rsquared to the input_model_avg_rsquared. Record
    # the best performing.
    interactor_results = []
    for interactor in intersect_coefficients:
        if ":" in interactor:
            
            interactor_variant_results = test_interactor_variants(
                intersect_coefficients,
                interactor,
                y=y,
                X=full_X,
                stratification_classes=stratification_classes)
            
            # compare the avg_r2 values of the two variants to input_model_avg_rsquared
            variant_dict = max(interactor_variant_results, key=lambda x: x["avg_r2"])
            
            if variant_dict["avg_r2"] > input_model_avg_rsquared:
                interactor_results.append(variant_dict)

    return input_model_avg_rsquared, interactor_results


The full model avg r-squared is 0.010517208487239915
The interactor results are: []


  full_X.loc[:,intersect_coefficients],


## Conduct the analysis

The functions above will be moved into source code once we settle one final forms.
Below is how I imagine carrying out the analysis on a single TF.

In [None]:

# get the additional main effects which will be tested from the intersect_coefficients
main_effects = []
for term in intersect_coefficients:
    if ":" in term:
        main_effects.append(term.split(":")[1])
    else:
        main_effects.append(term)

# combine these main effects with the intersect_coefficients
interactor_terms_and_main_effects = list(intersect_coefficients) + main_effects

# generate a model matrix with the intersect terms and the main effects. This full
# model will not be used for modeling -- subsets of the columns will be, however.
_, full_X = generate_modeling_data(
    'CBF1',
    response_df,
    predictors_df,
    formula = f"~ {' + '.join(interactor_terms_and_main_effects)}",
    drop_intercept=False
)

# Currently, this function tests each interactor term in the intersect_coefficients
# with two variants by replacing the interaction term with the main effect only, and
# with the main effect + interactor. If either of the variants has a higher avg
# r-squared than the intersect_model, then that variant is returned. In this case,
# the original intersect_coefficients are the best model.
full_avg_rsquared, x = get_interactor_importance(
    all_y,
    full_X,
    all_stratification_classes,
    intersect_coefficients
)

print(f"The full model avg r-squared is {full_avg_rsquared}")
print(f"The interactor results are: {x}")