## Example IMC analysis with Morpheus

#### Step 0: import required packages

In [None]:
import morpheus as mp
from lightning.pytorch import seed_everything

# Set seed for reproducibility
seed_everything(42)

%reload_ext autoreload
%autoreload 2

#### Step 1: create a dataset object

In [None]:
data_path = "/groups/mthomson/zwang2/IMC/output/cedarsLiver_sz48_pxl3_nc44/temp/singlecell.csv"  # change to your own directory
dataset = mp.SpatialDataset(
    input_path=data_path,
    channel_names=[
        "CD45",
        "Glnsynthetase",
        "CD163",
        "NKG2D",
        "CCR4",
        "PDL1",
        "FAP",
        "CD11c",
        "LAG3",
        "HepPar1",
        "FOXP3",
        "aSMA",
        "CD4",
        "CD105endoglin",
        "CD68",
        "VISTA",
        "CD20",
        "CD8a",
        "TIM3",
        "CXCR4",
        "PD1",
        "iNOS",
        "CD31",
        "CYR61",
        "CDX2",
        "CAIX",
        "CD3",
        "CD44",
        "CD15",
        "CD11b",
        "HLADR",
        "IL10",
        "CXCL12",
        "HLAABC",
        "DNA1",
        "DNA2",
        "GranzymeB",
        "Ki67",
        "HistoneH3",
        "CXCR3",
        "Galectin9",
        "YAP",
        "CD14",
        "CK19",
    ],
)

In [None]:
patch_size = 16  # Patch size in pixels
pixel_size = 3  # Pixel size in microns
cell_types = ["Tcytotoxic", "Tumor"]  # Specify the cell types of interest
mask_cell_types = ["Tcytotoxic"]
dataset.generate_masked_patch(
    cell_to_mask=mask_cell_types,
    cell_types=cell_types,
    patch_size=patch_size,
    pixel_size=pixel_size,
    save=True,
)

# example metadata
print(dataset.metadata.head())

#### Step 2: generate data splits to prepare for model training

Next, we will need to generate train, validation, and test data splits for model training. We want to stratify our splits by the label we want to predict.

In [None]:
label_name = "Contains_Tcytotoxic"
dataset.generate_data_splits(stratify_by=label_name)

##### Step 3: train PyTorch model

In [None]:
# initialize model
model_arch = "unet"
n_channels = dataset.n_channels
img_size = dataset.img_size
model = mp.PatchClassifier(n_channels, img_size, model_arch)

# train model
trainer_params = {
    "max_epochs": 2,  # set to >60 for better performance
    "accelerator": "auto",
    "logger": False,
}
model = mp.train(
    model=model,
    dataset=dataset,
    label_name=label_name,
    trainer_params=trainer_params,
)

#### Step 4: generate counterfactuals

In [None]:
# load model if needed
# model = mp.load_model("/path/to/checkpoint.ckpt")

# images to generate counterfactuals
select_metadata = dataset.metadata[
    (dataset.metadata["Contains_Tumor"] == 1)
    & (dataset.metadata["Contains_Tcytotoxic"] == 0)
    & (dataset.metadata["splits"] == "train")
]
# channels allowed to be perturbed
channel_to_perturb = [
    "Glnsynthetase",
    "CCR4",
    "PDL1",
    "LAG3",
    "CD105endoglin",
    "TIM3",
    "CXCR4",
    "PD1",
    "CYR61",
    "CD44",
    "IL10",
    "CXCL12",
    "CXCR3",
    "Galectin9",
    "YAP",
]

# threshold for classification
threshold = 0.5

# optimization parameters
optimization_param = {
    "use_kdtree": True,
    "theta": 40.0,
    "kappa": 0,  # set to: (threshold - 0.5) * 2
    "learning_rate_init": 0.1,
    "beta": 40.0,
    "max_iterations": 10,  # set to >1000 for better performance
    "c_init": 1000.0,
    "c_steps": 5,
}

In [None]:
# Generate counterfactuals using trained model
cf = mp.get_counterfactual(
    images=select_metadata.iloc[:2],
    dataset=dataset,
    target_class=1,
    model=model,
    channel_to_perturb=channel_to_perturb,
    optimization_params=optimization_param,
    threshold=threshold,
)