## Example IMC analysis with Morpheus

#### Step 0: import required packages

In [None]:
import morpheus as mp
from lightning.pytorch import seed_everything

# Set seed for reproducibility
seed_everything(42)

%reload_ext autoreload
%autoreload 2

#### Step 1: create a dataset object

In [None]:
import os

root_dir = "/groups/mthomson/zwang2/IMC/output/cedarsLiver_sz48_pxl3_nc44/temp"  # change to your own directory
data_path = os.path.join(root_dir, "crc.h5")
livertumor = mp.SpatialDataset(data_path=data_path)

model_dir = os.path.join(root_dir, "models")

#### Step 2: generate data splits to prepare for model training

Next, we will need to generate train, validation, and test data splits for model training. We want to stratify our splits by the label we want to predict.

In [None]:
label_name = "Tcytotoxic"
livertumor.generate_data_splits(stratify_by=label_name)

##### Step 3: train PyTorch model

In [None]:
# initialize model
model_arch = "unet"
n_channels = livertumor.n_channels
img_size = livertumor.img_size
model = mp.PatchClassifier(n_channels, img_size, arch=model_arch)

# train model
trainer_params = {
    "max_epochs": 2,
    "accelerator": "auto",
    "logger": False,
}
model = mp.train(
    model=model,
    dataset=livertumor,
    label_name=label_name,
    save_model_dir=model_dir,
    trainer_params=trainer_params,
)

#### Step 4: generate counterfactuals

In [None]:
# images to generate counterfactuals
select_metadata = livertumor.metadata[
    (livertumor.metadata["Tumor"] == 1) & (livertumor.metadata["Tcytotoxic"] == 0)
]

# load model if needed
model_path = os.path.join(model_dir, "checkpoints/epoch=49-step=17400.ckpt")
model = livertumor.load_model(model_path, arch="unet")

In [None]:
# channels allowed to be perturbed
channel_to_perturb = [
    "Glnsynthetase",
    "CCR4",
    "PDL1",
    "LAG3",
    "CD105endoglin",
    "TIM3",
    "CXCR4",
    "PD1",
    "CYR61",
    "CD44",
    "IL10",
    "CXCL12",
    "CXCR3",
    "Galectin9",
    "YAP",
]

# threshold for classification
threshold = 0.5

# optimization parameters
optimization_param = {
    "use_kdtree": True,
    "theta": 40.0,
    "kappa": 0,  # recommend set to: (threshold - 0.5) * 2
    "learning_rate_init": 0.1,
    "beta": 40.0,
    "max_iterations": 1000,
    "c_init": 1000.0,
    "c_steps": 5,
}

# Generate counterfactuals using trained model
cf = mp.get_counterfactual(
    images=select_metadata.iloc[0:2],
    dataset=livertumor,
    target_class=1,
    model=model,
    channel_to_perturb=channel_to_perturb,
    optimization_params=optimization_param,
    threshold=threshold,
    save_dir=os.path.join(root_dir, "cf"),
)

In [None]:
metadata = livertumor.metadata[["PatientID", "ImageNumber"]]
label = livertumor.label

# merge metadata and label using ImageNumber as the common column
metadata = metadata.merge(label, on="ImageNumber")
metadata = metadata.reset_index().rename(columns={"index": "patch_index"})

# add misc to metadata with columns "location_x_index" and "location_y_index"
metadata = metadata.join(
    pd.DataFrame(misc, columns=["location_x_index", "location_y_index"])
)

metadata.to_csv(
    "/groups/mthomson/zwang2/IMC/output/cedarsLiver_sz48_pxl3_nc44/temp/metadata.csv",
    index=False,
)

In [None]:
import h5py

with h5py.File(
    "/groups/mthomson/zwang2/IMC/output/cedarsLiver_sz48_pxl3_nc44/temp/crc.h5", "w"
) as f:
    # Create a dataset to store the images
    dset = f.create_dataset(
        "images",
        data=livertumor.intensity,
        compression="gzip",
        chunks=(100, 16, 16, 44),
        dtype=livertumor.intensity.dtype,
    )

    # Create a group to store the metadata
    metadata_numpy = livertumor.metadata.to_records(index=False)
    dset_metadata = f.create_dataset(
        "metadata", data=metadata_numpy, dtype=metadata_numpy.dtype
    )

    # Create a group to store the channel names
    dset_channel_names = f.create_dataset("channel_names", data=livertumor.channel)