# ComSeg on large scale data with SOPA / SpatialData

#### ComSeg is now integrated in SOPA : https://gustaveroussy.github.io/sopa/tutorials/comseg/.  This tutorial is depreciated

To ease the application of ComSeg on large datasets, ComSeg can be used with SOPA: https://gustaveroussy.github.io/sopa/  
Sopa is build on top of Spatial data

The following example is a modified version of the official SOPA tutorial : 
https://gustaveroussy.github.io/sopa/tutorials/api_usage/  
and done with the version 1.0.14 of SOPA

In [1]:
import pandas as pd
import sopa.segmentation
import sopa.io
from matplotlib import pyplot as plt
from tqdm import tqdm
from sopa._sdata import to_intrinsic
import sopa
from spatialdata import SpatialData, read_zarr
from spatialdata.models import PointsModel
from pathlib import Path

### Load example data

In [2]:
sdata = sopa.io.uniform()
image_key = "image"
points_key = "transcripts"
gene_column = "genes"

### Segmentation of nuclei with Cellpose 
The nucleus segmentation will be used as prior by ComSeg

In [3]:
patches = sopa.segmentation.Patches2D(sdata, image_key, patch_width=1200, patch_overlap=50)
patches.write()
from sopa._sdata import get_spatial_image
print(get_spatial_image(sdata, image_key).c.values)
channels = ["DAPI"]
method = sopa.segmentation.methods.cellpose_patch(diameter=35, channels=channels, flow_threshold=2, cellprob_threshold=-6)
segmentation = sopa.segmentation.StainingSegmentation(sdata, method, channels, min_area=2500)


# The cellpose boundaries will be temporary saved here. You can choose a different path
cellpose_temp_dir = "tuto.zarr/.sopa_cache/cellpose"
segmentation.write_patches_cells(cellpose_temp_dir)

cells = sopa.segmentation.StainingSegmentation.read_patches_cells(cellpose_temp_dir)
cells = sopa.segmentation.shapes.solve_conflicts(cells)
shapes_key = "cellpose_boundaries" # name of the key given to the cells in sdata.shapes
sopa.segmentation.StainingSegmentation.add_shapes(sdata, cells, image_key, shapes_key)

### Compute the patches for ComSeg

In [4]:
image_key = "image"
points_key = "transcripts" # (ignore this for multiplex imaging)
gene_column = "genes" # (option
config_comseg = {}
baysor_temp_dir = "tuto.zarr/.sopa_cache/comseg"

patches = sopa.segmentation.Patches2D(sdata, points_key, patch_width=200, patch_overlap=50)
valid_indices = patches.patchify_transcripts(baysor_temp_dir, config=config_comseg, use_prior=True)

### Compute centroid for ComSeg

In [5]:
def add_centroids_to_sdata(sdata,
                           points_key="transcripts",
                           shapes_key='cellpose_boundaries', z_constant=None):
    centroid = sdata[shapes_key].geometry.centroid
    x_centroid = list(centroid.geometry.x)
    y_centroid = list(centroid.geometry.y)
    if z_constant is not None:
        z_centroid = [z_constant] * len(y_centroid)
        coords = pd.DataFrame({"x": x_centroid, "y": y_centroid, "z": z_centroid})

    else:
        if "z" in sdata[points_key].columns:
            z = list(sdata[points_key].z.unique().compute())
            assert len(z)==1, "3D point cloud with 2D segmentation, manually set z_constant"
            z_centroid = [z[0]] * len(y_centroid)
            coords = pd.DataFrame({"x": x_centroid, "y": y_centroid, "z": z_centroid})
        else:
            coords = pd.DataFrame({"x": x_centroid, "y": y_centroid})
    points = PointsModel.parse(coords)
    sdata['centroid'] = points
    sdata['centroid'] = to_intrinsic(sdata, sdata['centroid'], points_key)
    return sdata

sdata_centroid = SpatialData()
sdata_centroid['cellpose_boundaries'] = sdata['cellpose_boundaries']
sdata_centroid['transcripts'] = sdata['transcripts']
sdata_centroid = add_centroids_to_sdata(sdata_centroid,
                           points_key="transcripts",
                           shapes_key='cellpose_boundaries',
                               z_constant=1)


baysor_temp_dir = "tuto.zarr/.sopa_cache/comseg_centroid"
config_comseg ={}
points_key = "centroid"
patches = sopa.segmentation.Patches2D(sdata_centroid, points_key, patch_width=200, patch_overlap=50)
valid_indices = patches.patchify_transcripts(baysor_temp_dir, config=config_comseg, use_prior=True)

### Run ComSeg on each patch

In [9]:
import comseg
from comseg import dataset as ds
from comseg import dictionary
from comseg import model
import json

#### HYPERPARAMETER ####
MEAN_CELL_DIAMETER = 15  # in micrometer
MAX_CELL_RADIUS = 50  # in micrometer
#########################

path_transcript = "tuto.zarr/.sopa_cache/comseg"
path_centroid = "tuto.zarr/.sopa_cache/comseg_centroid"

for patch_index in tqdm(list(range(len(patches.ilocs)))):
    path_dataset_folder = Path(path_transcript) / str(patch_index)
    path_dataset_folder_centroid = Path(path_centroid) / str(patch_index)
    
    dataset = ds.ComSegDataset(
        path_dataset_folder=path_dataset_folder,
        dict_scale={"x": 1, 'y': 1, "z": 1},
        mean_cell_diameter = MEAN_CELL_DIAMETER,
        gene_column = "genes",
        )

    
    dico_proba_edge, count_matrix = dataset.compute_edge_weight(  # in micrometer
    images_subset=None,
    n_neighbors=40,
    sampling=True,
    sampling_size=10000
    )
    
    Comsegdict = dictionary.ComSegDict(
    dataset=dataset,
    mean_cell_diameter=MEAN_CELL_DIAMETER,
    community_detection="with_prior",
    prior_name="cell",
    )
    Comsegdict.run_all(max_cell_radius = MAX_CELL_RADIUS,
                        path_dataset_folder_centroid=path_dataset_folder_centroid,
                               file_extension=".csv")


                       
    anndata_comseg, json_dict = Comsegdict.anndata_from_comseg_result(
        return_polygon = True,
        alpha = 0.5,
        min_rna_per_cell = 5)
    anndata_comseg.write_loom(path_dataset_folder / 'segmentation_counts.loom')
    ## save the json_dict as json
    with open(path_dataset_folder / "segmentation_polygons.json", 'w') as f:
        json.dump(json_dict['transcripts'], f)

### Aggregate results

In [7]:

from sopa.segmentation.baysor.resolve import resolve
resolve(sdata, path_transcript, gene_column, min_area=10)
shapes_key = "baysor_boundaries"
aggregator = sopa.segmentation.Aggregator(sdata, image_key=image_key, shapes_key=shapes_key)
aggregator.compute_table(gene_column=gene_column, average_intensities=True)
sdata

### Plot results

In [8]:

import spatialdata_plot
sdata.pl.render_points(size=0.01, color="r")\
    .pl.render_images()\
    .pl.render_shapes(shapes_key, outline=True, fill_alpha=0, outline_color="w")\
    .pl.show("global")
plt.show()