# Objective 

In [1]:
import spatialdata as sd
from pathlib import Path
from rna_seg.dataset_zarr.patches import create_patch_rnaseg
from rna_seg.dataset_zarr.RNA2segDataset import RNA2segDataset, custom_collate_fn
from rna_seg.models import RNASeg
import numpy as np
import albumentations as A
import cv2
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm
from rna_seg.train import train_one_epoch

  warn(f"Failed to load image Python extension: {e}")


No module named 'vmunet'
VMUnet not loaded


### step 1 create training dataset from zarr files
* create crops shape 
* create folder of csv

In [2]:
### load sdata and set path parameters 
merfish_zarr_path = "/media/tom/Transcend/open_merfish/test_spatial_data/from_cluster/test_mouse_ileum.zarr"
sdata = sd.read_zarr(merfish_zarr_path)
image_key = "staining_z3"
patch_width = 1200
patch_overlap = 50
points_key = "transcripts"
min_transcripts_per_patch = 0
folder_patch_rna_seg = Path(merfish_zarr_path) / ".rna_seg"

### create patch in the sdata and precompute transcipt.csv for each patch with sopa
create_patch_rnaseg(sdata,
                    image_key=image_key,
                    points_key=points_key,
                    patch_width=patch_width,
                    patch_overlap=patch_overlap,
                    min_transcripts_per_patch=min_transcripts_per_patch,
                    overwrite = True)
print(sdata)

  self._check_key(key, self.keys(), self._shared_keys)
[36;20m[INFO] (sopa.patches.patches)[0m 64 patches were saved in sdata['sopa_patches_rna_seg_1200_50']


[########################################] | 100% Completed | 20.12 ss
SpatialData object, with associated Zarr store: /media/tom/Transcend/open_merfish/test_spatial_data/from_cluster/test_mouse_ileum.zarr
├── Images
│     └── 'staining_z3': DataTree[cyx] (5, 9000, 9000), (5, 4500, 4500), (5, 2250, 2250), (5, 1125, 1125), (5, 562, 562)
├── Points
│     └── 'transcripts': DataFrame with shape: (<Delayed>, 9) (2D points)
└── Shapes
      ├── 'Cellbound1': GeoDataFrame shape: (3258, 1) (2D shapes)
      ├── 'Cellbound1_consistent_with_nuclei': GeoDataFrame shape: (1007, 1) (2D shapes)
      ├── 'Cellbound1_consistent_without_nuclei': GeoDataFrame shape: (2239, 1) (2D shapes)
      ├── 'DAPI': GeoDataFrame shape: (2377, 1) (2D shapes)
      ├── 'DAPI_consistent_in_cell': GeoDataFrame shape: (1007, 1) (2D shapes)
      ├── 'DAPI_consistent_not_in_cell': GeoDataFrame shape: (1370, 1) (2D shapes)
      └── 'sopa_patches_rna_seg_1200_50': GeoDataFrame shape: (64, 3) (2D shapes)
with coordinate

### Step 2 create training dataset from sdata

In [15]:

key_shape_cell_seg = "Cellbound1"
key_nuclei_segmentation="DAPI"
# to name for future shape that will be created in the sdata
key_cell_consistent = "Cellbound1_consistent"
key_nucleus_consistent = "DAPI_consistent"

transform_resize  = A.Compose([
 A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST),
])


dataset = RNA2segDataset(
 sdata=sdata,
 channels_dapi= ["DAPI"],
 channels_cellbound=["Cellbound1"],

 shape_patch_key="sopa_patches_rna_seg_1200_50",

 key_cell_consistent=key_cell_consistent,
    key_nucleus_consistent=key_nucleus_consistent,
 key_nuclei_segmentation=key_nuclei_segmentation,
 gene_column="gene",
 density_threshold = None,
 kernel_size_background_density = 10 ,
 kernel_size_rna2img = 0.5,
 max_filter_size_rna2img = 2,
 transform_resize = transform_resize,
    training_mode = True,
    path_cache = sdata.path / f'.rnaseg'
 )

100%|██████████| 64/64 [00:01<00:00, 34.58it/s]
Number of valid patches: 48


### set threshold on density

In [None]:
dataset.set_threshold( max_nb_crops=500,
              kernel_size = 9,
              percentile_threshold = 5,
              shape=(1200, 1200),)
dataset[0].keys()

100%|██████████| 64/64 [00:01<00:00, 34.66it/s]
compute density threshold


 54%|█████▍    | 26/48 [00:04<00:03,  6.03it/s]

# Step 3: Initilize and RNAseg Model, optinally load pretrained weights

In [None]:
device="cpu"

rnaseg = RNASeg(device,
                net='unet',
                flow_threshold = 0.9,
                cellbound_flow_threshold = 0.4,
                pretrained_model = None,
                )
rnaseg.to(device)
print('ok')

In [None]:
validation_split = 0.1
indices = list(range(len(dataset)))
split = int(np.floor(validation_split * len(dataset)))

np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices) # does it change the order at each iteration?
valid_sampler = SubsetRandomSampler(val_indices) # does it change the order at each iteration?


#dataset.reloaded = False
training_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=5,
                                              shuffle=False,
                                              num_workers = 5,
                                              sampler=train_sampler,
                                              collate_fn = custom_collate_fn,
                                              )

print( f"len(training_loader) {len(training_loader)}")

validation_loader = torch.utils.data.DataLoader(dataset,
                                                batch_size=2,
                                                shuffle=False,
                                                num_workers =  3,
                                                sampler=valid_sampler,
                                                collate_fn = custom_collate_fn,
                                                )

optimizer = torch.optim.AdamW(rnaseg.parameters(),
                              lr=0.001,
                              weight_decay=0.01)

                                

In [None]:
best_val_loss = np.inf
path_save_model = "/home/tom/toremove"
for epoch_index in tqdm(range(3)):

    train_one_epoch(
        device=device,
    epoch_index=epoch_index,
    rnaseg=rnaseg,
    training_loader=training_loader,
    optimizer=optimizer,
    print_loss_every = int(len(training_loader) /3),
    tb_writer= None,
    validation_loader=validation_loader,
    path_save_model=path_save_model,
    cellbound_prob=  0.8,
    best_val_loss=best_val_loss
    )