# SCimilarity Served Model Endpoints Usage 

## A Guided Example

This notebook is a guided walkthrough of how to use the served SCimilarity functions as endpoints to make inferences.

It adopts the official [SCimilarity `cell_search` tutorial](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html): and where modules are wrapped with [MLflow Custom PyFunc](https://mlflow.org/docs/latest/ml/traditional-ml/tutorials/creating-custom-pyfunc/part2-pyfunc-components/index.html) and [served as endpoints](https://docs.databricks.com/aws/en/machine-learning/model-serving/create-manage-serving-endpoints), these [custom model endpoints](https://docs.databricks.com/aws/en/machine-learning/model-serving/score-custom-model-endpoints) will be leveraged instead. 


<!-- Requires MLDBR_14.3LTSgpu? -->

#### `start_stop_check_endpoints`
Before you begin, please manually kick off a run of the notebook: `./start_stop_check_endpoints_v1`.  
_This will help to wake up the relevant SCimilarity endpoints so that they will not need to be scaled from zero and we experience time-out when attempting to make inference calls to the endpoints._

### Notebook Setup + requirements installation

In [0]:
requirements = [
    "scimilarity==0.4.0",
    "typing_extensions>=4.14.0",
    "scanpy==1.11.2", #
    "numcodecs==0.13.1", #
    "numpy==1.26.4",
    "pandas==1.5.3",
    "mlflow==2.22.0",
    "cloudpickle==2.0.0", #
    "tbb>=2021.6.0",    
    # "uv"
]

for package in requirements:
    %pip install {package} --upgrade -q #-v

dbutils.library.restartPython()

In [0]:
# Import numba after installing the required packages
import numba

# Set this at the beginning of your notebook/script
numba.config.THREADING_LAYER = 'workqueue'  # Most compatible option
# Other options: 'omp' (OpenMP) or 'tbb' (default)

In [0]:
## for nb devs -- these get overwritten wrt deployment args
dbutils.widgets.text("catalog", "genesis_workbench", "Catalog")
dbutils.widgets.text("schema", "dev_mmt_core_test", "Schema") 

dbutils.widgets.text("model_name", "SCimilarity", "Model Name") ## use this as a prefix for the model name 

dbutils.widgets.text("cache_dir", "scimilarity", "Cache dir") ## VOLUME NAME | MODEL_FAMILY 

CATALOG = dbutils.widgets.get("catalog")
SCHEMA = dbutils.widgets.get("schema")

MODEL_NAME = dbutils.widgets.get("model_name")

CACHE_DIR = dbutils.widgets.get("cache_dir")

print(f"Cache dir: {CACHE_DIR}")
cache_full_path = f"/Volumes/{CATALOG}/{SCHEMA}/{CACHE_DIR}"
print(f"Cache full path: {cache_full_path}")

In [0]:
CATALOG = CATALOG 
DB_SCHEMA = SCHEMA 

# VOLUME_NAME | PROJECT 
MODEL_FAMILY = CACHE_DIR ## CACHE_DIR #"scimilarity"

# MODEL_NAME #"SCimilarity" 

print("CATALOG :", CATALOG)
print("DB_SCHEMA :", DB_SCHEMA)
print("MODEL_FAMILY :", MODEL_FAMILY)

In [0]:
### specific to scimilarity tutorial example
# from scimilarity import CellQuery, CellEmbedding ## these are the functions we wrapped in MLflow PyFunc
from scimilarity import align_dataset, lognorm_counts
import scanpy as sc

### viz-related
from matplotlib import pyplot as plt
sc.set_figure_params(dpi=100)
plt.rcParams["figure.figsize"] = [6, 4]

import warnings
warnings.filterwarnings("ignore")

from scipy import sparse
import numpy as np
import pandas as pd
import json
import requests

import os

from collections.abc import MutableMapping  # 

### Specify Configs. & Parameters  

We need to specify a few configurations:
- [DATABRICKS TOKEN](https://docs.databricks.com/aws/en/dev-tools/auth/pat) 

The following can be specified via the widget parameters at the top of the notebook:
- databricks_instance e.g.      
  `https://{<region-instance>.databricks.com/}`, `https://{<region>.azuredatabricks.net}`,`https://{<region>.gcp.databricks.com/}`
- SCimilarity endpoints to use:
  - gene_order_endpoint
  - get_embedding_endpoint
  - search_nearest_endpoint

In [0]:
## to update as widgets input
import os 
DATABRICKS_TOKEN = dbutils.secrets.get("mmt","databricks_token") 
# DATABRICKS_TOKEN = dbutils.secrets.get("<SCOPE>","<SECRET_KEY>") ## to update with APP SP/PAT token?
os.environ["DATABRICKS_TOKEN"] = DATABRICKS_TOKEN

# databricks_instance = dbutils.widgets.text("databricks_instance_url", "<region>.{azure}databricks.net", "databricks_instance")
databricks_instance = "adb-830292400663869.9.azuredatabricks.net" ## update to your workspace instance

## User to specify in Widgets -- to replace with gwb-app ones?... 
# gene_order_endpoint = dbutils.widgets.text("gene_order_endpoint_name", "mmt_scimilarity_gene_order", "gene_order_endpoint")
# get_embedding_endpoint = dbutils.widgets.text("get_embedding_endpoint_name", "mmt_scimilarity_get_embedding_v2test", "get_embedding_endpoint")
# search_nearest_endpoint = dbutils.widgets.text("search_nearest_endpoint_name", "mmt_scimilarity_search_nearest", "search_nearest_endpoint")

gene_order_endpoint = "mmt_scimilarity_gene_order"
get_embedding_endpoint = "mmt_scimilarity_get_embedding_v2test" 
search_nearest_endpoint = "mmt_scimilarity_search_nearest"

### Define Functions to Make API Endpoint Inferencing + SCimilarity function calls from endpoints

In [0]:
# Function to create Serving JSON format
def create_serving_json(data, params=None):
    return {"instances": data} if params is None else {"instances": data, "params": params}


def api_inference(databricks_instance, endpoint_name, data_input, params=None):
    url = f'https://{databricks_instance}/serving-endpoints/{endpoint_name}/invocations'
    headers = {
        'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}',
        'Content-Type': 'application/json'
    }
    
    if isinstance(data_input, pd.DataFrame):
        ds_dict = {'dataframe_split': data_input.to_dict(orient='split')}
        if params:
            ds_dict['params'] = params
    else:
        ds_dict = create_serving_json(data_input, params)
    
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.post(url, headers=headers, data=data_json)
    
    if response.status_code == 200:
        try:
            return response.json()
        except json.JSONDecodeError:
            return {"error": "Invalid JSON response", "response_text": response.text}
    else:
        return {"error": f"HTTP {response.status_code}", "response_text": response.text}

In [0]:
def derive_gene_order(databricks_instance, endpoint_name) -> list[str]:  
    example_input = pd.DataFrame({"input": ["get_gene_order"]})
    get_gene_order = api_inference(databricks_instance, endpoint_name, data_input=example_input)
    return get_gene_order['predictions']

## gene_order = derive_gene_order(endpoint_name="mmt_scimilarity_gene_order")


## updated version 
def derive_embedding(databricks_instance, endpoint_name, data_input) -> pd.DataFrame():  
    ## data_input: X_vals

    celltype_sample_obs_json = data_input.obs.to_json(orient='split')
    X_vals_dense = data_input.X.toarray().tolist()
    celltype_subsample_pdf = pd.DataFrame([{'celltype_subsample': row} for row in X_vals_dense ], 
                                        index=data_input.obs.index
                                        )

    example_input = pd.DataFrame([{
                                    "celltype_sample": celltype_subsample_pdf.to_json(orient='split'), 
                                    "celltype_sample_obs": data_input.obs.to_json(orient='split')
                                }])
    
    get_embeddings = api_inference(databricks_instance, endpoint_name, data_input=example_input)
    return get_embeddings #['predictions'][0]['embedding']#[0]

## cell_embedding = derive_embedding(endpoint_name = "mmt_scimilarity_get_embedding", data_input = X_vals)


def search_nearest(databricks_instance, endpoint_name, embedding, params = {'k': 10}) -> dict:
    example_input = pd.DataFrame({"embedding": embedding})
    return api_inference(databricks_instance, endpoint_name, data_input=example_input, params = params)

## knn = search_nearest(databricks_instance, endpoint_name, embedding, params = {'k': 10})

### DATA Overview

To simplify this walkthrough example of using the SCimilarity functions served as endpoints, we [pre-computed embeddings on aligned and log-normed dataset of adams et al.](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html#2.-Compute-embeddings) that the SCimilarity model was not trained on (`previously unseen`) and saved it as a `.h5ad` file to read in here.    

_The observations are that the model is able to yield author annotations roughly clustered in SCimilarity embedding space_ 

<!-- - ~/REPO/genesis-workbench/modules/scimilarity/scimilarity_v0.4.0_weights_v1.1/notebooks/0x_BatchInference_GetEmbedding  -->


In [0]:
# To save a bit of wait time, and for tutorial viz: we use `cell_query.get_embeddings` to pre-process `adams0`, which is cell-order aligned, log-normed, and scimilarity embeddings applied to the adams et al dataset sample for displaying with umap.   

adams0 = sc.read_h5ad(f"/Volumes/{CATALOG}/{DB_SCHEMA}/{MODEL_FAMILY}/data/adams_etal_2020/adams0_alignedNlognormed_Xscim_umap.h5ad")

# umap aligned+lognormed data
sc.pl.umap(adams0, color="celltype_raw", legend_fontsize=10)

### DATA Preprocessing 

The [dataset](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html#0.-Required-software-and-data) (with embedding clustering visualized above) has cells sourced from _Idiopathic Pulmonary Fibrosis_ (IPF) patients, _Chronic Obstructive Pulmonary Disease_ (COPD) patients, and healthy individuals. 

The data processing steps follow the original SCimilarity `cell_search` tutorial as described in [Import-and-normalize-data](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html#1.-Prepare-for-SCimilarity:-Import-and-normalize-data).


In [0]:
# sampledata_path = f"/Volumes/{CATALOG}/{DB_SCHEMA}/{MODEL_FAMILY}/data/adams_etal_2020/GSE136831_subsample.h5ad"
sampledata_path = f"/Volumes/{CATALOG}/{DB_SCHEMA}/{MODEL_FAMILY}/data/adams_etal_2020/adams.h5ad" #gziped compressed

## READ sample czi dataset H5AD file --> align + lognorm 
adams = sc.read(sampledata_path)

To align the `adams` data, we will use the **`gene_order`** serving endpoint

In [0]:
## here we will use the gene_order endpoint

gene_order = derive_gene_order(databricks_instance, endpoint_name = gene_order_endpoint) 
aligned = align_dataset(adams, gene_order) #

lognorm = lognorm_counts(aligned)


Per [SCimilarity tutorial](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html#), **our interest is in studying the myofibroblasts in IPF patients** and we would like to understand **what other disease exhibit similar myofibroblasts**. 

To ensure we pick a cell that we are confident is a myofibroblast cell type, we will first filter for IPF samples and check the expression of some canonical fibroblast and myofibroblast markers across different samples. 

In [0]:
#### Filter samle data to "Disease" | celltype: "myofibroblast cell" | sample_ref: "DS000011735-GSM4058950"  

disease_name = "IPF"
celltype_name = "myofibroblast cell" 
sample_refid = "DS000011735-GSM4058950"
subsample_refid = "123942" #[123942, 124332, 126138]

diseasetype_ipf = lognorm[lognorm.obs["Disease"] == disease_name].copy()

celltype_myofib = diseasetype_ipf[
                                  diseasetype_ipf.obs["celltype_name"] == celltype_name
                                 ].copy()

## Extract list for sample_ref
celltype_sample = celltype_myofib[
                                  celltype_myofib.obs["sample"] == sample_refid 
                                 ].copy()

## extract specific index in celltype_sample 
celltype_subsample = celltype_sample[celltype_sample.obs.index == subsample_refid]

query_cell = celltype_sample 
# query_cell = celltype_subsample

## extract subsample query (1d array or list)
# X_vals: sparse.csr_matrix = query_cell.X
# print(X_vals)

# X_vals_dense = X_vals.toarray()
# print(X_vals_dense)

In [0]:
query_cell.obs

### Cell Types filterd to Disease:`IPF` and Study Sample Ref.
 
Per [SCimilarity tutorial](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html#), we selected `cell 123942` from `IPF sample DS000011735-GSM4058950` for our query.    
We also check canonical gene markers to help ensure a high confidence myofibroblast is selected.

In [0]:
marker_list = {
    "Fibroblast": ["COL1A1", "COL3A1", "DCN", "FBLN1", "FN1", "LUM", "THY1"],
    "Myofibroblast": ["CDH11", "COMP", "CTHRC1", "ELN", "POSTN", "TNC"],
    "Smooth Muscle": ["ACTA2", "ACTG2", "DES", "MYH11", "MYL9", "TAGLN"],
}

In [0]:
celltype_sample.obs["cell_id"] = celltype_sample.obs.index
# celltype_subsample.obs["cell_id"] = celltype_subsample.obs.index

dot = sc.pl.dotplot(
    celltype_sample,
    var_names=marker_list,
    groupby="cell_id",
    var_group_rotation=0,
    title="CellType & Gene Markers",
    show=False,          # keep the figure hidden for now
)

dot['mainplot_ax'].set_xlabel("Associated Genes")   # Access the main Axes from the dictionary
dot['mainplot_ax'].set_ylabel("Cell RefID")
# dot['mainplot_ax'].figure.tight_layout()            # Access the underlying Figure using dot['mainplot_ax'].figure
plt.show()

Relative to the dataset embedded space -- we can visualize "where" this cell subsample is 'located'.  

In [0]:
used_in_query = adams0.obs.index == "123942"
adams.obs["used_in_query"] = used_in_query.astype(int)

# Define the colormap once
from matplotlib.colors import LinearSegmentedColormap
grey_to_red = LinearSegmentedColormap.from_list(
    "grey_to_red",
    ["lightgrey", "red"]
)

# Plot UMAP with a larger marker size and a distinct marker type for the query
f = sc.pl.umap(
    adams0, ## using the the pre-computed dataset embeddings with scimilarity
    color=["used_in_query"], 
    cmap=grey_to_red, #"YlOrRd", 
    # size=100,  # Increase the marker size
    return_fig=True
)

# Add an arrow to highlight the specific point
f.axes[0].arrow(5.5, 8.75, 1, -1, head_width=0.5, head_length=0.5, color='black')  
# f.axes[0].arrow(6.1, 10.5, 1, -1, head_width=0.5, head_length=0.5); # original

# Change the marker type for the used_in_query points
used_in_query_points = adams0[adams0.obs["used_in_query"] == 1].obsm["X_umap"]
f.axes[0].scatter(used_in_query_points[:, 0], used_in_query_points[:, 1], color='red', marker='X', s=12)

plt.show()


### Searches for `k` most simliar cells 

We take a similar approach to SCimilarity tutorial and [perform searches for N most simliar cells across 22.7M cell reference and extract metadata for each cell](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html#3.-Perform-cell-search).

<!-- For sub-second searches, with less frills, use `cq.get_nearest_neighbors`.

[Note, many of the search results will be from the Adams et al. dataset, since the most similar cells will come from the same sample or study.] -->

Specifically for this example notebook, we will illustrate using our **`search_nearest`** serving endpoint, which has been MLflow custom pyfunc wrapped using `cq.search_nearest()` where [`CellQuery` is from `SCimilarity`](https://genentech.github.io/scimilarity/modules/cell_query.html#module-scimilarity.cell_query) and `cq = CellQuery(model_path="/opt/data/model")` 

- Input for `search_nearest()`

    - Model embedding which we can calculate using our **`get_embedding`** endpoint which was similarly MLflow custom pyfunc wrapped using `ce.get_embeddings()` where [`CellEmbedding` is from `SCimilarity`](https://genentech.github.io/scimilarity/modules/cell_embedding.html) and `ce = CellEmbedding(model_path="/opt/data/model")` 
    - **`k`** represents the number of nearest neighbours we would like to search.

- Output of `search_nearest()`

    - **`nn_idxs`**: `neighbour cells' indices` in the `SCimilarity` reference.
    - **`nn_dists`**: the `distance` between neighbour cells and the query.
    - **`metadata`**: a dataframe containing the `metadata` associated with each cell.

#### **`get_embedding`**

Let's first derive our cell type samples' embedding as inputs into the **`serarch_nearest`** endpoint

In [0]:
### derives the embedding for the given query cell using the specified endpoint.

cell_embedding = derive_embedding(databricks_instance, endpoint_name = get_embedding_endpoint, data_input = query_cell)
# cell_embedding

cell_embedding_pdf = pd.DataFrame(cell_embedding['predictions'])
cell_embedding_pdf #embedding colum is all the way to the right 

#### **`search_nearest`**

Now we can perform or `knn` using our **`serarch_nearest`** endpoint with the derived celltype sample `embedding(s)` and specified `k` parameter.

In [0]:
### searches for the nearest neighbors using the derived embedding and specified endpoint.

knn = search_nearest(
                     databricks_instance,                    
                     endpoint_name = search_nearest_endpoint,
                     embedding=cell_embedding_pdf[cell_embedding_pdf.celltype_sample_index == '123942'].embedding.values[0],                             
                     params={'k': 1000}
                   )

We can also check that `knn` `predictions` yield the expected output for '`nn_idxs`', '`nn_dists`', and '`metadata`' and inspect the associated `metadata` dataframe values for a random sample of indices. 

In [0]:
knn["predictions"][0].keys()

In [0]:
### inspect prediction results from KNN search + samples 10 random rows from the results metadata of the KNN predictions.

pd.DataFrame(knn["predictions"][0]['results_metadata']).sample(10, random_state=963)

### Making sense of search results

With the the most similar cells returned from our query, we can investigate the studies and conditions in which these cells are present. 

We have 2 helper functions to facilitate the derivation and visualization of disease proportions. 

In [0]:
def calculate_disease_proportions(metadata):
    study_proportions = metadata.disease.value_counts()
    return 100 * study_proportions / study_proportions.sum()


def plot_proportions(df, title=None):
    ax = df.plot(
                 kind="barh", xlabel="percent of cells", title=title, grid=False, figsize=(8, 5), color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # Add a list of colors
                 # kind="bar", ylabel="Percent of cells", title=title, grid=False, figsize=(8, 6)
                )
    ax.tick_params(axis="y", labelsize=8)
    ax.invert_yaxis()  # Invert the y-axis
    ax.set_xticklabels([f"{int(tick)}%" for tick in ax.get_xticks()])
    
    # ax.tick_params(axis="x", labelsize=8)
    # plt.xticks(rotation=60, ha='right')  # Rotate x-tick labels by X degrees
    # ax.set_yticklabels([f"{int(tick)}%" for tick in ax.get_yticks()])

    plt.tight_layout()

### Omit self-referencing results
To get a clearer view of the results, self-referencing hits are often excluded.    
Let's filter out results for cells from the same study before deriving the disease proportions where the cells are present.

In [0]:
query_study = "DS000011735"

filtered_df = pd.DataFrame(knn["predictions"][0]['results_metadata']).loc[
    pd.DataFrame(knn["predictions"][0]['results_metadata'])["study"] != query_study
]
display(filtered_df.head())

In [0]:
query_study = "DS000011735"

filtered_result_metadata = pd.DataFrame(knn["predictions"][0]['results_metadata']).loc[
    pd.DataFrame(knn["predictions"][0]['results_metadata'])["study"] != query_study
]
query_disease_frequencies = calculate_disease_proportions(filtered_result_metadata)
query_disease_frequencies = query_disease_frequencies[query_disease_frequencies > 0.1]
plot_proportions(
    query_disease_frequencies, title="Disease Proportions for Most Similar Cells"
)

In [0]:
query_disease_frequencies

This sample query manifests higher proportion of cells in multiple diseases including COVID-19, ulcerative colitis, ILDs, and cancers compared to healthy samples. However _**results can be skewed by the imbalanced abundances of diseases and tissues in the query reference**._ 

We can look at how enriched these cells are by counting the **number of predicted myofibroblasts** ([already precompted in `cq.cell_metadata` and extracted from `CellQuery` module in SCimilarity](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html#Get-reference-cell-metadata) and saved as a delta table for easier read for this notebook) **across diseases for both our query hits and the full reference**.

In [0]:
## readin scimilarity_cq_ref_metadata 
uc_path = f"{CATALOG}.{DB_SCHEMA}.scimilarity_cq_ref_metadata"
ref_metadata = spark.read.table(uc_path).toPandas()

# ref_metadata.info() #~23million rows

### Background comparison of `myofibroblasts` across diseases

_`Per SCimilarity search_nearest tutorial`_
>   To assess how enriched our results are for a disease of interest, we can visualize this imbalance:
> 
>   - Subset the full reference metadata to cells that are predicted to be `myofibroblasts`.
>   - Derive cell counts by disease state.
>   - Visualize the proportion of `myofibroblasts` within the reference collection in different diseases.

In [0]:
myofib_meta = ref_metadata[ref_metadata.prediction.isin(["myofibroblast cell"])]
query_disease_frequencies = calculate_disease_proportions(myofib_meta)
plot_proportions(
    query_disease_frequencies[:20],
    title="Disease Proportions for Reference Myofibroblasts",
)


_`Per SCimilarity search_nearest tutorial`_

> Over 50% of the `myofibroblasts` in the reference are from healthy tissues, while in our query they make up less than 5% of the cells most similar to our query. This [observation highlights] that these cells are in fact found in multiple conditions as [seen in the plot] earlier while they are more rare in healthy samples.

### NOTE:    
For further information on data requirements in using modules from the SCimilarity package, please refer to [guidelines](https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html#Conclusion) as well as to relevant [tutorials](https://genentech.github.io/scimilarity/tutorials.html) on the [official SCimilarity documentation](https://genentech.github.io/scimilarity/index.html). 

In [0]:
# for endpoint in endpoints:
#     # start_endpoint(databricks_instance, endpoint, DATABRICKS_TOKEN)
#     stop_endpoint(databricks_instance, endpoint, DATABRICKS_TOKEN)
