# Training RNA2seg on Zarr-Saved SpatialData  

This notebook demonstrates how to train RNA2seg on spatial transcriptomics data stored in a Zarr file. The process consists of four main steps:  

1. **Patch Creation** – Extract patches of a reasonable size to process efficiently (saved in the Zarr file).  
2. **Filtered Target Generation** – Create a curated segmentation mask from a teacher model for RNA2seg training (saved in the Zarr file).  
3. **Model Training** – Train RNA2seg using the generated patches and filtered segmentation.  
4. **Apply to the whole dataset** – Use the notebook `apply_model_on_zarr.ipynb` to apply the trained model to the entire dataset.


## Import 

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import cv2
import torch
import numpy as np
from tqdm import tqdm
import spatialdata as sd
from pathlib import Path
import albumentations as A

from rna2seg.dataset_zarr import (
    RNA2segDataset, custom_collate_fn, compute_consistent_cell
)


### Set your own path 

In [3]:

## path to spatial data
merfish_zarr_path = "/media/tom/Transcend/open_merfish/test_spatial_data/from_cluster/test_mouse_ileum.zarr"

## 
path_save_model  = "/media/tom/Transcend/open_merfish/test_spatial_data/from_cluster/modelm"


## Step 1: Create training patches from Zarr files

In this step, the dataset (image + transcripts) is divided into patches of size `patch_width × patch_width` with an overlap of `patch_overlap`. This allows processing images of a manageable size while preserving spatial continuity.  

**Process** 
- The dataset, stored in Zarr format, is loaded.  
- Patches coordinates are saved as a `Shape` in the zarr: `sopa_patches_rna2seg_[patch_width]_[patch_overlap]`. 
- A `.rna2seg` directory is created to store the transcript data corresponding to each patch.  
- The transcript information for each patch is saved in CSV format for further processing.  


In [4]:
from rna2seg.dataset_zarr import create_patch_rna2seg

### load sdata and set path parameters 
sdata = sd.read_zarr(merfish_zarr_path)
image_key = "staining_z3"
patch_width = 1200
patch_overlap = 50
points_key = "transcripts"
min_transcripts_per_patch = 0
folder_patch_rna2seg = Path(merfish_zarr_path) / f".rna2seg_{patch_width}_{patch_overlap}"

### create patch in the sdata and precompute transcipt.csv for each patch with sopa
create_patch_rna2seg(sdata,
                    image_key=image_key,
                    points_key=points_key,
                    patch_width=patch_width,
                    patch_overlap=patch_overlap,
                    min_transcripts_per_patch=min_transcripts_per_patch,
                    folder_patch_rna2seg = folder_patch_rna2seg,
                    overwrite = True)
print(sdata)

[36;20m[INFO] (sopa.patches._patches)[0m Added 64 patche(s) to sdata['sopa_patches_rna2seg_1200_50']


[########################################] | 100% Completed | 16.20 ss
[########################################] | 100% Completed | 12.26 s
SpatialData object, with associated Zarr store: /media/tom/Transcend/open_merfish/test_spatial_data/from_cluster/test_mouse_ileum.zarr
├── Images
│     └── 'staining_z3': DataTree[cyx] (5, 9000, 9000), (5, 4500, 4500), (5, 2250, 2250), (5, 1125, 1125), (5, 562, 562)
├── Points
│     └── 'transcripts': DataFrame with shape: (<Delayed>, 9) (2D points)
└── Shapes
      ├── 'Cellbound1': GeoDataFrame shape: (3258, 1) (2D shapes)
      ├── 'DAPI': GeoDataFrame shape: (2377, 1) (2D shapes)
      ├── 'sopa_patches_rna2seg_1200_50': GeoDataFrame shape: (64, 3) (2D shapes)
      ├── 'sopa_patches_rna2seg_1200_150': GeoDataFrame shape: (81, 3) (2D shapes)
      └── 'test_rnas2eg': GeoDataFrame shape: (3011, 1) (2D shapes)
with coordinate systems:
    ▸ 'global', with elements:
        test_rnas2eg (Shapes)
    ▸ 'microns', with elements:
        staining_z3

## Step 2: Create a Consistent Target for Training RNA2seg

**Input:** Spatial data with potentially erroneous nucleus and cell segmentations.  
**Output:** Curated cell and nucleus segmentations for training RNA2seg. Saved in the zarr.

This step refines two segmentations stored in the Zarr file: **cell segmentation** (`key_shape_cell_seg`) and **nuclei segmentation** (`key_nuclei_segmentation`).  
The goal is to generate a **teacher segmentation** by filtering out inconsistencies between cells and nuclei.  

**Process** 
1. Load the segmentations (`Cellbound1` and `DAPI`) from the Zarr file.  
2. Apply a **consistency check** to remove unreliable segmentations:  
   - **Consistent cell segmentation** → `Cellbound1_consistent`  
   - **Consistent nuclei segmentation** → `DAPI_consistent`  
3. Save the refined segmentations back into the Zarr file.  

This ensures high-quality annotations for training or fine-tuning RNA2seg.  


In [5]:
key_cell_segmentation = "Cellbound1"
key_nuclei_segmentation="DAPI"
# to name for future shape that will be created in the sdata
key_cell_consistent = "Cellbound1_consistent"
key_nucleus_consistent = "DAPI_consistent"

sdata = sd.read_zarr(merfish_zarr_path)

sdata, _ = compute_consistent_cell(
    sdata=sdata,
    key_shape_nuclei_seg=key_nuclei_segmentation,
    key_shape_cell_seg=key_cell_segmentation,
    key_cell_consistent=key_cell_consistent,
    key_nuclei_consistent=key_nucleus_consistent,
    image_key="staining_z3",
    threshold_intersection_contain=0.95,
    threshold_intersection_intersect= 1,
    accepted_nb_nuclei_per_cell=None,
    max_cell_nb_intersecting_nuclei=1,
)

Resolving conflicts: 7688it [00:00, 10603.65it/s]


## Step 3: Training RNA2seg

Now, we will train RNA2seg using the target segmentation created in Step 2.  

### Initialize a RNA2segDataset

In [6]:
from rna2seg.models import RNA2seg

transform_resize  = A.Compose([
 A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST),
])

dataset = RNA2segDataset(
   sdata=sdata,
   channels_dapi= ["DAPI"],
   channels_cellbound=["Cellbound1"],
   shape_patch_key=f"sopa_patches_rna2seg_{patch_width}_{patch_overlap}", # Created at step 1
   key_cell_consistent=key_cell_consistent, # Created at step 2
   key_nucleus_consistent=key_nucleus_consistent, # Created at step 2
   key_nuclei_segmentation=key_nuclei_segmentation,
   gene_column="gene",
   density_threshold = None,
   kernel_size_background_density = 10 ,
   kernel_size_rna2img = 0.5,
   max_filter_size_rna2img = 2,
   transform_resize = transform_resize,
   training_mode = True,
   patch_dir = folder_patch_rna2seg,
   patch_width=1200,
   patch_overlap=50,
   use_cache = True, 
)

print((len(dataset)))

No module named 'vmunet'
VMUnet not loaded
100%|██████████| 64/64 [00:00<00:00, 83.27it/s]
Number of valid patches: 48
100%|██████████| 64/64 [00:00<00:00, 80.87it/s]
compute density threshold


100%|██████████| 48/48 [00:07<00:00,  6.35it/s]

Time to compute density threshold: 7.559024s
48





In [7]:
dataset[2].keys()

dict_keys(['img_cellbound', 'dapi', 'rna_img', 'mask_flow', 'mask_gradient', 'idx', 'patch_index', 'bounds', 'segmentation_nuclei'])

### Train / Validataion split

In [8]:
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler

train_indices, val_indices = train_test_split(
    range(len(dataset)), test_size=0.1, shuffle=True, random_state=42
)
train_sampler = SubsetRandomSampler(train_indices) 
valid_sampler = SubsetRandomSampler(val_indices)

### Initialize Dataloaders

In [9]:
training_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=2,
                                              shuffle=False,
                                              num_workers = 0,
                                              sampler=train_sampler,
                                              collate_fn = custom_collate_fn,
                                              )

print( f"len(training_loader) {len(training_loader)}")

validation_loader = torch.utils.data.DataLoader(dataset,
                                                batch_size=2,
                                                shuffle=False,
                                                num_workers = 0,
                                                sampler=valid_sampler,
                                                collate_fn = custom_collate_fn,
                                                )

print( f"len(training_loader) {len(validation_loader)}")

len(training_loader) 22
len(training_loader) 3


### Initilize RNA2seg Model

In [10]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cpu


In [11]:
rna2seg = RNA2seg(
    device,
    net='unet',
    flow_threshold = 0.9,
    cellbound_flow_threshold = 0.4,
    pretrained_model = None,
)
rna2seg = rna2seg.to(device)

optimizer = torch.optim.AdamW(rna2seg.parameters(), lr=0.001, weight_decay=0.01)


initiaisation of CPnet
Initiaisation of ChannelInvariantNet


### Training RNA2seg

In [12]:
from rna2seg.train import train_one_epoch

best_val_loss = np.inf


for epoch_index in tqdm(range(3)):

    train_one_epoch(
        device=device,
        epoch_index=epoch_index,
        rna2seg=rna2seg,
        training_loader=training_loader,
        optimizer=optimizer,
        print_loss_every = int(len(training_loader) /3),
        tb_writer= None,
        validation_loader=validation_loader,
        path_save_model=path_save_model,
        cellbound_prob= 0.8,
        best_val_loss=best_val_loss
    )

  0%|          | 0/3 [00:00<?, ?it/s]


[A  validation loss: 13.405842463175455, ?it/s]
best_val_loss: 13.405842463175455

[Aining:   5%|▍         | 1/22 [00:12<04:18, 12.31s/it]
[Aining:   9%|▉         | 2/22 [00:18<02:53,  8.69s/it]
[Aining:  14%|█▎        | 3/22 [00:24<02:22,  7.50s/it]
[Aining:  18%|█▊        | 4/22 [00:30<02:05,  6.99s/it]
[Aining:  23%|██▎       | 5/22 [00:36<01:52,  6.62s/it]
[Aining:  27%|██▋       | 6/22 [00:43<01:44,  6.55s/it]
[A  validation loss: 12.7456032435099281:33,  6.25s/it]
best_val_loss: 12.745603243509928

[Aining:  36%|███▋      | 8/22 [00:59<01:48,  7.74s/it]
[Aining:  41%|████      | 9/22 [01:05<01:31,  7.03s/it]
[Aining:  45%|████▌     | 10/22 [01:10<01:18,  6.56s/it]
[Aining:  50%|█████     | 11/22 [01:16<01:09,  6.31s/it]
[Aining:  55%|█████▍    | 12/22 [01:22<01:02,  6.23s/it]
[Aining:  59%|█████▉    | 13/22 [01:28<00:55,  6.13s/it]
[A  validation loss: 12.13439718882242800:49,  6.15s/it]
best_val_loss: 12.134397188822428

[Aining:  68%|██████▊   | 15/22 [01:46<00

 33%|███▎      | 1/3 [02:30<05:01, 150.83s/it]



[A  validation loss: 11.542235692342123, ?it/s]
best_val_loss: 11.542235692342123

[Aining:   5%|▍         | 1/22 [00:10<03:47, 10.83s/it]
[Aining:   9%|▉         | 2/22 [00:16<02:40,  8.03s/it]
[Aining:  14%|█▎        | 3/22 [00:22<02:11,  6.90s/it]
[Aining:  18%|█▊        | 4/22 [00:28<01:54,  6.37s/it]
[Aining:  23%|██▎       | 5/22 [00:33<01:43,  6.09s/it]
[Aining:  27%|██▋       | 6/22 [00:39<01:38,  6.15s/it]
[A  validation loss: 10.6438512802124021:28,  5.88s/it]
best_val_loss: 10.643851280212402

[Aining:  36%|███▋      | 8/22 [00:55<01:43,  7.36s/it]
[Aining:  41%|████      | 9/22 [01:01<01:28,  6.77s/it]
[Aining:  45%|████▌     | 10/22 [01:06<01:16,  6.35s/it]
[Aining:  50%|█████     | 11/22 [01:12<01:07,  6.12s/it]
[Aining:  55%|█████▍    | 12/22 [01:17<00:59,  5.94s/it]
[Aining:  59%|█████▉    | 13/22 [01:23<00:52,  5.87s/it]
[A  validation loss: 10.98107210795084600:46,  5.76s/it]

[Aining:  68%|██████▊   | 15/22 [01:39<00:49,  7.06s/it]
[Aining:  73%|██

 67%|██████▋   | 2/3 [04:50<02:24, 144.46s/it]



[A  validation loss: 10.06841786702474?, ?it/s]
best_val_loss: 10.06841786702474

[Aining:   5%|▍         | 1/22 [00:11<03:55, 11.21s/it]
[Aining:   9%|▉         | 2/22 [00:16<02:39,  7.99s/it]
[Aining:  14%|█▎        | 3/22 [00:22<02:12,  6.97s/it]
[Aining:  18%|█▊        | 4/22 [00:28<01:54,  6.34s/it]
[Aining:  23%|██▎       | 5/22 [00:33<01:44,  6.14s/it]
[Aining:  27%|██▋       | 6/22 [00:39<01:35,  5.96s/it]
[A  validation loss: 9.18920675913492801:30,  6.07s/it]
best_val_loss: 9.189206759134928

[Aining:  36%|███▋      | 8/22 [00:56<01:44,  7.45s/it]
[Aining:  41%|████      | 9/22 [01:01<01:28,  6.81s/it]
[Aining:  45%|████▌     | 10/22 [01:07<01:17,  6.42s/it]
[Aining:  50%|█████     | 11/22 [01:12<01:07,  6.16s/it]
[Aining:  55%|█████▍    | 12/22 [01:19<01:02,  6.23s/it]
[Aining:  59%|█████▉    | 13/22 [01:25<00:55,  6.18s/it]
[A  validation loss: 9.122661113739014<00:48,  6.08s/it]
best_val_loss: 9.122661113739014

[Aining:  68%|██████▊   | 15/22 [01:41<00:5

100%|██████████| 3/3 [07:15<00:00, 145.18s/it]





