# Example IMC analysis with Morpheus

## Step 0 (optional): set seed for reproducibility 

In [1]:
from lightning.pytorch import seed_everything
seed_everything(42)

%reload_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 42


## Step 1: Creating a SpatialDataset Object

In this tutorial, we will start by creating a `SpatialDataset` object, which will hold all relevant information about the dataset we will be working with. 

### Prerequisites

To create a `SpatialDataset` object, you will need:
- The path to the CSV file containing all single-cell expression information
- A list of channel names

### CSV File Structure

The expected structure of the CSV file is as follows:
- Each row corresponds to a single cell
- Columns for each channel name, with expression values specified
- Five additional columns with the following names and information:

| Column Name         | Description                               | Datatype    |
|---------------------|-------------------------------------------|-------------|
| `ImageNumber`       | Unique ID for each image                  | Integer     |
| `PatientID`         | Unique ID for each patient                | Str/Integer |
| `CellType`          | Cell type                                 | Str         |
| `Location_Center_X` | X coordinate of the cell center in micron | Float       |
| `Location_Center_Y` | Y coordinate of the cell center in micron | Float       |

Note: Additional metadata columns beyond these will not be used for the analysis performed in this tutorial.

To create a `SpatialDataset` object, use the following code, remember to replace `'path/to/your/csv/file.csv'` with the actual path to your CSV file.

In [3]:
import morpheus as mp

data_path = "/groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/new2/singlecell.csv"  # change to your own directory
dataset = mp.SpatialDataset(
    input_path=data_path, additional_cols=["Cancer_Stage", "IHC_T_score"],
    model_path="/groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/final/model/epoch=21-step=7106.ckpt",
)

41 channels inferred from input CSV: ['Vimentin', 'CD163', 'B2M', 'CD134', 'CD68', 'GLUT1', 'CD3', 'Lag3', 'PD1', 'CCL4_mRNA', 'CCL18_mRNA', 'HistoneH3', 'CCR2', 'PDL1', 'CXCL8_mRNA', 'CXCL10_mRNA', 'CXCL12_mRNA', 'CXCL13_mRNA', 'CD8', 'CCL2_mRNA', 'CCL22_mRNA', 'CXCL9_mRNA', 'SMA', 'DapB_mRNA', 'SOX10', 'CCL8_mRNA', 'CD31', 'CCL19_mRNA', 'Mart1', 'pRB', 'cleavedPARP', 'DNA1', 'DNA2', 'CK5', 'CD15', 'MPO', 'CD38', 'HLADR', 'S100', 'Cadherin11', 'FAP']
Input path: /groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/new2/singlecell.csv
Patch path: /groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/new2/patch.h5
Split directory: /groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/new2/split
Model path: /groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/final/model/epoch=21-step=7106.ckpt
Counterfactual directory: /groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/new2/cf


## Step 2: patch images and mask cells

In [4]:
patch_size = 16  # Patch size in pixels
pixel_size = 3  # Pixel size in microns
cell_types = ["Tcytotoxic", "Tumor"]  # Specify the cell types of interest
mask_cell_types = ["Tcytotoxic"]
dataset.generate_masked_patch(
    cell_to_mask=mask_cell_types,
    cell_types=cell_types,
    patch_size=patch_size,
    pixel_size=pixel_size,
    save=True,
)

File /groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/new2/patch.h5 already exists, existing file loaded


## Step 3: generate data splits for model training

Next, we will need to generate train, validation, and test data splits for model training. We want to stratify our splits by the label we want to predict.

In [5]:
colname = "Contains_Tcytotoxic"
dataset.generate_data_splits(stratify_by=colname)

Data splits already exist in /groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/new2/split


In [6]:
colname = "Contains_Tcytotoxic"
model_arch = "unet"
n_channels = dataset.n_channels
img_size = dataset.img_size
mp.test_model(dataset, predict_label=colname, in_channels=n_channels, img_size=img_size)

/central/home/zwang2/.cache/pypoetry/virtualenvs/morpheus-spatial-ndDQRg-x-py3.9/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /central/home/zwang2/.cache/pypoetry/virtualenvs/mor ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/central/home/zwang2/.cache/pypoetry/virtualenvs/morpheus-spatial-ndDQRg-x-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `ten

Testing model at /groups/mthomson/zwang2/IMC/output/hochMelanoma_sz48_pxl3_nc41/final/model/epoch=21-step=7106.ckpt


/central/home/zwang2/.cache/pypoetry/virtualenvs/morpheus-spatial-ndDQRg-x-py3.9/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /central/home/zwang2/.cache/pypoetry/virtualenvs/mor ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 79/79 [00:12<00:00,  6.53it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8505405187606812
       test_auroc           0.7512026429176331
        test_bce            0.5790185928344727
        test_bmc             0.550147533416748
         test_f1            0.6326158046722412
     test_precision         0.7272216081619263
       test_recall          0.5663365721702576
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


## Step 4: train classifier model

In [None]:
# initialize model
model_arch = "unet"
n_channels = dataset.n_channels
img_size = dataset.img_size
model = mp.PatchClassifier(n_channels, img_size, model_arch)

# train model
trainer_params = {
    "max_epochs": 30,
    "accelerator": "auto",
    "logger": False,
}
model = mp.train(
    model=model,
    dataset=dataset,
    predict_label=colname,
    trainer_params=trainer_params,
)

In [None]:
mp.test_model(
    dataset=dataset,
    predict_label=colname,
    model_path="/groups/mthomson/zwang2/IMC/output/danenbergBreast_sz48_pxl3_nc39/model/unet/lightning_logs/version_1/checkpoints/epoch=13-step=8120.ckpt",
    in_channels=dataset.n_channels,
    img_size=dataset.img_size,
)

## Step 5: generate counterfactuals using trained classifier

In [None]:
# images to generate counterfactuals
dataset.get_split_info()
select_metadata = dataset.metadata[
    (dataset.metadata["Contains_Tumor"] == 1)
    & (dataset.metadata["Contains_Tcytotoxic"] == 0)
    & (dataset.metadata["splits"] == "train")
]

In [None]:
# channels allowed to be perturbed
channel_to_perturb = [
    "Glnsynthetase",
    "CCR4",
    "PDL1",
    "LAG3",
    "CD105endoglin",
    "TIM3",
    "CXCR4",
    "PD1",
    "CYR61",
    "CD44",
    "IL10",
    "CXCL12",
    "CXCR3",
    "Galectin9",
    "YAP",
]

# probability cutoff for classification
threshold = 0.43

# optimization parameters
optimization_param = {
    "use_kdtree": True,
    "theta": 40.0,
    "kappa": (threshold - 0.5) * 2,
    "learning_rate_init": 0.1,
    "beta": 80.0,
    "max_iterations": 1000,
    "c_init": 1000.0,
    "c_steps": 5,
    "numerical_diff": False,
}

# example of selected instances to generate counterfactuals
print(f"Number of selected instances: {len(select_metadata)}")
print(select_metadata.head())

In [None]:
# Generate counterfactuals using trained model
mp.get_counterfactual(
    images=select_metadata.iloc[:1],
    dataset=dataset,
    target_class=1,
    model_path=f"{dataset.root_dir}/model/checkpoints/epoch=42-step=13287.ckpt",
    channel_to_perturb=channel_to_perturb,
    optimization_params=optimization_param,
    threshold=threshold,
    save_dir=f"{dataset.root_dir}/cf/",
    device="cpu",
    num_workers=1,
    verbosity=0,
    model_kwargs={"in_channels": n_channels, "img_size": img_size},
)