In [None]:
from cucim.skimage.segmentation import checkerboard_level_set, morphological_chan_vese # works only with cupy 12.3.0
import cupy as cp
import numpy as np
import gc
from tqdm import tqdm
from multiprocessing import Manager, Process, Queue, Value, Lock
import blosc2
import os
from time import sleep
from utils import load_tifstack, chunk_generator

In [None]:
num_gpus = 8
# Create a queue for each GPU using Queue for multiprocessing
gpu_queues = [Queue() for _ in range(num_gpus)]
manager_acwe = Manager()
acwe_dict = manager_acwe.dict()
lock = Lock()
total_chunks = Value('i', 0)
processed_chunks = Value('i', 0)

In [None]:
# Producer function to assign chunks to GPU queues dynamically
def producer(scroll, folder_path, chunk_size, gpu_queues, total_chunks):
    for i, (z, y, x) in tqdm(enumerate(chunk_generator(scroll.shape, chunk_size))):
        filepath = os.path.join(folder_path, f"chunk_z_y_x_{z}_{y}_{x}.b2nd")
        gpu_id = i % num_gpus
        chunk = blosc2.open(filepath, mode="r")[:,:,:].astype(np.float32)
        chunk_id = (z, y, x)
        gpu_queues[gpu_id].put((chunk_id, chunk))
        with lock:
            total_chunks.value += 1
        delta = total_chunks.value - processed_chunks.value
        sleep(20*(delta//8))
    # Signal the end of the data with a special value (None)
    for gpu_queue in gpu_queues:
        gpu_queue.put(None)

In [None]:
# Consumer function to process chunks on GPU
def process_chunk_on_gpu(gpu_id, task_queue, dict, processed_chunk, lock):
    cp.cuda.Device(gpu_id).use()
    while True:
        item = task_queue.get()
        if item is None:
            break
        chunk_id, chunk = item
        chunk = cp.asarray(chunk)

        chunk /= 65535.

        init_ls = checkerboard_level_set(chunk.shape, 5)
        #print("ok", gpu_id)
        #print("Before morpho", chunk.shape)
        mask = morphological_chan_vese(image=chunk, num_iter=20, init_level_set=init_ls)
        #print("chanvese", gpu_id)
        average_1 = cp.mean(chunk[mask == 1])
        average_2 = cp.mean(chunk[mask == 0])

        if average_2 > average_1:
            cp.invert(mask, out=mask)
        
        with lock:
            dict[chunk_id] = mask.get().astype(np.uint8)
            processed_chunk.value += 1

        del init_ls, chunk, mask
        #free_memory()
        #cp.cuda.Stream.null.synchronize()
        

In [None]:
def writer_mask_process(output_folder, chunk_size, dict, total_chunks, processed_chunks):
    clevel = 9
    nthreads = 200
    cparams = {
            "codec": blosc2.Codec.ZSTD,
            "clevel": clevel,
            "filters": [blosc2.Filter.BITSHUFFLE, blosc2.Filter.BYTEDELTA],
            "filters_meta": [0, 0],
            "nthreads": nthreads,
    }
    
    while True:
        if processed_chunks.value == total_chunks.value and len(dict) == 0:
            break
        for chunk_id, mask in list(dict.items()):
            z, y, x = chunk_id
            filepath = os.path.join(output_folder, f"chunk_z_y_x_{z}_{y}_{x}.b2nd")
            try:
                mask_array = blosc2.empty(mask.shape, dtype=np.uint8, chunks=(chunk_size[0],chunk_size[1],chunk_size[2]), blocks=(100,100,100), urlpath=filepath, cparams=cparams)
                mask_array[:,:,:] = mask
                #print(f"Writer: Finished writing chunk {chunk_id}")
            except:
                continue
            del dict[chunk_id]

In [None]:
scroll = load_tifstack("../Scroll2.volpkg/volumes/20230210143520_grids")
chunk_size = [800, 800, 800]

In [None]:
# Create and start a producer process
producer_process = Process(target=producer, args=(scroll, "./scroll2-denoised", chunk_size, gpu_queues, total_chunks))
producer_process.start()

In [None]:
# Create and start a process for each GPU (right now on CPU)
processes = []
for gpu_id in range(num_gpus):
    p = Process(target=process_chunk_on_gpu, args=(gpu_id, gpu_queues[gpu_id], acwe_dict, processed_chunks, lock))
    processes.append(p)
    p.start()

In [None]:
# Create and start the writer process
mwriter = Process(target=writer_mask_process, args=("./scroll2-denoised/mask", chunk_size, acwe_dict, total_chunks, processed_chunks))
mwriter.start()

In [None]:
producer_process.close()
for p in processes:
    p.close()
mwriter.close()