ref: https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html

<!-- Requires MLDBR_13.3LTSgpu? -->

In [0]:
import requests

# Define the endpoint names | use generic names from deployment+registration
endpoints = [
             "mmt_scimilarity_gene_order",
            #  "mmt_scimilarity_get_embedding",
             "mmt_scimilarity_get_embedding_v2test", #updated signature
             "mmt_scimilarity_search_nearest"
            ]

# Define the Databricks instance and token
databricks_instance = "adb-830292400663869.9.azuredatabricks.net"
DATABRICKS_TOKEN = dbutils.secrets.get("mmt", "databricks_token") ## update to user's SP/PAT 

# Function to start an endpoint
def start_endpoint(databricks_instance, endpoint_name, token):
    url = f"https://{databricks_instance}/api/2.0/serving-endpoints/{endpoint_name}/config:start"
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }
    response = requests.post(url, headers=headers)
    if response.status_code == 200:
        print(f"Successfully started endpoint: {endpoint_name}")
    else:
        print(f"Failed to start endpoint: {endpoint_name}, Status Code: {response.status_code}, Response: {response.text}")


# Function to stop an endpoint
def stop_endpoint(databricks_instance, endpoint_name, token):
    url = f"https://{databricks_instance}/api/2.0/serving-endpoints/{endpoint_name}/config:stop"
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }
    response = requests.post(url, headers=headers)
    if response.status_code == 200:
        print(f"Successfully stopped endpoint: {endpoint_name}")
    else:
        print(f"Failed to stop endpoint: {endpoint_name}, Status Code: {response.status_code}, Response: {response.text}")


# Function to check the status of an endpoint
def check_endpoint_status(databricks_instance, endpoint_name, token):
    url = f"https://{databricks_instance}/api/2.0/serving-endpoints/{endpoint_name}"
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }
    response = requests.get(url, headers=headers)
    if response.status_code == 200:
        status = response.json().get("state", {}).get("ready", "Unknown")
        print(f"Endpoint: {endpoint_name}, Status: {status}")
        return status
    else:
        print(f"Failed to check status of endpoint: {endpoint_name}, Status Code: {response.status_code}, Response: {response.text}")
        return "Unknown"
    

## Example USAGE
# Loop through the endpoints and start/stop them OR check their status
# for endpoint in endpoints:
#     start_endpoint(databricks_instance, endpoint, DATABRICKS_TOKEN)
    # stop_endpoint(databricks_instance, endpoint, DATABRICKS_TOKEN)
    # check_endpoint_status(databricks_instance, endpoint, DATABRICKS_TOKEN)

In [0]:
for endpoint in endpoints:
    # start_endpoint(databricks_instance, endpoint, DATABRICKS_TOKEN)
    stop_endpoint(databricks_instance, endpoint, DATABRICKS_TOKEN)

In [0]:
# ~15--20++mins sometimes >>+++++

## need to ping the scaled-to-zero endpoints instead 

import time

# Function to wait until all endpoints are ready
def wait_until_endpoints_ready(databricks_instance, endpoints, token, check_interval=60):
    all_ready = False
    while not all_ready:
        all_ready = True
        for endpoint in endpoints:
            status = check_endpoint_status(databricks_instance, endpoint, token)
            if status != "READY":
                all_ready = False
                print(f"Endpoint {endpoint} is not ready yet. Checking again in {check_interval} seconds...")
                break
        if not all_ready:
            time.sleep(check_interval)
    print("All endpoints are ready!")

### 
wait_until_endpoints_ready(databricks_instance, endpoints, DATABRICKS_TOKEN, check_interval=60)

MOVE To separate nb ^^

In [0]:
# !cat /Volumes/genesis_workbench/dev_mmt_core_test/scimilarity/mlflow_requirements/SCimilarity_Search_Nearest/requirements.txt

In [0]:
## used the saved requirements 
# /Volumes/genesis_workbench/dev_mmt_core_test/scimilarity/mlflow_requirements/SCimilarity_Get_Embedding/requirements.txt
# /Volumes/genesis_workbench/dev_mmt_core_test/scimilarity/mlflow_requirements/SCimilarity_Search_Nearest/requirements.txt

requirements = [
    "scimilarity==0.4.0",
    "typing_extensions>=4.14.0",
    "scanpy==1.11.2", #
    "numcodecs==0.13.1", #
    "numpy==1.26.4",
    "pandas==1.5.3",
    "mlflow==2.22.0",
    "cloudpickle==2.0.0", #
    "tbb>=2021.6.0",    
    # "uv"
]

for package in requirements:
    %pip install {package} --upgrade -v

dbutils.library.restartPython()

In [0]:
# Import numba after installing the required packages
import numba

# Set this at the beginning of your notebook/script
numba.config.THREADING_LAYER = 'workqueue'  # Most compatible option
# Other options: 'omp' (OpenMP) or 'tbb' (default)

In [0]:
## for nb devs -- these get overwritten wrt deployment args
dbutils.widgets.text("catalog", "genesis_workbench", "Catalog")
dbutils.widgets.text("schema", "dev_mmt_core_test", "Schema") 

dbutils.widgets.text("model_name", "SCimilarity", "Model Name") ## use this as a prefix for the model name ?

dbutils.widgets.text("cache_dir", "scimilarity", "Cache dir") ## VOLUME NAME | MODEL_FAMILY 

CATALOG = dbutils.widgets.get("catalog")
SCHEMA = dbutils.widgets.get("schema")

MODEL_NAME = dbutils.widgets.get("model_name")

CACHE_DIR = dbutils.widgets.get("cache_dir")

print(f"Cache dir: {CACHE_DIR}")
cache_full_path = f"/Volumes/{CATALOG}/{SCHEMA}/{CACHE_DIR}"
print(f"Cache full path: {cache_full_path}")

In [0]:
CATALOG = CATALOG #"mmt"
DB_SCHEMA = SCHEMA #"tests" | "genesiswb"

# VOLUME_NAME | PROJECT 
MODEL_FAMILY = CACHE_DIR ## CACHE_DIR #"scimilarity"

# MODEL_NAME #"SCimilarity" 

print("CATALOG :", CATALOG)
print("DB_SCHEMA :", DB_SCHEMA)
print("MODEL_FAMILY :", MODEL_FAMILY)

In [0]:
# from scimilarity import CellQuery, CellEmbedding ## these are the functions we wrapped in MLflow PyFunc
from scimilarity import align_dataset, lognorm_counts

import scanpy as sc
from scipy import sparse

import numpy as np
import pandas as pd
import json
import requests

import os

from collections.abc import MutableMapping  # 

In [0]:
# import scanpy as sc
from matplotlib import pyplot as plt

sc.set_figure_params(dpi=100)
plt.rcParams["figure.figsize"] = [6, 4]

import warnings

warnings.filterwarnings("ignore")

In [0]:
def create_serving_json(data, params=None):
    # Function to create Serving JSON format
    return {"instances": data} if params is None else {"instances": data, "params": params}


def api_inference(databricks_instance, endpoint_name, data_input, params=None):
    url = f'https://{databricks_instance}/serving-endpoints/{endpoint_name}/invocations'
    headers = {
        'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}',
        'Content-Type': 'application/json'
    }
    
    if isinstance(data_input, pd.DataFrame):
        ds_dict = {'dataframe_split': data_input.to_dict(orient='split')}
        if params:
            ds_dict['params'] = params
    else:
        ds_dict = create_serving_json(data_input, params)
    
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.post(url, headers=headers, data=data_json)
    
    if response.status_code == 200:
        try:
            return response.json()
        except json.JSONDecodeError:
            return {"error": "Invalid JSON response", "response_text": response.text}
    else:
        return {"error": f"HTTP {response.status_code}", "response_text": response.text}

In [0]:
# def derive_gene_order(endpoint_name = "mmt_scimilarity_gene_order") -> list[str]:  
def derive_gene_order(databricks_instance, endpoint_name) -> list[str]:  
    example_input = pd.DataFrame({"input": ["get_gene_order"]})
    get_gene_order = api_inference(databricks_instance, endpoint_name, data_input=example_input)
    return get_gene_order['predictions']

# gene_order = derive_gene_order()

## updated version 
def derive_embedding(databricks_instance, endpoint_name, data_input) -> pd.DataFrame():  
    ## data_input: X_vals

    celltype_sample_obs_json = data_input.obs.to_json(orient='split')
    X_vals_dense = data_input.X.toarray().tolist()
    celltype_subsample_pdf = pd.DataFrame([{'celltype_subsample': row} for row in X_vals_dense ], 
                                        index=data_input.obs.index
                                        )

    example_input = pd.DataFrame([{
                                    "celltype_sample": celltype_subsample_pdf.to_json(orient='split'), 
                                    "celltype_sample_obs": data_input.obs.to_json(orient='split')
                                }])
    
    get_embeddings = api_inference(databricks_instance, endpoint_name, data_input=example_input)
    return get_embeddings #['predictions'][0]['embedding']#[0]

# cell_embedding = derive_embedding(endpoint_name = "mmt_scimilarity_get_embedding", data_input = X_vals)


def search_nearest(databricks_instance, endpoint_name, embedding, params = {'k': 10}) -> dict:
    example_input = pd.DataFrame({"embedding": embedding})
    return api_inference(databricks_instance, endpoint_name, data_input=example_input, params = params)

## knn_output = search_nearest(databricks_instance, endpoint_name, embedding, params = {'k': 10})

In [0]:
import os 
DATABRICKS_TOKEN = dbutils.secrets.get("mmt","databricks_token") ## to update with APP SP/PAT token?
os.environ["DATABRICKS_TOKEN"] = DATABRICKS_TOKEN

databricks_instance = "adb-830292400663869.9.azuredatabricks.net"

## testing endpoints -- to replace with gwb-app ones... 
gene_order_endpoint = "mmt_scimilarity_gene_order"
# get_embedding_endpoint = "mmt_scimilarity_get_embedding" 
get_embedding_endpoint = "mmt_scimilarity_get_embedding_v2test" 
search_nearest_endpoint = "mmt_scimilarity_search_nearest"

In [0]:
from scimilarity import CellQuery
model_path = '/Volumes/genesis_workbench/dev_mmt_core_test/scimilarity/model/model_v1.1/' #"/models/model_v1.1"
cq = CellQuery(model_path)

#### DATA related


In [0]:
sampledata_path = f"/Volumes/{CATALOG}/{DB_SCHEMA}/{MODEL_FAMILY}/data/adams_etal_2020/GSE136831_subsample.h5ad"

adams0 = sc.read(sampledata_path)
adams0 = align_dataset(adams0, cq.gene_order)
adams = lognorm_counts(adams0)

adams0.obsm["X_scimilarity"] = cq.get_embeddings(adams0.X)

sc.pp.neighbors(adams0, use_rep="X_scimilarity")
sc.tl.umap(adams0)

In [0]:
sc.pl.umap(adams, color="celltype_raw", legend_fontsize=8)

In [0]:
sampledata_path = f"/Volumes/{CATALOG}/{DB_SCHEMA}/{MODEL_FAMILY}/data/adams_etal_2020/GSE136831_subsample.h5ad"

## READ sample czi dataset H5AD file + align + lognorm 
adams = sc.read(sampledata_path)

# gene_order = derive_gene_order(databricks_instance, endpoint_name = "mmt_scimilarity_gene_order")
gene_order = derive_gene_order(databricks_instance, endpoint_name = gene_order_endpoint)
aligned = align_dataset(adams, gene_order) #

lognorm = lognorm_counts(aligned)

In [0]:
#### Filter samle data to "Disease" | celltype: "myofibroblast cell" | sample_ref: "DS000011735-GSM4058950"  

disease_name = "IPF"
celltype_name = "myofibroblast cell" 
sample_refid = "DS000011735-GSM4058950"
subsample_refid = "123942"

diseasetype_ipf = lognorm[lognorm.obs["Disease"] == disease_name].copy()

celltype_myofib = diseasetype_ipf[
                                  diseasetype_ipf.obs["celltype_name"] == celltype_name
                                 ].copy()

## Extract list for sample_ref
celltype_sample = celltype_myofib[
                                  celltype_myofib.obs["sample"] == sample_refid 
                                 ].copy()

## extract specific index in celltype_sample 
celltype_subsample = celltype_sample[celltype_sample.obs.index == subsample_refid]

## extract specific index in subsample | test batch inference? 
# query_cell = celltype_subsample[subsample.obs.index == "123942"]
# query_cell = celltype_subsample[subsample.obs.index == "124332"]
# query_cell = celltype_subsample[subsample.obs.index == "126138"]


query_cell = celltype_sample 
# query_cell = celltype_subsample

## extract subsample query (1d array or list)
# X_vals: sparse.csr_matrix = query_cell.X
# print(X_vals)

# X_vals_dense = X_vals.toarray()
# print(X_vals_dense)

In [0]:
query_cell.obs

In [0]:
# # celltype_subsample_pdf = pd.DataFrame([{'celltype_subsample': row} for row in X_vals_dense ], 
# #                                       index=celltype_sample.obs.index
# #                                      )

# celltype_sample_obs_json = query_cell.obs.to_json(orient='split')

# celltype_subsample_pdf = pd.DataFrame([{'celltype_subsample': row} for row in X_vals_dense ], 
#                                       index=query_cell.obs.index
#                                      )

# # model_input = pd.DataFrame([{
# #     "celltype_sample": celltype_subsample_pdf.iloc[0:1].to_json(orient='split'), 
# #     "celltype_sample_obs": query_cell.obs.iloc[0:1].to_json(orient='split')
# # }])

# model_input = pd.DataFrame([{
#     "celltype_sample": celltype_subsample_pdf.to_json(orient='split'), 
#     "celltype_sample_obs": query_cell.obs.to_json(orient='split')
# }])

# model_input

In [0]:
## ADD VIZ

In [0]:
marker_list = {
    "Fibroblast": ["COL1A1", "COL3A1", "DCN", "FBLN1", "FN1", "LUM", "THY1"],
    "Myofibroblast": ["CDH11", "COMP", "CTHRC1", "ELN", "POSTN", "TNC"],
    "Smooth Muscle": ["ACTA2", "ACTG2", "DES", "MYH11", "MYL9", "TAGLN"],
}

In [0]:
celltype_subsample.obs["cell_id"] = celltype_subsample.obs.index

fig, axes = plt.subplots(1, 1, figsize=(12, 4))
sc.pl.dotplot(
    celltype_subsample,
    var_names=marker_list,
    groupby="cell_id",
    var_group_rotation=0,
    ax=axes,
    show=False,
);

In [0]:
celltype_sample.obs["cell_id"] = celltype_sample.obs.index

fig, axes = plt.subplots(1, 1, figsize=(12, 4))
sc.pl.dotplot(
    celltype_sample,
    var_names=marker_list,
    groupby="cell_id",
    var_group_rotation=0,
    ax=axes,
    show=False,
);

In [0]:
used_in_query = adams.obs.index == "123942"
adams.obs["used_in_query"] = used_in_query.astype(int)
f = sc.pl.umap(adams, color=["used_in_query"], cmap="YlOrRd", return_fig=True)
f.axes[0].arrow(6.1, 10.5, 1, -1, head_width=0.5, head_length=0.5)

In [0]:
cell_embedding = derive_embedding(databricks_instance, endpoint_name = get_embedding_endpoint, data_input = query_cell)
cell_embedding

In [0]:
cell_embedding_pdf = pd.DataFrame(cell_embedding['predictions'])
cell_embedding_pdf

In [0]:
## ADD VIZ

In [0]:
# cell_embedding_pdf.loc[cell_embedding_pdf.celltype_sample_index == '123942'].embedding

In [0]:
knn = search_nearest(
                             databricks_instance,
                             # endpoint_name = "mmt_scimilarity_search_nearest",
                             endpoint_name = search_nearest_endpoint,
                             embedding=cell_embedding_pdf[cell_embedding_pdf.celltype_sample_index == '123942'].embedding.values[0],                             
                             params={'k': 1000}
                            )

knn

In [0]:
knn["predictions"][0].keys()

In [0]:
def calculate_disease_proportions(metadata):
    study_proportions = metadata.disease.value_counts()
    return 100 * study_proportions / study_proportions.sum()


def plot_proportions(df, title=None):
    ax = df.plot(
        kind="barh", xlabel="percent of cells", title=title, grid=False, figsize=(8, 5)
    )
    ax.tick_params(axis="y", labelsize=8)
    ax.set_xticklabels([f"{int(tick)}%" for tick in ax.get_xticks()])
    plt.tight_layout()


In [0]:
pd.DataFrame(knn["predictions"][0]['results_metadata'])

In [0]:
filtered_df = pd.DataFrame(knn["predictions"][0]['results_metadata']).loc[
    pd.DataFrame(knn["predictions"][0]['results_metadata'])["study"] != query_study
]
display(filtered_df)

In [0]:
query_study = "DS000011735"
filtered_result_metadata = pd.DataFrame(knn["predictions"][0]['results_metadata']).loc[
    pd.DataFrame(knn["predictions"][0]['results_metadata'])["study"] != query_study
]
query_disease_frequencies = calculate_disease_proportions(filtered_result_metadata)
query_disease_frequencies = query_disease_frequencies[query_disease_frequencies > 0.1]
plot_proportions(
    query_disease_frequencies, title="disease proportions for most similar cells"
)

In [0]:
query_disease_frequencies

In [0]:
ref_metadata = cq.cell_metadata

In [0]:
myofib_meta = ref_metadata[ref_metadata.prediction.isin(["myofibroblast cell"])]
query_disease_frequencies = calculate_disease_proportions(myofib_meta)
plot_proportions(
    query_disease_frequencies[:15],
    title="disease proportions for reference myofibroblasts",
)

In [0]:
# for endpoint in endpoints:
#     # start_endpoint(databricks_instance, endpoint, DATABRICKS_TOKEN)
#     stop_endpoint(databricks_instance, endpoint, DATABRICKS_TOKEN)
 