This notebook has to run in an environment with SAM, on Windows it will require WSL.

In [2]:
from imaris_ims_file_reader.ims import ims
import dask.array as da
import os
import gc
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tifffile import imread
from skimage.color import gray2rgb
import matplotlib.pyplot as plt
from skimage.segmentation import clear_border
from tqdm import tqdm
import pickle as pkl

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

from sem_quant.load_config import load_config
from sem_quant.utils import smart_path

import torch
torch.cuda.is_available()

True

In [3]:
device = torch.device("cuda")

torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [7]:
config_file_name = "../A01_config.json"

In [11]:
im_path

'/mnt/i/CBI/Jonathan/CLEM/Birder/88EM87 A/00001_ashlar.ome.tif'

In [10]:
# read in config file

config = load_config(config_file_name)

im_path = smart_path(config.paths.im_path)
analysis_dir = smart_path(config.paths.analysis_dir)

output_prefix = config.paths.output_prefix
mitos_data_suffix = config.paths.mitos_data_suffix
axons_data_suffix = config.paths.axons_data_suffix

axons_res = config.data_properties.axons_res
mitos_res = config.data_properties.mitos_res
row_offset = config.data_properties.row_offset
col_offset = config.data_properties.col_offset

sam2_checkpoint = config.sam_model.sam2_checkpoint
model_cfg = config.sam_model.model_cfg


In [None]:
output_sub_dir = mitos_data_suffix
df_path = os.path.join(analysis_dir,f'{output_prefix}{axons_data_suffix}.pkl')

os.makedirs(os.path.join(analysis_dir,output_sub_dir),exist_ok=True)

In [12]:
im_path

'/mnt/i/CBI/Jonathan/CLEM/Birder/88EM87 A/00001_ashlar.ome.tif'

In [13]:
store = imread(im_path, aszarr=True)
im = da.from_zarr(store,mitos_res)
im

Unnamed: 0,Array,Chunk
Bytes,781.04 MiB,1.00 MiB
Shape,"(25535, 32073)","(1024, 1024)"
Dask graph,800 chunks in 2 graph layers,800 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 781.04 MiB 1.00 MiB Shape (25535, 32073) (1024, 1024) Dask graph 800 chunks in 2 graph layers Data type uint8 numpy.ndarray",32073  25535,

Unnamed: 0,Array,Chunk
Bytes,781.04 MiB,1.00 MiB
Shape,"(25535, 32073)","(1024, 1024)"
Dask graph,800 chunks in 2 graph layers,800 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [17]:
# import df
df = pd.read_pickle(df_path)

In [18]:
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=64,
    points_per_batch=64,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=2,
    box_nms_thresh=0.7,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=1000,
    use_m2m=False,
)

In [19]:
res_adjust = axons_res - mitos_res
mito_props = ['area', 'predicted_iou', 'stability_score']

st = 110
for ind, row in tqdm(df.loc[st:, :].iterrows(), total=len(df.loc[st:, :])):

    # get the cell interior image of high res
    row_start = (int(row['inside_bbox-0']) + row_offset)*2**res_adjust
    row_end = (int(row['inside_bbox-2']) + row_offset)*2**res_adjust
    col_start = (int(row['inside_bbox-1']) + col_offset)*2**res_adjust
    col_end = (int(row['inside_bbox-3']) + col_offset)*2**res_adjust

    im_test = im[row_start:row_end, col_start:col_end]
    
    # run the segmentation
    im_rgb = gray2rgb(im_test).compute()
    masks_org = mask_generator.generate(im_rgb)

    # saving the masks
    file_name = f'{output_prefix}{str(row.label).zfill(6)}_mito.pkl'
    with open(os.path.join(analysis_dir, output_sub_dir, file_name), 'wb') as f:
        pkl.dump(masks_org, f)

    torch.cuda.empty_cache()
    del im_test
    del im_rgb 
    del masks_org
    gc.collect()

  0%|          | 0/1198 [00:00<?, ?it/s]


Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  masks = self._transforms.postprocess_masks(
 14%|█▎        | 164/1198 [39:16<4:07:37, 14.37s/it]


KeyboardInterrupt: 