[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neonine2/morpheus-spatial/blob/master/examples/tutorial.ipynb)

# Morpheus tutorial with example data set

In this tutorial, we will demonstrate the complete Morpheus pipeline using an example data set from [Wang et al. (2023)]('https://doi.org/10.1016/j.cmet.2023.04.013'). This data set contains 209 tumor images taken from 30 patients with colorectal cancer.

If running this notebook in Google Colab, please select 'Runtime' -> 'Change runtime type' ->  set 'Runtime type' to Python 3 and Hardward accelerator to 'GPU'.

## Step 0: Download data set and set seed

In [None]:
# install morpheus if not already installed
!pip install morpheus-spatial requests

In [None]:
import requests
import zipfile
import json
import os
import morpheus as mp

from lightning.pytorch import seed_everything

seed_everything(42)  # Optional: sets seed for pytorch, numpy, python.random

We will download the input data from an online data repository, which consists of an input csv file and a txt file containing the channel names. 

For reproduction purpose, a trained model and patient split will also be downloaded and used in this notebook.

In [None]:
def download_and_unzip(record_id, filename, save_path):
    # check if save_path exists:
    if not os.path.exists(save_path):
        url = f"https://data.caltech.edu/records/{record_id}/files/{filename}"
        response = requests.get(url)
        response.raise_for_status()
        with open(filename, "wb") as f:
            f.write(response.content)
        with zipfile.ZipFile(filename, "r") as zip_ref:
            zip_ref.extractall(save_path)
        print(f"Downloaded {filename} to {save_path}")
    else:
        print(f"Data already exists in {save_path}")


# Download input data from the Caltech Data Portal
download_and_unzip("pr14s-wgk05", "crc_tutorial.zip", save_path="crc_tutorial")

# For paper reproduction purpose: load patient split and trained model
model_path = "crc_tutorial/model/unet.ckpt"
with open("crc_tutorial/patient_split.json", "r") as file:
    patient_split = json.load(file)

## Step 1: Creating a SpatialDataset Object

Start by creating a `SpatialDataset` object, which will hold all relevant information about the dataset we will be working with. 

### Prerequisites

To create a `SpatialDataset` object, you will need:
- The path to the input CSV file containing all single-cell expression information
- A list of channel names

### CSV File Structure

The expected structure of the CSV file is as follows:
- Each row corresponds to a single cell
- Columns for each channel name, with expression values specified
- Five additional columns with the following names and information:

| Column Name         | Description                               | Datatype    |
|---------------------|-------------------------------------------|-------------|
| `ImageNumber`       | Unique ID for each image                  | Integer     |
| `PatientID`         | Unique ID for each patient                | Str/Integer |
| `CellType`          | Cell type<sup>†</sup>                     | Str         |
| `Location_Center_X` | X coordinate of the cell center in micron | Float       |
| `Location_Center_Y` | Y coordinate of the cell center in micron | Float       |

**<sup>†</sup>Important**: in the `CellType` column, cytotoxic T cells must be labeled as `Tcytotoxic` and tumor cells must be labeled as `Tumor`. Additional cell types and metadata columns beyond those listed here will not be used in this tutorial.

To create a `SpatialDataset` object, specify the path to a single cell csv file and the corresponding list of channel names.

In [None]:
dataset = mp.SpatialDataset(
    input_path="crc_tutorial/singlecell.csv",
    channel_path="crc_tutorial/channel_names.txt",
)

## Step 2: patch images and mask cells

Next, we will generate image patches (of specified size and resolution) using the spatial data set, followed by masking out cytotoxic T cells

In [None]:
patch_size = 16  # Patch size in pixels
pixel_size = 3  # Pixel size in microns
cell_types = ["Tcytotoxic", "Tumor"]  # Specify the cell types of interest
mask_cell_types = ["Tcytotoxic"]
dataset.generate_masked_patch(
    cell_to_mask=mask_cell_types,
    cell_types=cell_types,
    patch_size=patch_size,
    pixel_size=pixel_size,
)

## Step 3: generate data splits for model training

Next, we generate train, validation, and test data splits for model training, where split is done at the patient level. We want to stratify our splits by the label we want to predict by specifying the `stratify_by` parameter.

In [None]:
colname = "Contains_Tcytotoxic"
dataset.generate_data_splits(stratify_by=colname, specify_split=patient_split)

## Step 4: train classifier model

After generating data splits, we train a unet model to predict the presence of T cells from masked patches. A model instance is first created using the `PatchClassifier` class and trained by calling the `train` function. 

Feel free to skip this step and proceed directly to step 5 as a trained model has already been downloaded.

In [None]:
# initialize model
model_arch = "unet"
n_channels = dataset.n_channels
img_size = dataset.img_size

model = mp.PatchClassifier(n_channels, img_size, model_arch)

In [None]:
# train model
trainer_params = {
    "max_epochs": 30,
    "accelerator": "auto",
    "logger": False,
}
model = mp.train(
    model=model,
    dataset=dataset,
    predict_label=colname,
    trainer_params=trainer_params,
)

## Step 5: generate counterfactuals using trained classifier

Since we are interested in perturbations that can drive T cells to infiltrate tumor, we only need to generate counterfactuals (cf) for (training) patches containing tumor but no T cells. We pass theses specific patches to `get_counterfactual` under the parameter `images` along with the path to a trained classifier under the parameter `model_path`. Note that by default a pre-downloaded unet is passed as argument for reproduction purposes, you can change this to point to your own trained unet from the train step above.

During cf generation, we will build a kdtree from the training patches. Then this kdtree is used to generate countefactual for each patch independently. Hence cf generation can (and needs to) be parallelized to achieve massive speed boos. Counterfactual generation per instance may be on the order of minutes. For large number of input instances, setting `num_workers` to be greater than 1 enables Ray parallelization and speeds things up. In order to complete cf generation on the order of hours (instead of days), we will need to distribute the instances across a large cluster, follow instructions [here]('https://doi.org/10.1016/j.cmet.2023.04.013') for using Ray with Slurm. A HPC version of this tutorial notebook that includes slurm job submission will be available soon, feel free to open a github issue if this would be helpful to you.

In [None]:
# select tumor patches that do not contain T cells from training cohort to generate counterfactuals
dataset.get_split_info()
select_metadata = dataset.metadata[
    (dataset.metadata["Contains_Tumor"] == 1)
    & (dataset.metadata["Contains_Tcytotoxic"] == 0)
    & (dataset.metadata["splits"] == "train")
]

# example of selected instances to generate counterfactuals
print(f"Number of selected instances: {len(select_metadata)}")
print(select_metadata.head())

In [None]:
# Parameters for counterfactual generation
optimization_param = {
    "use_kdtree": True,
    "theta": 50.0,
    "kappa": -0.34,
    "learning_rate_init": 0.1,
    "beta": 80.0,
    "max_iterations": 1000,
    "c_init": 1000.0,
    "c_steps": 5,
    "channel_to_perturb": [
        "Glnsynthetase",
        "CCR4",
        "PDL1",
        "LAG3",
        "CD105endoglin",
        "TIM3",
        "CXCR4",
        "PD1",
        "CYR61",
        "CD44",
        "IL10",
        "CXCL12",
        "CXCR3",
        "Galectin9",
        "YAP",
    ],
}

# Generate counterfactuals using trained model
mp.get_counterfactual(
    images=select_metadata,
    dataset=dataset,
    target_class=1,
    model_path=model_path,  # for paper reproduction purpose, set to pre-downloaded model
    optimization_params=optimization_param,
    save_dir=f"{dataset.root_dir}/cf/example/",
    device="cpu",
    num_workers=os.cpu_count(),  # set to greater than 1 for cpu parallelization with Ray
    verbosity=0,
)

### Frequently Asked Questions

**Code in this section is not meant to be executed, they are for educational purposes only!**

1. How do I use my own classification model?

In [None]:
# Step 1: Load your custom predictor class
from crc_tutorial.model.my_model import (
    MyOwnModel,
)  # IMPORTANT: model must be available as a python module
import torch

# Step 2: Initialize the predictor
predictor = MyOwnModel(img_size=img_size, in_channels=n_channels)

# Step 3: Train the predictor
predictor.train()

# Step 4: Save the entire model
torch.save(predictor, "my_model.pth")

# Step 5: pass the model path to the get_counterfactual function
mp.get_counterfactual(
    ...,
    model_path="my_model.pth",
)