# Morpheus tutorial with example data set

In this tutorial, we will demonstrate the complete Morpheus pipeline using an example data set from [Wang et al. (2023)]('https://doi.org/10.1016/j.cmet.2023.04.013'). This data set contains 209 tumor images taken from 30 patients with colorectal cancer.

## Step 0: Download data set and set seed

In [None]:
import requests
import zipfile
import json
import os
import morpheus as mp

from lightning.pytorch import seed_everything
seed_everything(42) # Optional: sets seed for pytorch, numpy, python.random

%reload_ext autoreload
%autoreload 2

We will download the input data from an online data repository, which consists of an input csv file and a txt file containing the channel names. For reproduction purpose, a trained model and patient split will also be downloaded and loaded into this notebook.

In [None]:
def download_and_unzip(record_id, filename, save_path):
    # check if save_path exists:
    if not os.path.exists(save_path):
        url = f"https://data.caltech.edu/records/{record_id}/files/{filename}"
        response = requests.get(url)
        response.raise_for_status()
        with open(filename, "wb") as f:
            f.write(response.content)
        with zipfile.ZipFile(filename, "r") as zip_ref:
            zip_ref.extractall(save_path)
        print(f"Downloaded {filename} to {save_path}")
    else:
        print(f"Data already exists in {save_path}")


# Download input data from the Caltech Data Portal
download_and_unzip("465sy-9g558", "crc_input.zip", save_path="crc")

# load channel names
with open("crc/channel_names.txt", "r") as f:
    channel_names = f.read().splitlines()

# For paper reproduction purpose: load patient split and trained model
model_path = "crc/model/unet.ckpt"
with open("crc/patient_split.json", "r") as file:
    patient_split = json.load(file)

## Step 1: Creating a SpatialDataset Object

Start by creating a `SpatialDataset` object, which will hold all relevant information about the dataset we will be working with. 

### Prerequisites

To create a `SpatialDataset` object, you will need:
- The path to the input CSV file containing all single-cell expression information
- A list of channel names

### CSV File Structure

The expected structure of the CSV file is as follows:
- Each row corresponds to a single cell
- Columns for each channel name, with expression values specified
- Five additional columns with the following names and information:

| Column Name         | Description                               | Datatype    |
|---------------------|-------------------------------------------|-------------|
| `ImageNumber`       | Unique ID for each image                  | Integer     |
| `PatientID`         | Unique ID for each patient                | Str/Integer |
| `CellType`          | Cell type                                 | Str         |
| `Location_Center_X` | X coordinate of the cell center in micron | Float       |
| `Location_Center_Y` | Y coordinate of the cell center in micron | Float       |

Note: Additional metadata columns beyond these will not be used in this tutorial.

To create a `SpatialDataset` object, specify the path to a single cell csv file and the corresponding list of channel names.

In [None]:
dataset = mp.SpatialDataset(
    input_path="crc/singlecell.csv",  # change to your own file path
    channel_names=channel_names,
)

## Step 2: patch images and mask cells

Next, we will generate image patches (of specified size and resolution) using the spatial data set, followed by masking out cytotoxic T cells

In [None]:
patch_size = 16  # Patch size in pixels
pixel_size = 3  # Pixel size in microns
cell_types = ["Tcytotoxic", "Tumor"]  # Specify the cell types of interest
mask_cell_types = ["Tcytotoxic"]
dataset.generate_masked_patch(
    cell_to_mask=mask_cell_types,
    cell_types=cell_types,
    patch_size=patch_size,
    pixel_size=pixel_size,
    save=True,
)

## Step 3: generate data splits for model training

Next, we generate train, validation, and test data splits for model training, where split is done at the patient level. We want to stratify our splits by the label we want to predict by specifying the `stratify_by` parameter.

In [None]:
colname = "Contains_Tcytotoxic"
dataset.generate_data_splits(
    stratify_by=colname,
    specify_split=patient_split,
)

## Step 4: train classifier model

After generating data splits, we train a unet model to predict the presence of T cells from masked patches. A model instance is first created using the `PatchClassifier` class and trained by calling the `train` method.

In [None]:
# initialize model
model_arch = "unet"
n_channels = dataset.n_channels
img_size = dataset.img_size

In [None]:
model = mp.PatchClassifier(n_channels, img_size, model_arch)

# train model
trainer_params = {
    "max_epochs": 30,
    "accelerator": "auto",
    "logger": False,
}
model = mp.train(
    model=model,
    dataset=dataset,
    predict_label=colname,
    trainer_params=trainer_params,
)

## Step 5: generate counterfactuals using trained classifier

Now we subset for training patches containing tumor but no T cells. These patches will be used to generate counterfactuals. Note that we will build a kdtree from the training patches when we first execute `get_counterfactual`, this process will be done only once. Counterfactual generation is parallized using Ray, but it will still be quite slow. We recommend performing the generation on CPU (instead of GPU) in order to rely on parallelization across a large cluster. We will be releasing slurm codes to help with this.

In [None]:
# select tumor patches that do not contain T cells from training cohort to generate counterfactuals
dataset.get_split_info()
select_metadata = dataset.metadata[
    (dataset.metadata["Contains_Tumor"] == 1)
    & (dataset.metadata["Contains_Tcytotoxic"] == 0)
    & (dataset.metadata["splits"] == "train")
]

# example of selected instances to generate counterfactuals
print(f"Number of selected instances: {len(select_metadata)}")
print(select_metadata.head())

In [None]:
# Parameters for counterfactual generation
optimization_param = {
    "use_kdtree": True,
    "theta": 50.0,
    "kappa": -0.34,
    "learning_rate_init": 0.1,
    "beta": 80.0,
    "max_iterations": 1000,
    "c_init": 1000.0,
    "c_steps": 5,
    "threshold": 0.33,  # probability cutoff for classification
    "channel_to_perturb": [
        "Glnsynthetase",
        "CCR4",
        "PDL1",
        "LAG3",
        "CD105endoglin",
        "TIM3",
        "CXCR4",
        "PD1",
        "CYR61",
        "CD44",
        "IL10",
        "CXCL12",
        "CXCR3",
        "Galectin9",
        "YAP",
    ],
}

# Generate counterfactuals using trained model
mp.get_counterfactual(
    images=select_metadata,
    dataset=dataset,
    target_class=1,
    model_path=model_path,
    optimization_params=optimization_param,
    save_dir=f"{dataset.root_dir}/cf/",
    device="cpu",
    num_workers=os.cpu_count() - 1,
    verbosity=0,
    model_kwargs={"in_channels": n_channels, "img_size": img_size},
)