In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
import numpy as np
import cc3d

import cv2

from PIL import Image

import pandas as pd

from tqdm import tqdm_notebook as tqdm

from skimage.measure import regionprops, label, find_contours
from scipy.spatial.distance import cdist

import scipy.ndimage as ndimage
from skimage import morphology

import tifffile

In [None]:
#値を-1から1に正規化する関数
def normalize_x(image):
    return image / 127.5 - 1


def denormalize_x(image):
    return (image + 1) * 127.5


#値を0から1正規化する関数
def normalize_y(image):
    return image / 255


#値を0から255に戻す関数
def denormalize_y(image):
    return image * 255

In [None]:
# インプット画像を読み込む関数
def load_X_gray(folder_path):
    
    image_files = []

    #image_files = os.listdir(folder_path)
       
    for file in os.listdir(folder_path):
        base, ext = os.path.splitext(file)
        if ext == '.png':
            image_files.append(file)
        else :
            pass
        
    image_files.sort()
    
    img = cv2.imread(folder_path + os.sep + image_files[0], cv2.IMREAD_GRAYSCALE)
    
    #image_files = image_files[1:]
    images = np.zeros((len(image_files), img.shape[0], img.shape[1], 1), np.float32)
    for i, image_file in enumerate(image_files):
        image = cv2.imread(folder_path + os.sep + image_file, cv2.IMREAD_GRAYSCALE)
        #image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
        image = image[:, :, np.newaxis]
        images[i] = normalize_x(image)
    return images, image_files


# ラベル画像を読み込む関数
def load_Y_gray(folder_path, thresh = None , normalize = True):
    image_files = []
    #image_files = os.listdir(folder_path)
    
    for file in os.listdir(folder_path):
        base , ext = os.path.splitext(file)
        if ext == '.png':
            image_files.append(file)
        else:
            pass
        
    image_files.sort()
    img = cv2.imread(folder_path + os.sep + image_files[0], cv2.IMREAD_GRAYSCALE)
    images = np.zeros(
        (
            len(image_files) ,
            img.shape[0] ,
            img.shape[1] ,
            1
        ) ,
        np.float32
    )
    
    for i , image_file in enumerate(image_files):
        image = cv2.imread(
            folder_path + os.sep + image_file ,
            cv2.IMREAD_GRAYSCALE
        )
        if thresh:
            ret , image = cv2.threshold(image , thresh , 255 , cv2.THRESH_BINARY)
        image = cv2.resize(image ,(img.shape[1] ,img.shape[0]))
        image = image[ : , : , np.newaxis]
        if normalize:
            images[i] = normalize_y(image)
        else:
            images[i] = image
            
    print(images.shape)
    
    return images , image_files

In [None]:
def dilate_imgs(imgs):
    
    
    """
        Args : 
            imgs (numpy.ndarray) : Z, Y, X, 1
            
        Returns : 
            dilated_imgs (numpy.ndarray) : Z, Y, X, 1
    """
    
    dilated_imgs = morphology.binary_dilation(imgs[:,:,:,0], morphology.ball(3)).astype(np.uint8)

    return dilated_imgs


def quantify_mercs(labeled_mito_imgs, mercs_imgs):
    
    
    """
        Args : 
            labeled_mito_imgs (numpy.ndarray) : Z, Y, X, 1
            mercs_imgs (numpy.ndarray) : Z, Y, X, 1
            
        Returns : 
            mercs_volume (num)
    """    
    
    mercs_volume = np.count_nonzero((labeled_mito_imgs > 0) & (mercs_imgs > 0))
    
    return mercs_volume


def erode_mito(imgs, erosion_iteration):
    
    diamond = ndimage.generate_binary_structure(rank=3, connectivity=1)
    eroded_mito = ndimage.binary_erosion(imgs, diamond, iterations=erosion_iteration)
    
    return eroded_mito


def dilate_mito(imgs, dilation_iteration):
    
    diamond = ndimage.generate_binary_structure(rank=3, connectivity=1)
    dilated_mito = ndimage.binary_dilation(imgs, diamond, iterations=dilation_iteration)
    
    return dilated_mito


def contract_contours(imgs, erosion_iteration):
    
    eroded_imgs = erode_mito(imgs, erosion_iteration-1).astype("uint8")
    plus_eroded_imgs = erode_mito(imgs, erosion_iteration).astype("uint8")
    
    contours_imgs = eroded_imgs - plus_eroded_imgs
    return contours_imgs


def convert_img2coords(imgs):
    
    _tuples = imgs.nonzero()
    
    _coords = np.zeros((_tuples[0].shape[0], 3))
    _coords[:,0] = _tuples[0]
    _coords[:,1] = _tuples[1]
    _coords[:,2] = _tuples[2]
    
    return _coords


def pick_counters(imgs):
    
    """
        Args : 
            imgs (numpy.ndarray) : Z, Y, X, 1
            
        Returns : 
            counter_images (numpy.ndarray) : Z, Y, X, 1
    """
    
    lst = []
    
    for z in range(imgs.shape[0]):
        counters, hierarchy = cv2.findContours(imgs[z], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        counter_img = np.zeros((imgs.shape[1], imgs.shape[2], 1), np.float32)
        line = cv2.drawContours(counter_img, counters, -1, (1,0,0), 1)
        
        x = imgs.shape[1]
        y = imgs.shape[2]
        
        counter = line[:,:,0]
        counter = cv2.line(counter,(0,y-1),(x,y-1),(0,0, 0),1)
        counter = cv2.line(counter,(0,1),(x,1),(0,0, 0),1)
        counter = cv2.line(counter,(1,0),(x-1,0),(0,0, 0),1)
        counter = cv2.line(counter,(1,y),(x-1,y),(0,0, 0),1)
        
        lst.append(counter)
        
    return np.array(lst)

In [None]:
def make_intensity_map(OMM_POINT, CJ_POINT, EM_IMAGES):
    
    EM_points = convert_img2coords(EM_IMAGES).astype(int)

    lst = []
    for EM_point in EM_points:
        inner_product = np.inner((EM_point-CJ_POINT), (OMM_POINT-CJ_POINT))
        l_CJ_EM_point = np.linalg.norm((EM_point-CJ_POINT))
        l_CJ_OMM = np.linalg.norm((OMM_POINT-CJ_POINT))
        theta = np.arccos(inner_product / (l_CJ_EM_point * l_CJ_OMM))
        distance = np.sin(theta) * l_CJ_EM_point
        if distance <= 1:
            lst.append({
                "location" : np.cos(theta) * l_CJ_EM_point,
                "Intensity" : EM_IMAGES[EM_point[0], EM_point[1], EM_point[2]]
            })
    
    return lst
    
    


def controller(GENE, CROP_NO):
    
    input_dir = f"/home/suga/drobo/DeepLearningData/research_010_NIH3T3/{GENE}/annotations/{CROP_NO}/cropped_mito_datasets"
    
    for filename in tqdm(sorted(os.listdir(input_dir))):
        if len(filename) == 4:
        
            input_path = f"{input_dir}/{filename}"
        
            lamellar_cj, tubular_cj = culc_each_mito(input_path)
            #culc_each_mito(input_path)
        
            tifffile.imwrite(f"{input_dir}/{filename}/lamellar_cj_107_002.tiff", lamellar_cj)
            tifffile.imwrite(f"{input_dir}/{filename}/tubular_cj_107_002.tiff", tubular_cj)
    
    
    
def culc_each_mito(input_path):
    
    EM_imgs = tifffile.imread(f"{input_path}/ori_imgs.tiff")
    mito_imgs = tifffile.imread(f"{input_path}/mito_imgs.tiff")
    lamellar_imgs = tifffile.imread(f"{input_path}/lamellar_imgs.tiff")
    tubular_imgs = tifffile.imread(f"{input_path}/tubular_imgs.tiff")
    
    target_cristae_imgs = np.where(
        lamellar_imgs + tubular_imgs > 0,
        1,
        0
    )
    
    omm_imgs = pick_counters(mito_imgs.astype("uint8"))
    omm_points = convert_img2coords(omm_imgs)
    
    erode_cj_imgs = []
    OMM_points_imgs = np.zeros_like(omm_imgs)
    
    for erode_num in range(3, 6):
        
        OMM_points_imgs_n = np.zeros_like(omm_imgs)
    
        ### ①OMMからXpx erodeする
        counter_eroded_imgs = contract_contours(mito_imgs, erosion_iteration=erode_num).astype("uint8")
    
        ### ②CristaeとEroded_mitoの重複部分を出す
        duplicated_cj_imgs = counter_eroded_imgs[..., np.newaxis] * target_cristae_imgs[..., np.newaxis]
    
        ### ③dup_cjをラベリングする
        labeled_cj_imgs = cc3d.connected_components(duplicated_cj_imgs[:,:,:,0].astype(int), connectivity = 26)
    
        cj_imgs = np.zeros_like(duplicated_cj_imgs)
    
        for label_no in range(1, np.max(labeled_cj_imgs)+1):
            labled_cj = (labeled_cj_imgs == label_no).astype("uint8")[..., np.newaxis]
        
            ### ④labeld_cjとOMMの最短距離をとるOMMの点を計算する
            labeld_cj_points = convert_img2coords(labled_cj)
            distances = cdist(omm_points, labeld_cj_points)
            ## distancesは(OMM_points, labeld_cj_points)の形で出る
        
            OMM_points_coords = np.nonzero(distances==np.min(distances))
            OMM_point = omm_points[OMM_points_coords[0][0]].astype(int)
            cj_point = labeld_cj_points[OMM_points_coords[1][0]].astype(int)
        
            ### ⑤Labeled＿CJとOMM-pointの間のIntensity mapを作る
            target_EM_image = EM_imgs[
                min(OMM_point[0], cj_point[0]) : max(OMM_point[0], cj_point[0]) + 1,
                min(OMM_point[1], cj_point[1]) : max(OMM_point[1], cj_point[1]) + 1,
                min(OMM_point[2], cj_point[2]) : max(OMM_point[2], cj_point[2]) + 1
            ]
            target_OMM_point = (
                OMM_point[0] - min(OMM_point[0], cj_point[0]),
                OMM_point[1] - min(OMM_point[1], cj_point[1]),
                OMM_point[2] - min(OMM_point[2], cj_point[2])
            )
            target_cj_point = (
                cj_point[0] - min(OMM_point[0], cj_point[0]),
                cj_point[1] - min(OMM_point[1], cj_point[1]),
                cj_point[2] - min(OMM_point[2], cj_point[2])
            )
            intensity_map_lst = make_intensity_map(np.array(target_OMM_point), np.array(target_cj_point), target_EM_image)
        
            ### ⑥Threshを下回ったら除く
            ### 0.4より上だったらダメとする
            is_isolated = denormalize_x(max([ node['Intensity'] for node in intensity_map_lst ])) < 127
            if is_isolated:
                
                ### ⑦CJ_pointとOMM_pointを1vs1対応させる　優先させるのは、3px →　4px → 5pxの順
                if OMM_points_imgs[OMM_point[0], OMM_point[1], OMM_point[2]] == 1:
                    pass
                else:
                    cj_imgs += labled_cj
                    OMM_points_imgs_n[OMM_point[0], OMM_point[1], OMM_point[2]] = 1 
                
        erode_cj_imgs.append(cj_imgs[:,:,:,0])
        OMM_points_imgs = ((OMM_points_imgs + OMM_points_imgs_n) > 0).astype(int)
        
    merged_erode_cj_imgs = ((erode_cj_imgs[0] + erode_cj_imgs[1] + erode_cj_imgs[2]) > 0).astype(int)
    lamellar_cj = lamellar_imgs * merged_erode_cj_imgs
    tubular_cj = tubular_imgs * merged_erode_cj_imgs
    
    return lamellar_cj, tubular_cj


In [None]:
controller("shCtrl_003", "cropped_001")
controller("shCtrl_003", "cropped_002")
controller("shCtrl_003", "cropped_003")
controller("shCtrl_003", "cropped_004")
controller("shCtrl_003", "cropped_005")

controller("shOPA1_003", "cropped_001")
controller("shOPA1_003", "cropped_002")
controller("shOPA1_003", "cropped_003")
controller("shOPA1_003", "cropped_004")
controller("shOPA1_003", "cropped_005")


In [None]:
def controller(GENE, CROP_NO):
    
    input_dir = f"/home/suga/drobo/DeepLearningData/research_010_NIH3T3/{GENE}/annotations/{CROP_NO}/cropped_mito_datasets"
    
    dct = {}
    
    for filename in tqdm(sorted(os.listdir(input_dir))):
        if len(filename) == 4:
        
            input_path = f"{input_dir}/{filename}"
            #mito_imgs = tifffile.imread(f"{input_path}/mito_imgs.tiff")
            lamellar_cj = tifffile.imread(f"{input_path}/lamellar_cj_107_002.tiff")
            tubular_cj = tifffile.imread(f"{input_path}/tubular_cj_107_002.tiff")
            
            removed_cj = np.zeros_like(lamellar_cj)
            
            cj = (lamellar_cj + tubular_cj > 0).astype(int)
            labeled_cj = cc3d.connected_components((cj == 1).astype(int), connectivity = 26)
            
            for label in range(1, np.max(labeled_cj)+1):
                target_cj = (labeled_cj == label).astype(int)
                if np.sum(target_cj) > 2:
                    removed_cj += target_cj
                    
            removed_lamellar_cj = lamellar_cj * removed_cj
            removed_tubular_cj = tubular_cj * removed_cj
            
            tifffile.imwrite(f"{input_path}/lamellar_cj_114_001.tiff", removed_lamellar_cj)
            tifffile.imwrite(f"{input_path}/tubular_cj_114_001.tiff", removed_tubular_cj)
            

In [None]:
controller("shCtrl_003", "cropped_001")
controller("shCtrl_003", "cropped_002")
controller("shCtrl_003", "cropped_003")
controller("shCtrl_003", "cropped_004")
controller("shCtrl_003", "cropped_005")

In [None]:
controller("shOPA1_003", "cropped_001")
controller("shOPA1_003", "cropped_002")
controller("shOPA1_003", "cropped_003")
controller("shOPA1_003", "cropped_004")
controller("shOPA1_003", "cropped_005")
