In [None]:
%load_ext autoreload
%autoreload 2
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
from config.settings import *

In [None]:
import cv2 as cv
import numpy as np
from ray import tune, air
from ray.tune.schedulers import ASHAScheduler
from ray.air import session
from src.analysis.components import bottomhat, close_rc, seq_line_open_rc, double_threshold, area_filter, contour_smoothing, close

def hpo_crista_segmentation(config):
    path = PROJECT_ROOT / 'data' / 'interim'
    mito = cv.imread(str(path / 'mito.tif'), cv.IMREAD_GRAYSCALE)
    mask = cv.imread(str(path / 'mask.tif'), cv.IMREAD_GRAYSCALE)
    target = cv.imread(str(path / 'cristamask.tif'), cv.IMREAD_GRAYSCALE)

    k = cv.getStructuringElement(cv.MORPH_ELLIPSE,((config['mask_ksize'],)*2))
    m = cv.erode(mask, k)

    y = cv.equalizeHist(mito)
    y = cv.equalizeHist(bottomhat(y, config['bhat_ksize'])) # shading correction
    y = cv.bitwise_or(y, cv.bitwise_not(m)) # remove outer membrane

    y = double_threshold(y, config['upper_narrow'], config['upper_wide'])
    y = area_filter(y, config['threshold']) # small component removal
    y = close(y,3)
    #y = prune(y, config['prunelength'])
    contours, _ = cv.findContours(y, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
    y = np.zeros_like(y)
    y = cv.drawContours(y, contours, -1, 255, cv.FILLED)
    #y = close(y, config['close_ksize'])
    #y = holefill(y)

    iou = cv.bitwise_and(y,target).sum()/cv.bitwise_or(y,target).sum()

    session.report({'iou': iou})

config = {
    "mask_ksize": tune.randint(10,51),
    "bhat_ksize": tune.randint(10,61),
    "upper_narrow": tune.randint(0,30),
    "upper_wide": tune.randint(50,100),
    #"prunelength": tune.randint(10,30),
    "threshold": tune.randint(0,200),
    #"close_ksize": tune.randint(1,11),
}

resultpath = str(PROJECT_ROOT / 'log')

tuner = tune.Tuner(
    hpo_crista_segmentation,
    param_space=config,
    tune_config=tune.TuneConfig(
        num_samples=100,
        scheduler=ASHAScheduler(metric="iou", mode="max")),
    run_config=air.RunConfig(storage_path=resultpath, name="crista"),
)

results = tuner.fit()

best_trial = results.get_best_result('iou', mode='max')
print(f"Best trial config: {best_trial.config}")
print(f"Best trial IoU: {best_trial.metrics['iou']}")