# Generic Unet for Fiji 

In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from scipy.optimize import curve_fit
import skimage
from PIL import Image
import random
import utils
import shutil
import os

In [2]:
data='/home/ubuntu/Data/'

In [3]:
base_scaler=32

In [4]:
mean=4.8E10
std=4.48E9
x_dim=512
y_dim=512
channels=1

# Get data

Load data from file

In [5]:
both=skimage.external.tifffile.imread(data+'Training_RotShift.tif')
both=np.swapaxes(np.swapaxes(both,1,2), 2,3)

#subsetA=[j for i in range(0,100) for j in range(i*54,i*54+9)]+[j for i in range(0,100) for j in range(27+i*54,27+i*54+9)]
#subsetB=[j for i in range(0,100) for j in range(9+i*54,9+i*54+9)]+[j for i in range(0,100) for j in range(9+27+i*54,9+27+i*54+9)]
#subsetC=[j for i in range(0,100) for j in range(18+i*54,18+i*54+9)]+[j for i in range(0,100) for j in range(18+27+i*54,18+27+i*54+9)]
#subset=subsetA+subsetB+subsetC

#both=both[subset,:,:,:]

x_dim=both.shape[1]
y_dim=both.shape[2]
channels=both.shape[3]-1

In [6]:
train_data=both[:,:,:,0:-1]
train_truth=both[:,:,:,[-1]]

both=0

#non_zeros=np.where(train_data!=0)
#mean=np.mean(train_data)
#std=np.std(train_data)
mean=4.8E10
std=4.48E9


train_data=(train_data-mean)/(std*1)+0.5
train_truth[np.where(train_truth>0.1)]=1
train_truth[np.where(train_truth<0)]=0



In [7]:
validation=skimage.external.tifffile.imread(data+'Validation_annotated.tif')
validation=np.swapaxes(np.swapaxes(validation,1,2), 2,3)

validation_data=validation[:,:,:,0:-1]
validation_data=(validation_data-mean)/(std*1)+0.5

validation_truth=validation[:,:,:,[-1]]
validation_truth[np.where(validation_truth>0.1)]=1
validation_truth[np.where(validation_truth<0)]=0

In [8]:
print([mean, std])

[48000000000.0, 4480000000.0]


# Design network

Define a function to create a convolutional layer, including conv, relu, maxpool, that can be called multiple times

In [9]:
tf.reset_default_graph()
#Input and output
x=tf.placeholder(dtype=tf.float32, shape=[None, x_dim,y_dim,channels], name='x')
y=tf.placeholder(dtype=tf.float32, shape=[None, x_dim,y_dim,1], name='y')
lr=tf.placeholder(dtype=tf.float32, shape=[], name='learning_rate')
dr=tf.placeholder(dtype=tf.float32, shape=[], name='dropout_rate')

xr=tf.identity(x)
yr=tf.identity(y)

#base_scaler=32

#Going down
A1=tf.keras.layers.Conv2D(base_scaler, 5, padding='SAME', activation=utils.leaky_relu)(xr)
A2=tf.keras.layers.Conv2D(base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(A1)

B0=tf.nn.max_pool(A2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
B1=tf.keras.layers.Conv2D(2*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(B0)
B2=tf.keras.layers.Conv2D(2*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(B1)

C0=tf.nn.max_pool(B2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
C1=tf.keras.layers.Conv2D(4*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(C0)
C2=tf.keras.layers.Conv2D(4*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(C1)

D0=tf.nn.max_pool(C2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
D1=tf.keras.layers.Conv2D(8*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(D0)
D2=tf.keras.layers.Conv2D(8*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(D1)

E0=tf.nn.max_pool(D2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
E1=tf.keras.layers.Conv2D(16*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(E0)
E2=tf.keras.layers.Conv2D(16*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(E1)

#Coming up
DD0=tf.keras.layers.Conv2DTranspose(8*base_scaler, kernel_size=3, strides=2, padding='SAME')(E2)
DD1=tf.concat(axis=3, values=[DD0,D2])
DD2=tf.keras.layers.Conv2D(8*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(DD1)
DD3=tf.keras.layers.Conv2D(8*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(DD2)

CC0=tf.keras.layers.Conv2DTranspose(4*base_scaler, kernel_size=3, strides=2, padding='SAME')(DD3)
CC1=tf.concat(axis=3, values=[CC0,C2])
CC2=tf.keras.layers.Conv2D(4*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(CC1)
CC3=tf.keras.layers.Conv2D(4*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(CC2)

BB0=tf.keras.layers.Conv2DTranspose(2*base_scaler, kernel_size=3, strides=2, padding='SAME')(CC3)
BB1=tf.concat(axis=3, values=[BB0,B2])
BB2=tf.keras.layers.Conv2D(2*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(BB1)
BB3=tf.keras.layers.Conv2D(2*base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(BB2)

AA0=tf.keras.layers.Conv2DTranspose(base_scaler, kernel_size=3, strides=2, padding='SAME')(BB3)
AA1=tf.concat(axis=3, values=[AA0,A2])
AA2=tf.keras.layers.Conv2D(base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(AA1)
AA3=tf.keras.layers.Conv2D(base_scaler, 3, padding='SAME', activation=utils.leaky_relu)(AA2)

logits=tf.keras.layers.Conv2D(1, 1, padding='SAME', activation=utils.leaky_relu)(AA3)
probs=tf.tanh(logits, name='probabilities')

diff=tf.subtract(probs, yr)
LSQ=tf.multiply(diff,diff)
#Added this to make the outlines more potent in error function
#OutError, MaskError=tf.split(LSQ, [1,1], 3)
#loss=1*tf.reduce_mean(OutError)+0.1*tf.reduce_mean(MaskError)
loss=tf.reduce_mean(LSQ, name='error')
l2_loss = tf.losses.get_regularization_loss()
#loss=loss+l2_loss/1000

train_op=tf.train.AdamOptimizer(learning_rate=lr, name='trainer').minimize(loss)

tf.summary.scalar('loss', loss)
tf.summary.image('input', tf.reshape(x[:,:,:,0], [-1, x_dim, y_dim,1]), 3)
tf.summary.image('standard', tf.reshape(y[:,:,:,0], [-1, x_dim, y_dim,1]) , 3)
tf.summary.image('outline', tf.reshape(probs[:,:,:,0], [-1, x_dim, y_dim,1]), 3)



merge = tf.summary.merge_all()

Instructions for updating:
Colocations handled automatically by placer.


# Training

In [1]:
shutil.rmtree('./logs/')
tf.set_random_seed(123456)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=50)
test_writer = tf.summary.FileWriter('./logs/170/test')
current_best=[0,1000000000]


learning_rates=[0.000195, 0.0001, 0.00002, 0.00001]
learning_rate_steps=[500,1500, 2200, 6000]
current_step=0
for lrate, lrs in zip(learning_rates, learning_rate_steps):
    for i in range(current_step, lrs):
        idx=np.random.choice(train_data.shape[0], replace=False, size=[15])
        cur_train=train_data[idx,:,:,:]+np.random.uniform(-0.3, 0.3, 1)
        cur_truth=train_truth[idx,:,:]
        _, results, losses=sess.run([train_op,  probs, loss], feed_dict={x:cur_train, y:cur_truth, lr:lrate})
        
        #utils.plot_3x1(cur_train[0,:,:,0], cur_truth[0,:,:,0], results[0,:,:,0])
        #plt.show()
        
        #train_writer.add_summary(summary, i)
        if (i%100==0):
            print(i)
            print("Training loss: ",losses)
            #idx=np.random.choice(validation_data.shape[0], replace=False, size=[50])
            idx=range(0,3, 1)
            sub_validation_data=validation_data[idx, :,:,:]
            sub_validation_truth=validation_truth[idx, :,:]
            summary, results, losses=sess.run([merge, probs, loss], feed_dict={x:sub_validation_data, y:sub_validation_truth})
            test_writer.add_summary(summary, i)
            print(results.shape)
            print("Validation loss: ",losses)
            for ti in range (0,3):
                utils.plot_3x1(sub_validation_data[ti,:,:,0], sub_validation_truth[ti,:,:,0], results[ti,:,:,0])
                plt.show()
            saver.save(sess, data+'NewModels/Model'+str(i))
    current_step=lrs

NameError: name 'shutil' is not defined

In [1]:
tf.train.write_graph(sess.graph_def, 'Data/NewModels/', 'saved_3100.pb', as_text=False)

NameError: name 'tf' is not defined

In [None]:
idx=np.random.choice(validation_data.shape[0], replace=False, size=[25])
cur_train=train_data[idx,:,:,:]
cur_truth=train_truth[idx,:,:]
results, losses=sess.run([probs, loss], feed_dict={x:cur_train, y:cur_truth, lr:lrate})

ids=4
utils.plot_3x1(cur_train[ids,:,:,0], cur_truth[ids,:,:,0], results[ids,:,:,0])

# Process new datafiles using trained network

Load the network that works best

In [None]:
saver = tf.train.Saver()
sess=tf.Session()
sess.run(tf.global_variables_initializer())

data_model=data+'v1/Model'+str(4300)
saver.restore(sess,data_model)

Function for processing a file

In [None]:
def process_file(sess, file_path, save_dir):
    file_data=skimage.external.tifffile.imread(file_path)
    if (len(file_data.shape)==3):
        file_data=np.reshape(file_data, [file_data.shape[0], x_dim, x_dim, 1])
    else:
        file_data=np.swapaxes(np.swapaxes(file_data,1,2), 2,3)
    inference_data=file_data
    
    #true_test2_data=true_test2_data/np.std(true_test2_data)*0.02
    #true_test2_data=true_test2_data-np.mean(true_test2_data)+0.55
    #v1true_test2_data=(true_test2_data-.578)/.138+0.5
    inference_data=(inference_data-mean)/std+0.5

    channels=inference_data.shape[3]
    num_images=inference_data.shape[0]
    output=np.zeros([num_images,x_dim,x_dim,channels+1])
    process_batch_size=20
    print('Starting')
    for t in range(0,num_images,process_batch_size):
        endrng=np.min((t+process_batch_size,num_images))
        inference_batch_data=inference_data[t:(t+process_batch_size), :,:,:]
        results=sess.run(probs, feed_dict={x:inference_batch_data})
        output[t:(t+process_batch_size),:,:,0:channels]=inference_batch_data[:,:,:,0:channels]
        output[t:(t+process_batch_size),:,:,channels]=results[:,:,:,0]
    print('Done')
    inference_data=0
    output[:,:,:,0:channels]=file_data[:,:,:,0:channels]
    output=np.swapaxes(np.swapaxes(output,3,2),2,1)
    np.place(output, output<0, 0)
    file_name=file_path.split('/')[-1]
    print(output.dtype)
    skimage.external.tifffile.imsave(save_dir+file_name, output.astype('float32'), imagej=True)
    print('Written')
    return output


In [None]:
import glob
for f in glob.glob('/n/projects/smc/public/STN/19-6-3b_ali_level_fullsize/DeepLearn/Data/*.tif'):
    output=process_file(sess, f, data+'v1/Output/')

In [None]:
ids=4
utils.plot_3x1(output[ids,0,:,:], output[ids,1,:,:], output[ids,1,:,:])

# Retrain using new data and best old model

In [None]:
if os.path.isdir(data+'NewModels/'):
    os.rename(data+'NewModels/', data+'OldModels/')

In [None]:
both=skimage.external.tifffile.imread(data+'Training_retrain_RotShift.tif')
both=np.swapaxes(np.swapaxes(both,1,2), 2,3)

x_dim=both.shape[1]
y_dim=both.shape[2]
channels=both.shape[3]-2

In [None]:
file=open(data+'Network.txt')
base_scaler=int(file.readline())
baseline_noise=float(file.readline())
x_dim=int(file.readline())
mean=float(file.readline())
std=float(file.readline())
model=int(file.readline())
channels=int(file.readline())
file.close()

In [None]:
train_data=both[:,:,:,0:-2]/655350.0
train_truth=both[:,:,:,[-2,-1]]

train_data=(train_data-mean)/(std*1)+0.5
train_truth[np.where(train_truth>0.1)]=1

validation=skimage.external.tifffile.imread(data+'Validation_annotated.tif')
validation=np.swapaxes(np.swapaxes(validation,1,2), 2,3)

validation_data=validation[:,:,:,0:-2]/655350.0
validation_data=(validation_data-mean)/(std*1)+0.5

validation_truth=validation[:,:,:,[-2,-1]]
validation_truth[np.where(validation_truth>0.1)]=1


In [None]:
saver = tf.train.Saver()
sess=tf.Session()
sess.run(tf.global_variables_initializer())

#data_model=data+'OldModels/Model'+str(model)
data_model=data+'OldModels/Model'+str(1500)
saver.restore(sess,data_model)

In [None]:
shutil.rmtree('./logs/')
tf.set_random_seed(123456)
saver = tf.train.Saver(max_to_keep=50)
#train_writer = tf.summary.FileWriter('./logs/1/train ', sess.graph)
test_writer = tf.summary.FileWriter('./logs/170/test')
current_best=[0,1000000000]


learning_rates=[0.00005, 0.00002, 0.00002, 0.00001]
learning_rate_steps=[500,1500, 2200, 2200]
current_step=0
for lrate, lrs in zip(learning_rates, learning_rate_steps):
    for i in range(current_step, lrs):
        idx=np.random.choice(train_data.shape[0], replace=False, size=[25])
        cur_train=train_data[idx,:,:,:]+np.random.uniform(-baseline_noise, baseline_noise, 1)
        cur_truth=train_truth[idx,:,:]
        _, results, losses=sess.run([train_op,  probs, loss], feed_dict={x:cur_train, y:cur_truth, lr:lrate})
        #train_writer.add_summary(summary, i)
        if (i%100==0):
            print(i)
            print("Training loss: ",losses)
            #idx=np.random.choice(validation_data.shape[0], replace=False, size=[50])
            idx=range(0,3, 1)
            sub_validation_data=validation_data[idx, :,:,:]
            sub_validation_truth=validation_truth[idx, :,:]
            summary, results, losses=sess.run([merge, probs, loss], feed_dict={x:sub_validation_data, y:sub_validation_truth})
            test_writer.add_summary(summary, i)
            print(results.shape)
            print("Validation loss: ",losses)
            if (losses<current_best[1]):
                current_best=[i, losses]
                file=open(data+'Network.txt', 'w')
                file.write(str(base_scaler)+'\n')
                file.write(str(baseline_noise)+'\n')
                file.write(str(x_dim)+'\n')
                file.write(str(mean)+'\n')
                file.write(str(std)+'\n')
                file.write(str(i)+'\n')
                file.write(str(channels)+'\n')
                file.close()
            for ti in range (0,3):
                utils.plot_3x1(sub_validation_data[ti,:,:,0], results[ti,:,:,0], results[ti,:,:,1])
                plt.show()
            saver.save(sess, data+'NewModels/Model'+str(i))
    current_step=lrs