In [None]:
%load_ext autoreload
%autoreload 2

import PIL
from PIL import Image
import numpy as np
import glob
import skimage
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
from tqdm import tqdm

from utils.segmentation import Segmentation
from utils.bg_fg_prep import saliency_detect
from utils.image_io import *


from skimage.filters import threshold_li, threshold_mean
from skimage.morphology import disk
from skimage.morphology import erosion, dilation, opening, closing, skeletonize


In [None]:
def empty_dir(folder):
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path): shutil.rmtree(file_path)
        except Exception as e:
            print(e)
    print('empty directory: ', folder)


### CELL

In [None]:
folder_name = 'FISH'

#### Utils Functions
- change image into three channel
- get hints

In [None]:
# transform cell images into three channel
# from images_1 to images

for img in tqdm(glob.glob('data/'+ folder_name + '/images_1/*.png')):
    image = Image.open(img)
    image = pil_to_np(image)
    image_s = image.squeeze()
    tmp_image = skimage.color.gray2rgb(image_s)
    tmp_image = tmp_image.transpose(2, 0, 1)
    tmp_image = np_to_pil(tmp_image)
    
    os.makedirs('data/'+ folder_name + '/images/', exist_ok=True)
    name_list = img.split('/')
    save_path = 'data/'+ folder_name + '/images/' + name_list[-1]
    tmp_image.save(save_path)
    

In [None]:
def process_img(img):
    thresh = threshold_mean(img)

    binary = img > thresh
    binary_small = closing(binary, disk(3))
    binary_new = dilation(binary_small, disk(4))
    
    return binary_new

def get_bg(img):
    thresh = threshold_li(img)
    binary = img <= thresh
    binary_ero = skimage.morphology.dilation(binary, disk(5))
    return binary_ero

def get_bg_from_annot(annot):
    annot_ero = erosion(annot, disk(3))
    annot_ero = annot
    buffer = np.zeros(annot_ero.shape)
    buffer[annot_ero > 0] = 0
    buffer[annot_ero == 0] = 1 
    
    return buffer

In [None]:
def process_fg(img):
    thresh = threshold_mean(img)
    binary = img < thresh
    binary_new = erosion(binary, disk(4))
    return binary_new

def process_bg(img):
    thresh = threshold_mean(img)
    binary = img > thresh
    binary_new = erosion(binary, disk(4))
    return binary_new

In [None]:
# saliency_detect('data/'+ folder_name + '/images/*.png', dest='data/' + folder_name,
#                t1=115, t2=120)

In [None]:
# get hints from original images

dest = 'data/{}'.format(folder_name)
os.makedirs(dest + r"/output_fg/", exist_ok=True)
os.makedirs(dest + r"/output_bg/", exist_ok=True)

folder = 'data/{0}/{1}/*.png'.format(folder_name, 'images_1')

for img in tqdm(glob.glob(folder)):
    image = skimage.io.imread(img, 0)
    
    image_fg = process_fg(image)
    image_fg = skimage.img_as_ubyte(image_fg)
    skimage.io.imsave(dest + r"/output_fg/" + img.split('/')[-1], image_fg)

    image_bg = process_bg(image)
    image_bg = skimage.img_as_ubyte( image_bg)  ### 1-
    skimage.io.imsave(dest + r"/output_bg/" + img.split('/')[-1], image_bg)

----------------------------

Run Network

In [None]:
empty_dir('output/FISH/mask')
empty_dir('output/FISH/fixed_mask')
empty_dir('output/FISH/left_right')
empty_dir('output/FISH/reconstruct')

In [None]:
"""use one channel input
"""

def run_net(img_name, output_path='output/' + folder_name):
    image = Image.open('data/'+ folder_name + '/images_1/' + img_name)
    image = pil_to_np(image)
    bg_hint= Image.open('data/'+ folder_name + '/output_bg/' + img_name)
    bg_hint= pil_to_np(bg_hint)
    fg_hint= Image.open('data/'+ folder_name + '/output_fg/' + img_name)
    fg_hint= pil_to_np(fg_hint)
    
    net = Segmentation(image_name='data/'+ folder_name + '/images_1/' + img_name, 
                       image=image, bg_hint=bg_hint, fg_hint=fg_hint,
                       input_depth=2, output_depth=1, psnr_goal=40,
                       output_path=output_path, show_every=500,
                       first_step_iter_num = 2000, second_step_iter_num = 4000,
                       plot_during_training=True)
    
    net.optimize()
    net.finalize()

In [None]:
%%time

run_net('0000.png')

In [None]:
cell_list = os.listdir('data/'+ folder_name + '/images_1')
cell_list = [x for x in cell_list if x.endswith('png')]
cell_list = sorted(cell_list)
# cell_list

In [None]:
for c in cell_list:
    run_net(c)