ref: https://genentech.github.io/scimilarity/notebooks/cell_search_tutorial_1.html

In [0]:
requirements = [
    "scimilarity==0.4.0",
    "typing_extensions>=4.14.0",
    "numpy==1.26.4",
    "pandas==1.5.3",
    "mlflow==2.22.0",
    "tbb>=2021.6.0",
    # "uv"
]

for package in requirements:
    %pip install {package} --upgrade -v

dbutils.library.restartPython()

In [0]:
# Import numba after installing the required packages
import numba

# Set this at the beginning of your notebook/script
numba.config.THREADING_LAYER = 'workqueue'  # Most compatible option
# Other options: 'omp' (OpenMP) or 'tbb' (default)

In [0]:
## for nb devs -- these get overwritten wrt deployment args
dbutils.widgets.text("catalog", "genesis_workbench", "Catalog")
dbutils.widgets.text("schema", "dev_mmt_core_test", "Schema") 

dbutils.widgets.text("model_name", "SCimilarity", "Model Name") ## use this as a prefix for the model name ?

dbutils.widgets.text("cache_dir", "scimilarity", "Cache dir") ## VOLUME NAME | MODEL_FAMILY 

CATALOG = dbutils.widgets.get("catalog")
SCHEMA = dbutils.widgets.get("schema")

MODEL_NAME = dbutils.widgets.get("model_name")

CACHE_DIR = dbutils.widgets.get("cache_dir")

print(f"Cache dir: {CACHE_DIR}")
cache_full_path = f"/Volumes/{CATALOG}/{SCHEMA}/{CACHE_DIR}"
print(f"Cache full path: {cache_full_path}")

In [0]:
CATALOG = CATALOG #"mmt"
DB_SCHEMA = SCHEMA #"tests" | "genesiswb"

# VOLUME_NAME | PROJECT 
MODEL_FAMILY = CACHE_DIR ## CACHE_DIR #"scimilarity"

# MODEL_NAME #"SCimilarity" 

print("CATALOG :", CATALOG)
print("DB_SCHEMA :", DB_SCHEMA)
print("MODEL_FAMILY :", MODEL_FAMILY)

In [0]:
# from scimilarity import CellQuery, CellEmbedding ## these are the functions we wrapped in MLflow PyFunc
from scimilarity import align_dataset, lognorm_counts

import scanpy as sc
from scipy import sparse

import numpy as np
import pandas as pd
import json
import requests

import os

from collections.abc import MutableMapping  # 

In [0]:
def create_serving_json(data, params=None):
    # Function to create Serving JSON format
    return {"instances": data} if params is None else {"instances": data, "params": params}

def api_inference(databricks_instance, endpoint_name, data_input, params=None):
    url = f'https://{databricks_instance}/serving-endpoints/{endpoint_name}/invocations'
    headers = {
        'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}',
        'Content-Type': 'application/json'
    }
    
    if isinstance(data_input, pd.DataFrame):
        ds_dict = {'dataframe_split': data_input.to_dict(orient='split')}
        if params:
            ds_dict['params'] = params
    else:
        ds_dict = create_serving_json(data_input, params)
    
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.post(url, headers=headers, data=data_json)
    
    if response.status_code == 200:
        try:
            return response.json()
        except json.JSONDecodeError:
            return {"error": "Invalid JSON response", "response_text": response.text}
    else:
        return {"error": f"HTTP {response.status_code}", "response_text": response.text}

In [0]:
# def derive_gene_order(endpoint_name = "mmt_scimilarity_gene_order") -> list[str]:  
def derive_gene_order(databricks_instance, endpoint_name) -> list[str]:  
    example_input = pd.DataFrame({"input": ["get_gene_order"]})
    get_gene_order = api_inference(databricks_instance, endpoint_name, data_input=example_input)
    return get_gene_order['predictions']

# gene_order = derive_gene_order()

def derive_embedding(databricks_instance, endpoint_name, data_input) -> list[str]:  
    ## data_input: X_vals
    example_input = pd.DataFrame([{'subsample_query_array': data_input.toarray()[0].tolist() }]) ## required input formatting 
    # get_embeddings = api_inference(databricks_instance, endpoint_name="mmt_scimilarity_get_embedding", data_input=example_input)
    get_embeddings = api_inference(databricks_instance, endpoint_name, data_input=example_input)
    return get_embeddings['predictions'][0]['embedding']#[0]

# cell_embedding = derive_embedding(endpoint_name = "mmt_scimilarity_get_embedding", data_input = X_vals)

def search_nearest(databricks_instance, endpoint_name, embedding, params = {'k': 10}) -> dict:
    example_input = pd.DataFrame({"embedding": embedding})
    return api_inference(databricks_instance, endpoint_name, data_input=example_input, params = params)

In [0]:
data_input = X_vals
endpoint_name = "mmt_scimilarity_get_embedding"
example_input = pd.DataFrame([{'subsample_query_array': data_input.toarray()[0].tolist() }]) ## required input formatting 
# get_embeddings = api_inference(databricks_instance, endpoint_name="mmt_scimilarity_get_embedding", data_input=example_input)
get_embeddings = api_inference(databricks_instance, endpoint_name, data_input=example_input)

In [0]:
get_embeddings
# {'error': 'HTTP 504', 'response_text': 'upstream request timeout'} #need a way to handle this WITH RETRYS

In [0]:
import os 
DATABRICKS_TOKEN = dbutils.secrets.get("mmt","databricks_token")
os.environ["DATABRICKS_TOKEN"] = DATABRICKS_TOKEN

databricks_instance = "adb-830292400663869.9.azuredatabricks.net"

## testing endpoints -- to update with gwb-app ones... 
gene_order_endpoint = "mmt_scimilarity_gene_order"
get_embedding_endpoint = "mmt_scimilarity_get_embedding"
search_nearest_endpoint = "mmt_scimilarity_search_nearest"

In [0]:
# example_input = pd.DataFrame({"input": ["get_gene_order"]})
# get_gene_order = api_inference(databricks_instance, endpoint_name="mmt_SCimilarity_gene_order", data_input=example_input)
# get_gene_order['predictions']

# derive_gene_order(databricks_instance, endpoint_name = "mmt_scimilarity_gene_order")  
# derive_gene_order(databricks_instance, endpoint_name = gene_order_endpoint) 

In [0]:
sampledata_path = f"/Volumes/{CATALOG}/{DB_SCHEMA}/{MODEL_FAMILY}/data/adams_etal_2020/GSE136831_subsample.h5ad"

## READ sample czi dataset H5AD file + align + lognorm 
adams = sc.read(sampledata_path)

# gene_order = derive_gene_order(databricks_instance, endpoint_name = "mmt_scimilarity_gene_order")
gene_order = derive_gene_order(databricks_instance, endpoint_name = gene_order_endpoint)
aligned = align_dataset(adams, gene_order) #

lognorm = lognorm_counts(aligned)

### Filter sample data to "Disease" | celltype: "myofibroblast cell" | sample_ref: "DS000011735-GSM4058950"  
adams_ipf = lognorm[lognorm.obs["Disease"] == "IPF"].copy()

adams_myofib = adams_ipf[
                          adams_ipf.obs["celltype_name"] == "myofibroblast cell"
                        ].copy()

## Extract list for sample_ref
subsample = adams_myofib[
                          adams_myofib.obs["sample"] == "DS000011735-GSM4058950" # sample ref 
                        ].copy()

## extract specific index in subsample | test batch inference? 
# query_cell = subsample[subsample.obs.index == "123942"]
# query_cell = subsample[subsample.obs.index == "124332"]
query_cell = subsample[subsample.obs.index == "126138"]

## extract subsample query (1d array or list)
X_vals: sparse.csr_matrix = query_cell.X
X_vals

In [0]:
# from pyspark.sql import Row
# import numpy as np

# # Convert the sparse matrix to a dense format
# dense_X_vals = X_vals.todense()

# # Convert the dense matrix to a list of lists
# dense_X_vals_list = dense_X_vals.tolist()

# # Create a list of Rows
# rows = [Row(embedding=row) for row in dense_X_vals_list]

# # Create a Spark DataFrame from the list of Rows
# query_cell_spark_df = spark.createDataFrame(rows)

# # Display the Spark DataFrame
# display(query_cell_spark_df)

In [0]:
subsample.obs

In [0]:
cell_embedding = derive_embedding(databricks_instance, endpoint_name = "mmt_scimilarity_get_embedding", data_input = X_vals)

# cell_embedding = derive_embedding(databricks_instance, endpoint_name = get_embedding_endpoint, data_input = X_vals)

cell_embedding

In [0]:
knn_results = search_nearest(
                             databricks_instance,
                            #  endpoint_name = "mmt_scimilarity_search_nearest",
                             endpoint_name = search_nearest_endpoint,
                             embedding = cell_embedding,
                             params={'k': 5}
                            )

knn_results

In [0]:
## todo
# -test batch inference -- need to convert sample queries to spark dataframe
# -test ai_query() | initial testing encountered formatting issues 