# Cell and Tissue Search Tutorial
KRONOS embedding can be used to search and retrieve cell or tissue samples with similar phenotypic or spatial patterns.  

## Prerequisites

To follow this tutorial, ensure you have the following data prepared:

1. **Query cell/patch embeddings**: A folder containing query cell/patch embeddings, where each embedding is a numpy array stored in a npy file.
2. **Support cell/patch embeddings**: A folder containing support cell/patch embeddings, similar to query embeddings folder.


**Notes**: Refer to **[2 - Cell-phenotyping.ipynb](https://github.com/mahmoodlab/KRONOS/blob/main/tutorials/2%20-%20Cell-phenotyping.ipynb)** and **[3 - Patch-phenotyping.ipynb](https://github.com/mahmoodlab/KRONOS/blob/main/tutorials/3%20-%20Patch-phenotyping.ipynb)** tutorials for cell and patch embedding extractions.
 

## Step 1: Import Required Packages

We begin by importing the necessary libraries and modules for the workflow.

In [1]:
import os
import torch
import pandas as pd
import numpy as np
import torch.nn.functional as F

## Step 2: Experiment Configuration

Configure retrieval experiment settings.

In [2]:
# Configuration dictionary containing all parameters for the pipeline
config = {
    "query_folder": f"/data/query", # Replace with your actual query folder
    "support_folder": f"/data/support", # Replace with your actual support folder
    "results_path": f"retrieval_results.csv", # Replace with your actual results path
    'device': 'cuda:0', # Replace with the device you want to use.
    # Retrieval-related parameters
    'topk': 5, # How many samples do you want to retrieve from the support
    'similarity': 'l2', # use l2 distance or cosine similarity to calculate the similarity scores; use l2 distance by default
    'centering': True # whether to center the features before retrieval
}

## Step 3: Define Retrieval Functions

We define a function that takes the above defined configs as input, perform the retrieval, and saves the topk retrieved samples for each query sample in a csv file.

In [None]:
def retrieval(query_folder, key_folder, results_path, device='cuda:0', topk=5, similarity='l2', centering=True):
    # load query and support data from their respective folders
    query_filenames = [file for file in os.listdir(query_folder) if file.endswith('.npy')]
    queries = np.array([np.load(os.path.join(query_folder, file)) for file in query_filenames])
    support_filenames = [file for file in os.listdir(key_folder) if file.endswith('.npy')]
    keys = np.array([np.load(os.path.join(key_folder, filename)) for filename in support_filenames])
    assert similarity in ['cosine', 'l2']
    
    # move both queries and supports to device
    queries = torch.from_numpy(queries).to(device)
    keys = torch.from_numpy(keys).to(device)

    # preprocessing for centering and normalization
    if centering:
        means = keys.mean(dim=0, keepdim=True)
        keys = keys - means
        queries = queries - means
    if similarity == 'cosine' or centering:
        queries = F.normalize(queries.float(), dim=1)
        keys = F.normalize(keys.float(), dim=1)
    
    # calculate similarities
    if similarity == 'cosine':
        sim_scores = torch.matmul(queries, keys.T)
    elif similarity == 'l2':
        sim_scores = -torch.cdist(queries, keys, p=2) # take negative to make it a similarity
    else:
        raise ValueError(f'similarity {similarity} not supported!')

    # obtain topk retrieved indices
    _, topk_ids = torch.topk(sim_scores, max(ks), dim=1)
    topk_filenames = np.array(support_filenames)[topk_ids]

    # make the result csv
    topk_df = pd.DataFrame(topk_filenames, columns=[f'top{k+1}' for k in range(topk_filenames.shape[1])])
    topk_df.index = query_filenames
    topk_df.to_csv(results_path)

## Step 4: Perform Actual Retrieval

Now we run the actual retrieval using the defined configs. The resulting csv file saved in the `results_path` defined in `config` contains the top-k retrieved file names in the support set, for each query. It will be similar to

||top1|top2|top3|top4|top5|
|--|--|--|--|--|--|
|query_tissue_1.npy|support_tissue_10.npy|support_tissue_2.npy|support_tissue_5.npy|support_tissue_26.npy|support_tissue_23.npy|
|...|
|query_tissue_n.npy|support_tissue_50.npy|support_tissue_36.npy|support_tissue_25.npy|support_tissue_46.npy|support_tissue_33.npy|

In [None]:
retrieval(**config)