# Gene analysis using SHAP
- This notebook explains what is **SHAP** and how to use scaLR's SHAP to get the genes/features weight to each class of the model.
- scaLR supports early stops in SHAP analysis.

# What is SHAP?

- SHAP (SHapley Additive exPlanations) is a game theoretic approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions.

- Know more: https://shap.readthedocs.io/en/latest/

# What is early stop in SHAP?

- scaLR proccessing SHAP in the batches. Processing SHAP in the batch or with all data gives similar results.
- scaLR list out the top N genes for each batch and match with previous batch if it's found number(threshold) of genes similar then it will count the patience. Once patience count is match with the config patience number, it will stop the process.

# How to use SHAP from scaLR

## Imports

In [None]:
import pandas as pd
from os import path
import sys
sys.path.append('/path/to/scaLR/')

%reload_ext autoreload
%autoreload 2

In [None]:
from scalr.feature.scoring import ShapScorer
from scalr.nn.model import build_model
from scalr.utils import read_data
from scalr.analysis import Heatmap
from scalr.feature.selector import build_selector

## Configuration

In [None]:
config = {
    "dataloader": {
        "name": "SimpleDataLoader",
        "params": {
            "batch_size": 10, # Number of samples processing at a time.
            "padding": 5000
        }
    },
    "top_n_genes": 100, # Top N Genes used for early stop.
    "background_tensor": 20, # Number of train data used as background. Please see SHAP official documentation to know more.
    "early_stop": {
        "patience": 5, # Process stop if continue top genes are similar(>= threshold) up-to number of batches(=patience).
        "threshold": 95 # How many genes should be the same for each iteration?
    },
    "device": 'cuda', # Process with a run on cpu or cuda/gpu.
    "samples_abs_mean": True, # First performed abs on the samples score then do mean.
    "logger": "FlowLogger" # It will print the logs to the output.
}

## Read data & Model

In [None]:
# The user can provide the DIR path or direct file path.
train_data = read_data("data/train.h5ad")
test_data = read_data("data/test.h5ad")

In [None]:
# Model path which generated using scaLR platform.
model_checkpoint = "best_model"

model_config = read_data(path.join(model_checkpoint, 'model_config.yaml'))
model_weights = path.join(model_checkpoint, 'model.pt')
mappings = read_data(path.join(model_checkpoint, 'mappings.json'))

model, _ = build_model(model_config)
model.to(config['device'])
model.load_weights(model_weights)

## Run SHAP

In [None]:
shap_scorer = ShapScorer(**config)

In [None]:
target = "Cell_Type"
shap_values = shap_scorer.get_top_n_genes_weights(model, train_data, test_data, target, mappings)

In [None]:
shap_values

In [None]:
columns = train_data.var_names
class_labels = mappings[target]['id2label']
all_scores = shap_values[:, :len(columns)]

score_matrix = pd.DataFrame(all_scores, columns=columns, index=class_labels)

In [None]:
score_matrix

# Select top N features

In [None]:
selector_config = {
    "name": "ClasswisePromoters", # Class wise top genes.
    # "name": "AbsMean", # Top genes across all class.
    "params":{
        "k": 5000
    }
}
selector, _ = build_selector(selector_config)

In [None]:
top_features = selector.get_feature_list(score_matrix)

# Generate heatmaps
Heatmap of feature weights with respect to each class.

- If `top_features` is listed, will plot a single heatmap with top genes from all classes.
- If `top_features` is dict(it contains class wise top features), each heatmap show top features of that class w.r.t the other class.

In [None]:
# save_plot = True, will store plots without showing plots.
heatmap = Heatmap(top_n_genes=100, save_plot=False)

In [None]:
heatmap.generate_analysis(
    score_matrix,
    top_features,
    dirpath=".",
)