# Training RNA2seg on Zarr-Saved SpatialData  

This notebook demonstrates how to train RNA2seg on spatial transcriptomics data stored in a Zarr file. The process consists of four main steps:  

1. **Patch Creation** – Extract patches of a reasonable size to process efficiently (saved in the Zarr file).  
2. **Filtered Target Generation** – Create a curated segmentation mask from a teacher model for RNA2seg training (saved in the Zarr file).  
3. **Model Training** – Train RNA2seg using the generated patches and filtered segmentation.  
4. **Apply to the whole dataset** – Use the notebook `apply_model_on_zarr.ipynb` to apply the trained model to the entire dataset.


## Import 

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import cv2
import torch
import numpy as np
from tqdm import tqdm
import spatialdata as sd
from pathlib import Path
import albumentations as A

from rna2seg.dataset_zarr import (
    RNA2segDataset, custom_collate_fn, compute_consistent_cell
)


## Step 1: Create training patches from Zarr files

In this step, the dataset (image + transcripts) is divided into patches of size `patch_width × patch_width` with an overlap of `patch_overlap`. This allows processing images of a manageable size while preserving spatial continuity.  

**Process** 
- The dataset, stored in Zarr format, is loaded.  
- Patches coordinates are saved as a `Shape` in the zarr: `sopa_patches_rna2seg_[patch_width]_[patch_overlap]`. 
- A `.rna2seg` directory is created to store the transcript data corresponding to each patch.  
- The transcript information for each patch is saved in CSV format for further processing.  


In [3]:
from rna2seg.dataset_zarr import create_patch_rnaseg

### load sdata and set path parameters 
merfish_zarr_path = "/Users/alice/Documents/data/Cell_Segmentation/test_mouse_ileum.zarr"
sdata = sd.read_zarr(merfish_zarr_path)
image_key = "staining_z3"
patch_width = 1200
patch_overlap = 50
points_key = "transcripts"
min_transcripts_per_patch = 0
folder_patch_rna2seg = Path(merfish_zarr_path) / ".rna2seg"

### create patch in the sdata and precompute transcipt.csv for each patch with sopa
create_patch_rnaseg(sdata,
                    image_key=image_key,
                    points_key=points_key,
                    patch_width=patch_width,
                    patch_overlap=patch_overlap,
                    min_transcripts_per_patch=min_transcripts_per_patch,
                    folder_patch_rna2seg = folder_patch_rna2seg,
                    overwrite = True)
print(sdata)

## Step 2: Create a Consistent Target for Training RNAseg

**Input:** Spatial data with potentially erroneous nucleus and cell segmentations.  
**Output:** Curated cell and nucleus segmentations for training RNAseg. Saved in the zarr.

This step refines two segmentations stored in the Zarr file: **cell segmentation** (`key_shape_cell_seg`) and **nuclei segmentation** (`key_nuclei_segmentation`).  
The goal is to generate a **teacher segmentation** by filtering out inconsistencies between cells and nuclei.  

**Process** 
1. Load the segmentations (`Cellbound1` and `DAPI`) from the Zarr file.  
2. Apply a **consistency check** to remove unreliable segmentations:  
   - **Consistent cell segmentation** → `Cellbound1_consistent`  
   - **Consistent nuclei segmentation** → `DAPI_consistent`  
3. Save the refined segmentations back into the Zarr file.  

This ensures high-quality annotations for training or fine-tuning RNAseg.  


In [None]:
key_cell_segmentation = "Cellbound1"
key_nuclei_segmentation="DAPI"
# to name for future shape that will be created in the sdata
key_cell_consistent = "Cellbound1_consistent"
key_nucleus_consistent = "DAPI_consistent"

merfish_zarr_path = "/Users/alice/Documents/data/Cell_Segmentation/test_mouse_ileum.zarr" # MODIFY PATH
sdata = sd.read_zarr(merfish_zarr_path)

sdata, _ = compute_consistent_cell(
    sdata=sdata,
    key_shape_nuclei_seg=key_nuclei_segmentation,
    key_shape_cell_seg=key_cell_segmentation,
    key_cell_consistent=key_cell_consistent,
    key_nuclei_consistent=key_nucleus_consistent,
    image_key="staining_z3",
    threshold_intersection_contain=0.95,
    threshold_intersection_intersect= 1,
    accepted_nb_nuclei_per_cell=None,
    max_cell_nb_intersecting_nuclei=1,
)

## Step 3: Training RNA2seg

Now, we will train RNA2seg using the target segmentation created in Step 2.  

### Initialize a RNAsegDataset

In [20]:
from rna2seg.models import RNA2seg

transform_resize  = A.Compose([
 A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST),
])

dataset = RNA2segDataset(
   sdata=sdata,
   channels_dapi= ["DAPI"],
   channels_cellbound=["Cellbound1"],
   shape_patch_key=f"sopa_patches_rna2seg_{patch_width}_{patch_overlap}", # Created at step 1
   key_cell_consistent=key_cell_consistent, # Created at step 2
   key_nucleus_consistent=key_nucleus_consistent, # Created at step 2
   key_nuclei_segmentation=key_nuclei_segmentation,
   gene_column="gene",
   density_threshold = None,
   kernel_size_background_density = 10 ,
   kernel_size_rna2img = 0.5,
   max_filter_size_rna2img = 2,
   transform_resize = transform_resize,
   training_mode = True,
   patch_dir = folder_patch_rna2seg,
   patch_width=1200,
   patch_overlap=50,
   use_cache = True, 
)

print((len(dataset)))

No module named 'vmunet'
VMUnet not loaded
100%|██████████| 64/64 [00:01<00:00, 43.16it/s]
Number of valid patches: 48
100%|██████████| 64/64 [00:01<00:00, 39.21it/s]
compute density threshold


100%|██████████| 48/48 [00:06<00:00,  7.39it/s]

Time to compute density threshold: 6.498080s
48





In [21]:
dataset[2].keys()

dict_keys(['img_cellbound', 'dapi', 'rna_img', 'mask_flow', 'mask_gradient', 'idx', 'patch_index', 'bounds', 'segmentation_nuclei'])

### Train / Validataion split

In [22]:
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler

train_indices, val_indices = train_test_split(
    range(len(dataset)), test_size=0.1, shuffle=True, random_state=42
)
train_sampler = SubsetRandomSampler(train_indices) 
valid_sampler = SubsetRandomSampler(val_indices)

### Initialize Dataloaders

In [23]:
training_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=2,
                                              shuffle=False,
                                              num_workers = 0,
                                              sampler=train_sampler,
                                              collate_fn = custom_collate_fn,
                                              )

print( f"len(training_loader) {len(training_loader)}")

validation_loader = torch.utils.data.DataLoader(dataset,
                                                batch_size=2,
                                                shuffle=False,
                                                num_workers = 0,
                                                sampler=valid_sampler,
                                                collate_fn = custom_collate_fn,
                                                )

print( f"len(training_loader) {len(validation_loader)}")

len(training_loader) 22
len(training_loader) 3


### Initilize RNA2seg Model

In [24]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


In [25]:
rnaseg = RNA2seg(
    device,
    net='unet',
    flow_threshold = 0.9,
    cellbound_flow_threshold = 0.4,
    pretrained_model = None,
)
rnaseg = rnaseg.to(device)

optimizer = torch.optim.AdamW(rnaseg.parameters(), lr=0.001, weight_decay=0.01)


initiaisation of CPnet
Initiaisation of ChannelInvariantNet


### Training RNA2seg

In [27]:
from rna2seg.train import train_one_epoch

best_val_loss = np.inf
path_save_model  = "/Users/alice/Documents/data/Cell_Segmentation/test_mouse_ileum"

for epoch_index in tqdm(range(3)):

    train_one_epoch(
        device=device,
        epoch_index=epoch_index,
        rnaseg=rnaseg,
        training_loader=training_loader,
        optimizer=optimizer,
        print_loss_every = int(len(training_loader) /3),
        tb_writer= None,
        validation_loader=validation_loader,
        path_save_model=path_save_model,
        cellbound_prob= 0.8,
        best_val_loss=best_val_loss
    )

  0%|          | 0/3 [00:00<?, ?it/s]

  validation loss: 4.0795408089955645
best_val_loss: 4.0795408089955645
  validation loss: 3.9352603753407798
best_val_loss: 3.9352603753407798
  validation loss: 4.032752990722656
  validation loss: 3.6595205465952554
best_val_loss: 3.6595205465952554
training: 100%|██████████| 22/22 [01:26<00:00,  3.93s/it]

 33%|███▎      | 1/3 [01:26<02:53, 86.57s/it]


  validation loss: 5.133564313252767
best_val_loss: 5.133564313252767
  validation loss: 4.012377103169759
best_val_loss: 4.012377103169759
  validation loss: 3.7041075229644775
best_val_loss: 3.7041075229644775
  validation loss: 4.0180362065633135
training: 100%|██████████| 22/22 [01:26<00:00,  3.91s/it]

 67%|██████▋   | 2/3 [02:52<01:26, 86.30s/it]


  validation loss: 4.1740842660268145
best_val_loss: 4.1740842660268145
  validation loss: 3.9927518367767334
best_val_loss: 3.9927518367767334
  validation loss: 3.5406501293182373
best_val_loss: 3.5406501293182373
  validation loss: 3.28145170211792
best_val_loss: 3.28145170211792
training: 100%|██████████| 22/22 [01:26<00:00,  3.94s/it]

100%|██████████| 3/3 [04:19<00:00, 86.46s/it]





