In [109]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
import tkinter as tk
import math
import os
%gui tk
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import Path, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import CARE
from IPython.display import clear_output
from tqdm import tqdm

def next_power_of_2(x):  
    return 1 if x == 0 else 2**(x - 1).bit_length()

def subdivide():
    zslice = next_power_of_2(math.floor(size[0]/24)) # slice z if z > 24
    
    product = 1/zslice
    for i in size: product *= i
    
    ref = 12*512*512
    target = math.sqrt(product / ref)
    yslice = next_power_of_2(math.floor(target)+1)
    xslice = int(yslice / 2)
    return (zslice, yslice, xslice)
        

root = tk.Tk()
root.withdraw()
root.call('wm', 'attributes', '.', '-topmost', True)

chan = int(input("Number of channels: "))
directory = tk.filedialog.askdirectory(title="Choose folder directory")

models = list()
for i in range(chan):
    i = i + 1
    modeldir = tk.filedialog.askdirectory(title="Choose model for channel " + str(i),
                                          initialdir=os.path.dirname("models/"))
    models.append(CARE(config=None, name=str.split(modeldir, "/")[-1], basedir=modeldir + "/../"))

filecounter = 0
for file in os.listdir(directory):
    if file.endswith("tif"): # Processed folder from step 1
        filecounter += 1

for file in tqdm(os.listdir(directory), total=filecounter, unit="files"):
    if file.endswith(".tif"):
        current = imread(directory + "/" + file)
        size = current.shape
        result = np.zeros(size)
        print(file)
        if chan == 1:
            result = models[0].predict(current, 'ZYX', n_tiles=subdivide())
        else:
            size = (size[0],) + size[2:] # Convert ZCYX to ZYX
            for c in range(chan):
                print('Channel ' + str(c+1))
                result[:,c,:,:] = models[c].predict(current[:,c,:,:], 'ZYX', n_tiles=subdivide())
        Path(directory + "/Restored/").mkdir(exist_ok=True)
        save_tiff_imagej_compatible(directory + "/Restored/%s" % file, result, 'ZCYX')
        clear_output(wait=True)
print("Done!")


100%|██████████| 1/1 [06:42<00:00, 402.83s/files][A
[A

Done!
