# Example IMC analysis with Morpheus

## Step 0 (optional): set seed for reproducibility 

In [None]:
from lightning.pytorch import seed_everything
seed_everything(42)

%reload_ext autoreload
%autoreload 2

## Step 1: Creating a SpatialDataset Object

In this tutorial, we will start by creating a `SpatialDataset` object, which will hold all relevant information about the dataset we will be working with. 

### Prerequisites

To create a `SpatialDataset` object, you will need:
- The path to the CSV file containing all single-cell expression information
- A list of channel names

### CSV File Structure

The expected structure of the CSV file is as follows:
- Each row corresponds to a single cell
- Columns for each channel name, with expression values specified
- Five additional columns with the following names and information:

| Column Name         | Description                               |
|---------------------|-------------------------------------------|
| `ImageNumber`       | Unique ID for each image                  |
| `PatientID`         | Unique ID for each patient                |
| `CellType`          | Cell type                                 |
| `Location_Center_X` | X coordinate of the cell center in micron |
| `Location_Center_Y` | Y coordinate of the cell center in micron |

Note: Additional columns beyond these will not be used for the analysis performed in this tutorial.

To create a `SpatialDataset` object, use the following code, remember to replace `'path/to/your/csv/file.csv'` with the actual path to your CSV file.

In [None]:
import morpheus as mp

data_path = "/groups/mthomson/zwang2/IMC/output/cedarsLiver_sz48_pxl3_nc44/temp/singlecell.csv"  # change to your own directory
dataset = mp.SpatialDataset(input_path=data_path)

## Step 2: patch images and mask cells

In [None]:
patch_size = 16  # Patch size in pixels
pixel_size = 3  # Pixel size in microns
cell_types = ["Tcytotoxic", "Tumor"]  # Specify the cell types of interest
mask_cell_types = ["Tcytotoxic"]
dataset.generate_masked_patch(
    cell_to_mask=mask_cell_types,
    cell_types=cell_types,
    patch_size=patch_size,
    pixel_size=pixel_size,
    save=True,
)

# example metadata
print(dataset.metadata.head())

## Step 3: generate data splits for model training

Next, we will need to generate train, validation, and test data splits for model training. We want to stratify our splits by the label we want to predict.

In [None]:
colname = "Contains_Tcytotoxic"
dataset.generate_data_splits(stratify_by=colname)

## Step 4: train classifier model

In [None]:
# # initialize model
# model_arch = "unet"
# n_channels = dataset.n_channels
# img_size = dataset.img_size
# model = mp.PatchClassifier(n_channels, img_size, model_arch)

# # train model
# trainer_params = {
#     "max_epochs": 100,  # set to >60 for better performance
#     "accelerator": "auto",
#     "logger": False,
# }
# model = mp.train(
#     model=model,
#     dataset=dataset,
#     label_name=colname,
#     trainer_params=trainer_params,
# )

## Step 5: generate counterfactuals using trained classifier

In [None]:
# images to generate counterfactuals
dataset.get_split_info()
select_metadata = dataset.metadata[
    (dataset.metadata["Contains_Tumor"] == 1)
    & (dataset.metadata["Contains_Tcytotoxic"] == 0)
    & (dataset.metadata["splits"] == "train")
]

In [None]:
# channels allowed to be perturbed
channel_to_perturb = [
    "Glnsynthetase",
    "CCR4",
    "PDL1",
    "LAG3",
    "CD105endoglin",
    "TIM3",
    "CXCR4",
    "PD1",
    "CYR61",
    "CD44",
    "IL10",
    "CXCL12",
    "CXCR3",
    "Galectin9",
    "YAP",
]

# probability cutoff for classification
threshold = 0.43

# optimization parameters
optimization_param = {
    "use_kdtree": True,
    "theta": 40.0,
    "kappa": (threshold - 0.5) * 2,
    "learning_rate_init": 0.1,
    "beta": 1.0,
    "max_iterations": 1000,  # set to >1000 for better performance
    "c_init": 1000.0,
    "c_steps": 5,
}

# load model
model = mp.load_model(f"{dataset.root_dir}/model/checkpoints/epoch=41-step=12432.ckpt")

In [None]:
import os

# Generate counterfactuals using trained model
cf = mp.get_counterfactual(
    images=select_metadata,
    dataset=dataset,
    target_class=1,
    model=model,
    channel_to_perturb=channel_to_perturb,
    optimization_params=optimization_param,
    threshold=threshold,
    save_dir=f"{dataset.root_dir}/cf/",
    device="cpu",
    num_workers=os.cpu_count() - 1,
)