# Running Segmentation Notebook

## We have finished the segmentation training and get the trained model model_epoch_50.pth . We are generating the segmented cell membrane and cell morphological objects.

### Place your runing augments at ./experiment/TRAIN_TEST.yaml or change it at the Paramenters configuration section 

```
# =================================
# Parameters for prediction (Paras for training will be preserved)
# =================================
get_memb_bin: True
show_snap: False
get_cell: True


test_data_dir: ./dataset/run
test_embryos: [200109plc1p1,200113plc1p2]
test_max_times: [180,187]

test_transforms: # for testing
  Compose([
    Resize((256,352,224)),
    NumpyType((np.float32, np.float32)),
    ])
    
```

In [1]:
import os
import time
import logging
import random
import argparse
import setproctitle
import numpy as np

import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

import models 
from utils import ParserUse
from utils.show_train import Visualizer


In [2]:
cudnn.benchmark = True # https://zhuanlan.zhihu.com/p/73711222 to accelerate the network
## parse arguments
args = ParserUse('TRAIN_TEST', log='run')
# CHANGE Your training parameters here
args.mode=2
args.gpu='0'
args.ckpts='./ckpts'
args.suffix='*.pkl'
args.save_format='nii'
args.resume=r'./ckpts/CMap_model_epoch_50.pth'
is_scoring = False

visualizer = None
if args.show_snap:
    visualizer = Visualizer(1)

## Start run the prediction on your computer with gpu

In [3]:
# setproctitle.setproctitle(args.cfg)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
assert torch.cuda.is_available(), "GPU is needed for prediction"

# =============================================================
#  set seeds for randomlization in TRAIN_TEST.yaml
# =============================================================
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

In [None]:
import shutil


from data import datasets
from utils.prediction_utils import validate, membrane2cell, combine_cells

# start to predict the unconstant eggshell (prediciton, binary segmentation)
if args.get_memb_bin:
    # get membrane binary shell
    # nii.gz to pickle, make it easier to read in neural network
    # doit(test_folder, embryo_names=args.test_embryos, max_times=args.max_times)
    # =============================================================
    #  construct network model
    # =============================================================
    Network = getattr(models, args.net)
    model = Network(**args.net_params)
    model = torch.nn.DataParallel(model).cuda()
    print("="*20 + "Loading parameters {}".format(args.resume) + "="*20)
    assert os.path.isfile(args.resume), "{} ".format(args.resume) + "doesn't exist"
    check_point = torch.load(args.resume)
    args.start_iter = check_point["iter"]
    model.load_state_dict(check_point["state_dict"])

    msg = ("Loaded checkpoint '{}' (iter {})".format(args.resume, check_point["iter"]))
    msg = msg + "\n" + str(args)
    
#     logging.info(msg)
    
    root_path=args.running_data_dir
    
    if args.get_memb_bin or args.get_cell:
        if args.running_embryos is None:
            args.running_embryos = [name for name in os.listdir(root_path) if os.path.isdir(os.path.join(root_path, name))]
    Dataset = getattr(datasets, args.dataset)
    test_set = Dataset(root=root_path, membrane_names=args.running_embryos, for_train=False, transforms=args.test_transforms,
                       return_target=is_scoring, suffix=args.suffix, max_times=args.run_max_times)
    test_loader = DataLoader(
        dataset=test_set,
        batch_size=1,
        shuffle=False,
        # collate_fn=test_set.collate, # control how data is stacked
        num_workers=10,
        pin_memory=True
    )

    #=============================================================
    #  begin prediction
    #=============================================================
    output_saving_path=r'./dataset/run'

    # the edt cell membrane file will save in the {savepath}/{embryo_name}/segMemb folder
    with torch.no_grad():
        validate(
            valid_loader=test_loader,  # dataset loader
            model=model,  # model
            savepath=output_saving_path,  # output folder
            names=test_set.names,  # stack name lists
            scoring=False,  # whether keep accuracy
            save_format=".nii.gz",  # save volume format
            snapsot=visualizer,  # whether keep snap
            postprocess=False,
            size=test_set.size
        )


2023-07-18 15:27:53,621 Note: NumExpr detected 48 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2023-07-18 15:27:53,622 NumExpr defaulting to 8 threads.




Getting binary membrane::  75%|████████████████████████        | 276/367 [31:07<15:03,  9.93s/it]

## From now on, you get the euclidean distance transformed cell membrane segmentation saved at ./dataset/run/{embryo_name}/SegMemb, which could be opened with ITK-SNAP.

## Next, let's work on the cell object segmentation

In [None]:
#  Post process on binary segmentation. Group them into closed 3D cells
if args.get_cell:
    # read the binary segmentation in segMemb folder and process them
    membrane2cell(args)
    #  Combine labels by detecting the dividing cells via the cell membrane segmentation
    print("Begin combine division based on TP...\n")
    combine_cells(args)