In [1]:
%matplotlib inline
from fns import *
from syntheticdata import synthetic_generation

In [2]:
def create_conv_net(x, keep_prob, channels, n_class, layers=3, features_root=16, filter_size=3, pool_size=2, summaries=True):
    """
    Creates a new convolutional unet for the given parametrization.
    
    :param x: input tensor, shape [?,nx,ny,nz,channels]
    :param keep_prob: dropout probability tensor
    :param channels: number of channels in the input image
    :param n_class: number of output labels
    :param layers: number of layers in the net
    :param features_root: number of features in the first layer
    :param filter_size: size of the convolution filter
    :param pool_size: size of the max pooling operation
    :param summaries: Flag if summaries should be created
    """
    
    logging.info("Layers: {layers}, FeaturesRoot: {features_root}, ConvolutionSize: {filter_size}*{filter_size}*{filter_size}, PoolingSize: {pool_size}*{pool_size}*{pool_size}"\
                 .format(layers=layers, features_root=features_root, filter_size=filter_size, pool_size=pool_size))
    
    in_node = x
    batch_size = tf.shape(x)[0] 
    weights = []
    biases = []
    convs = []
    pools = OrderedDict()
    deconv = OrderedDict()
    dw_h_convs = OrderedDict()
    up_h_convs = OrderedDict()
    in_size = 144
    size = in_size

   
    # down layers
    with tf.name_scope('going_down'):
        for layer in range(0, layers):
            with tf.name_scope('layer_down_%d'%layer):
                features = 2**layer*features_root
                stddev = 1 / (filter_size**3 * features)
                if layer == 0:
                    w1 = weight_variable([filter_size, filter_size, filter_size, channels, features], stddev)
                else:
                    w1 = weight_variable([filter_size, filter_size, filter_size, features//2, features], stddev)
                w2 = weight_variable([filter_size, filter_size, filter_size, features, features], stddev)
                b1 = bias_variable([features])
                b2 = bias_variable([features])
                
                conv1 = conv3d(in_node, w1, keep_prob)
                tmp_h_conv = tf.nn.elu(conv1 + b1)
                conv2 = conv3d(tmp_h_conv, w2, keep_prob)
                dw_h_convs[layer] = tf.nn.elu(conv2 + b2)
                
                logging.info("Down Convoltion Layer: {layer} Size: {size}".format(layer=layer,size=dw_h_convs[layer].get_shape()))
                
                weights.append((w1, w2))
                biases.append((b1, b2))
                convs.append((conv1, conv2))

                size -= 4    
                if layer < layers-1:
                    pools[layer] = max_pool(dw_h_convs[layer], pool_size)
                    in_node = pools[layer]
                    size /= 2    
        
    in_node = dw_h_convs[layers-1]
        
    # up layers
    with tf.name_scope('going_up'):
        for layer in range(layers-2, -1, -1):   
            with tf.name_scope('layer_up_%d'%layer):
                features = 2**(layer+1)*features_root
                stddev = 1 / (filter_size**3 * features)

                wd = weight_variable_devonc([pool_size, pool_size, pool_size, features//2, features], stddev)
                bd = bias_variable([features//2])
                h_deconv = tf.nn.elu(deconv3d(in_node, wd, pool_size) + bd)
                h_deconv_concat = crop_and_concat(dw_h_convs[layer], h_deconv)    
                deconv[layer] = h_deconv_concat

                w1 = weight_variable([filter_size, filter_size, filter_size, features, features//2], stddev)
                w2 = weight_variable([filter_size, filter_size, filter_size, features//2, features//2], stddev)
                b1 = bias_variable([features//2])
                b2 = bias_variable([features//2])

                conv1 = conv3d(h_deconv_concat, w1, keep_prob)
                h_conv = tf.nn.elu(conv1 + b1)
                conv2 = conv3d(h_conv, w2, keep_prob)
                in_node = tf.nn.elu(conv2 + b2)
                up_h_convs[layer] = in_node
                
                logging.info("Up Convoltion Layer: {layer} Size: {size}".format(layer=layer,
                                                                                size=tf.shape(dw_h_convs[layer])))
                
                weights.append((w1, w2))
                biases.append((b1, b2))
                convs.append((conv1, conv2))

                size *= 2
                size -= 4

    # Output Map
    with tf.name_scope('output_map'):
        #stddev = 1 / (features_root)
        weight = weight_variable([1, 1, 1, features_root, 1], stddev)
        bias = bias_variable([1])
        conv = conv3d(in_node, weight, tf.constant(1.0))
        output_map = tf.nn.sigmoid(conv + bias)
        up_h_convs["out"] = output_map
        logging.info("Output map shape {size}, offset {offset}".format(size=output_map.get_shape(), offset=int(in_size-size)))

        if summaries:
#             for i, (c1, c2) in enumerate(convs):
#                 tf.summary.image('summary_conv_%03d_01'%i, get_image_summary(c1))
#                 tf.summary.image('summary_conv_%03d_02'%i, get_image_summary(c2))

#             for k in pools.keys():
#                 tf.summary.image('summary_pool_%03d'%k, get_image_summary(pools[k]))

#             for k in deconv.keys():
#                 tf.summary.image('summary_deconv_concat_%03d'%k, get_image_summary(deconv[k]))

            for k in dw_h_convs.keys():
                tf.summary.histogram("dw_convolution_%03d"%k + '/activations', dw_h_convs[k])

            for k in up_h_convs.keys():
                tf.summary.histogram("up_convolution_%s"%k + '/activations', up_h_convs[k])

        variables = []
        for w1,w2 in weights:
            variables.append(w1)
            variables.append(w2)

        for b1,b2 in biases:
            variables.append(b1)
            variables.append(b2)
        
        variables.append(weight)
        variables.append(bias)
        
    return output_map, variables, int(in_size - size)

In [3]:
class Unet(object):
    """
    A unet implementation

    :param channels: (optional) number of channels in the input image
    :param n_class: (optional) number of output labels
    :param cost: (optional) name of the cost function. Default is 'cross_entropy'
    :param cost_kwargs: (optional) kwargs passed to the cost function. See Unet._get_cost for more options
    """

    def __init__(self, channels=1, n_class=1, cost="dice_coefficient", predict_thresh=0.5, cost_kwargs={}, **kwargs):
        tf.reset_default_graph()

        self.n_class = n_class
        self.summaries = kwargs.get("summaries", True)

        self.x = tf.placeholder(tf.float32, shape=[None, None, None, None, channels], name='data')
        self.y = tf.placeholder(tf.float32, shape=[None, None, None, None, n_class], name='target')
        self.keep_prob = tf.placeholder(tf.float32)  # dropout (keep probability)

        logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs)
        logging.info("Actual Output Shape: {}".format(logits.get_shape()))
        logging.info("Desired Output Shape: {}".format(self.y.get_shape()))

        self.logits = logits
        self.predicter = self.logits
        self.predicter_label = tf.cast(self.predicter >= predict_thresh, tf.float32)
        self.correct_pred = tf.cast(
            tf.equal(tf.reshape(self.predicter_label, [-1, n_class]), tf.reshape(self.y, [-1, n_class])), tf.float32)
        self.cost = self._get_cost(self.logits, self.predicter, cost, cost_kwargs)
        self.gradients_node = tf.gradients(self.cost, self.variables)
        self.cross_entropy = tf.reduce_mean(
            cross_entropy(tf.reshape(self.y, [-1, n_class], name='cross_entro_label_reshape'),
                          tf.reshape(pixel_wise_softmax_2(logits), [-1, n_class], name='px_logit_reshape')))
        self.accuracy = tf.reduce_mean(self.correct_pred)

    def _get_cost(self, logits, predicter, cost_name, cost_kwargs):
        """
        Constructs the cost function, either cross_entropy, weighted cross_entropy or dice_coefficient.
        Optional arguments are:
        class_weights: weights for the different classes in case of multi-class imbalance
        regularizer: power of the L2 regularizers added to the loss function
        """
        with tf.name_scope('cost_function'):
            logging.info('*' * 50)
            logging.info('getting cost')
            logging.info("Logits: {}".format(logits.get_shape()))
            logging.info("Y: {}".format(self.y.get_shape()))
            flat_logits = tf.reshape(logits, [-1, self.n_class], name='flat_logits_reshape')
            flat_predicter = tf.reshape(predicter, [-1, self.n_class], name='flat_predicter_reshape')
            flat_labels = tf.reshape(self.y, [-1, self.n_class], name='flat_labels_reshape')
            if cost_name == "cross_entropy":
                class_weights = cost_kwargs.pop("class_weights", None)

                if class_weights is not None:
                    class_weights = tf.constant(np.array(class_weights, dtype=np.float32))

                    weight_map = tf.multiply(flat_labels, class_weights, name='weightmap')
                    weight_map = tf.reduce_sum(weight_map, axis=1)

                    loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)
                    weighted_loss = tf.multiply(loss_map, weight_map, name='weightloss')

                    loss = tf.reduce_mean(weighted_loss)

                else:
                    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
                                                                                  labels=flat_labels))
            elif cost_name == "dice_coefficient":
                intersection = tf.reduce_sum(flat_predicter * flat_labels)
                union = tf.reduce_sum(flat_predicter) + tf.reduce_sum(flat_labels)
                loss = 1 - 2 * intersection / union

            else:
                raise ValueError("Unknown cost function: " % cost_name)

            regularizer = cost_kwargs.pop("regularizer", None)
            if regularizer is not None:
                regularizers = sum([tf.nn.l2_loss(variable) for variable in self.variables])
                loss += (regularizer * regularizers)

        return loss

    def predict(self, model_path, x_test):
        """
        Uses the model to create a prediction for the given data

        :param model_path: path to the model checkpoint to restore
        :param x_test: Data to predict on. Shape [n, nx, ny, nz, channels]
        :returns prediction: The unet prediction Shape [n, px, py, pz, labels] (px=nx-self.offset/2)
        """

        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            # Initialize variables
            sess.run(init)

            # Restore model weights from previously saved model
            self.restore(sess, model_path)

            y_dummy = np.empty((x_test.shape[0], x_test.shape[1], x_test.shape[2], x_test.shape[3], self.n_class))
            prediction = sess.run(self.predicter_label, feed_dict={self.x: x_test, self.y: y_dummy, self.keep_prob: 1.})

        return prediction

    def save(self, sess, model_path):
        """
        Saves the current session to a checkpoint

        :param sess: current session
        :param model_path: path to file system location
        """

        saver = tf.train.Saver()
        save_path = saver.save(sess, model_path)
        return save_path

    def restore(self, sess, model_path):
        """
        Restores a session from a checkpoint

        :param sess: current session instance
        :param model_path: path to file system checkpoint location
        """

        saver = tf.train.Saver()
        saver.restore(sess, model_path)
        logging.info("Model restored from file: %s" % model_path)

## setting up the unet

In [4]:
net = Unet(channels=1, 
           n_class=1, 
           layers=4, 
           pool_size=2,
           features_root=16, summaries=True,
          )

## training

In [5]:
# data_provider = ImageDataProvider(array=False)
# trainer = Trainer(net, batch_size=3, optimizer="adam")
# path = trainer.train(data_provider, 
#                      "./unet_trained",
#                      training_array = None,
#                      validation_array = None,
#                      testing_array = None,
#                      training_iters=37, 
#                      epochs=100, 
#                      dropout=0.6, 
#                      restore= True,
#                      display_step=1)

### Predict

In [18]:
provider = ImageDataProvider()
# testing_data, label_data = provider._load_data_and_label(provider.testing_data_files,3)

testing_data, label_data = provider._load_data_and_label(provider.training_data_files,10)


Number of training data used: 37
Number of validation data used: 7
Number of testing data used: 3


In [19]:
testing_data.shape

(10, 148, 148, 148, 1)

In [20]:
provider.testing_data_files

['/mnt/DATA/gp1514/Dropbox/2016-paolo/preprocessed_data/LabelMapsNEW2_1.00-1.00-1.00_clahe/028/case.nrrd',
 '/mnt/DATA/gp1514/Dropbox/2016-paolo/preprocessed_data/LabelMapsNEW2_1.00-1.00-1.00_clahe/051/case.nrrd',
 '/mnt/DATA/gp1514/Dropbox/2016-paolo/preprocessed_data/LabelMapsNEW2_1.00-1.00-1.00_clahe/054/case.nrrd']

In [38]:
i = 8
prediction = net.predict("./unet_trained/model 6.cpkt", testing_data[i][np.newaxis,...])[0][:,:,:,0]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


In [41]:
print(np.unique(prediction, return_counts=True))
print(prediction.shape)
print(label_data.shape)

(array([ 0.,  1.], dtype=float32), array([213416,   2584]))
(60, 60, 60)
(10, 148, 148, 148, 1)


In [40]:
%matplotlib notebook
xs,ys,zs = np.where(prediction == 1)

fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(xs+44, ys+44, zs+44, marker='o', alpha=0.3, s=5)
plt.show()

# fig = plt.figure(figsize=(6,6))
# ax = fig.add_subplot(111, projection='3d')
xs,ys,zs = np.where(label_data[i, ...,0] == 1)
ax.scatter(xs, ys, zs, marker='o',color='g', alpha=0.1, s=5)
plt.show()

<IPython.core.display.Javascript object>

In [77]:
from skimage import measure
from skimage import filters

In [29]:
np.random.seed(1)
islands = measure.label(prediction)
K = np.max(islands)
cp =sns.color_palette("Set2", K)
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
for j in range(1,K):
    xs,ys,zs = np.where(islands == j)
    ax.scatter(xs, ys, zs, marker='o',color=cp[j], alpha=0.9, s=2)
plt.show()

<IPython.core.display.Javascript object>

In [30]:
np.random.seed(1)
islands = measure.label(label_data[i,...,0])
K = np.max(islands)
cp =sns.color_palette("Set2", K)
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
for j in range(1,K):
    xs,ys,zs = np.where(islands == j)
    ax.scatter(xs, ys, zs, marker='o',color=cp[j], alpha=0.9, s=2)
plt.show()

<IPython.core.display.Javascript object>

In [23]:
np.alltrue(blobs_labels==all_labels)

True

In [14]:
import pyximport; pyximport.install()
from fns.set_metrics import *

In [15]:
print(prediction.shape)
print(label_data[0,...,0].shape)
lab = crop_to_shape(label_data[...,0], prediction[np.newaxis,...].shape)
print(lab[0].shape)

(60, 60, 60)
(148, 148, 148)
(60, 60, 60)


In [16]:
labval = np.where(lab[0]!=False)
predval = np.where(prediction!=1)
# hausdorff_distance(np.array(labval),np.array(predval))
np.array(labval).shape
np.array(predval).shape

(3, 209990)

In [264]:
image_name = '/mnt/DATA/gp1514/Dropbox/2016-paolo/preprocessed_data/LabelMapsNEW2_1.00-1.00-1.00/070/case.nrrd'
label_name = '/mnt/DATA/gp1514/Dropbox/2016-paolo/preprocessed_data/LabelMapsNEW2_1.00-1.00-1.00/070/needles.nrrd'
# img = provider._load_file(image_name, np.float32, padding="noise")
label = provider._load_file(label_name, np.bool, padding="zero")

# data = provider._process_data(img)
label = provider._process_labels(label)

data, label = provider._post_process(data, label)

# nx = data.shape[0]
# ny = data.shape[1]
# nz = data.shape[2]

# data, label = data.reshape(1, nx, ny, nz, provider.channels), label.reshape(1, nx, ny, nz, provider.n_class)

In [265]:
# print(img.shape)
# print(data.shape)

In [266]:
# data.shape

In [267]:
tiles = (148,148,148)
tile = 148

In [268]:
data, options = nrrd.read(image_name)
data = data.astype(np.float32)
print(data.shape)
d = data.resize(max(data.shape[0],tile),
               max(data.shape[1],tile),
               max(data.shape[2],tile))
print(data.shape)
print(options)

(300, 195, 230)
(300, 195, 230)
{'keyvaluepairs': {}, 'type': 'unsigned short', 'dimension': 3, 'space': 'left-posterior-superior', 'sizes': [300, 195, 230], 'space directions': [['1', '0', '0'], ['0', '1', '0'], ['0', '0', '1']], 'kinds': ['domain', 'domain', 'domain'], 'endian': 'little', 'encoding': 'gzip', 'space origin': ['-148.06295776367188', '-90.153846740722656', '-145.76899719238281']}


In [269]:
def getpad(size,block):
    if size%block:
        pad = block*(size//block+1) - size
    else:
        pad = 0
    return pad

def getpads(sizes, blocks):
    pads = []
    for i in range(3):
        print(sizes, blocks)
        pads.append([0,getpad(sizes[i], blocks[i])])
    return pads

def tiler(input, tile_shape):
    input = input[np.newaxis,...]
    ts = tile_shape
    input_shape = input.get_shape().as_list()
    print(input_shape)
    paddings = getpads(input_shape[:], tile_shape)
    batch = tf.space_to_batch_nd(input, tile_shape, paddings, name=None)
    batch = tf.transpose(batch, [3,1,2,0,4])
    batch = tf.reshape(batch, (-1, ts[0], ts[1], ts[2], input_shape[-1]))
    return batch


def reshapeData(batch_x, batch_y, tile_shape=(148,148,148), keepNoNeedle=True, start=0, data_provider=ImageDataProvider()):
    logits, labels = [], []
    tf.reset_default_graph()
    with tf.Session() as sess:
        # initialize the graph
        batch_x = tf.Variable(batch_x)
        batch_y = tf.Variable(batch_y)
        tf.global_variables_initializer().run()

        resx = tiler(batch_x, tile_shape).eval()
        resy = tiler(batch_y, tile_shape).eval()
        if keepNoNeedle or np.max(resy)!=0:
            logits.append(resx)
            labels.append(resy)
    logits = np.array(logits)
    labels = np.array(labels)
    return np.concatenate(logits, axis=0), np.concatenate(labels, axis=0)

def reshapeDataset(nbOfCases, tile_shape=(100,100,100), keepNoNeedle=False, start=0, data_provider=ImageDataProvider()):
    logits, labels = [], []
    tf.reset_default_graph()
    with tf.Session() as sess:
        # initialize the graph
        for i in range(nbOfCases):
            batch_x, batch_y = data_provider(1)
            batch_x = tf.Variable(batch_x)
            batch_y = tf.Variable(batch_y)
            tf.global_variables_initializer().run()

            resx = tiler(batch_x, (100,100,100)).eval()
            resy = tiler(batch_y, (100,100,100)).eval()
            if keepNoNeedle or np.max(resy)!=0:
                logits.append(resx)
                labels.append(resy)
    logits = np.array(logits)
    labels = np.array(labels)
    return np.concatenate(logits, axis=0), np.concatenate(labels, axis=0)

Number of training data used: 37
Number of validation data used: 7
Number of testing data used: 3
Number of training data used: 37
Number of validation data used: 7
Number of testing data used: 3


In [270]:
data.shape
data = np.pad(data,((44,44),(44,44), (44,44)), mode='mean')
data.shape

(388, 283, 318)

In [271]:
Mx, My, Mz = data.shape
# Mx,My,Mz=150,150,150
tile_in = 60
kx = Mx//tile_in + 1*((Mx%tile_in)>0)
ky = Mx//tile_in + 1*((My%tile_in)>0)
kz = Mz//tile_in + 1*((Mz%tile_in)>0)
print(Mx,My,Mz)
print(kx,ky,kz)

388 283 318
7 7 6


In [273]:


off_x = 60
off_y = 60
off_z = 60


print(off_x, off_y, off_z)

arr_data = []
nbTiles = 0
for i in range(kx):
    for j in range(ky):
        for k in range(kz):
            x = min(off_x*i, Mx - tile)
            y = min(off_y*j, My - tile)
            z = min(off_z*k, Mz - tile)
            x = np.int(x)
            y = np.int(y)
            z = np.int(z)
            print(x,y,z)
            data_s = data[x : x + tile, y : y + tile, z : z + tile ]
#             print(data_s.shape)
            arr_data.append(data_s)
            nbTiles += 1
            if (off_z*k+1) > (Mz - tile):
                break
        if (off_y*j+1) > (My - tile):
                break
    if (off_x*i+1) > (Mx - tile):
                break
print("number of tiles: %d " % nbTiles)
arr_data = np.array(arr_data)            

60 60 60
0 0 0
0 0 60
0 0 120
0 0 170
0 60 0
0 60 60
0 60 120
0 60 170
0 120 0
0 120 60
0 120 120
0 120 170
0 135 0
0 135 60
0 135 120
0 135 170
60 0 0
60 0 60
60 0 120
60 0 170
60 60 0
60 60 60
60 60 120
60 60 170
60 120 0
60 120 60
60 120 120
60 120 170
60 135 0
60 135 60
60 135 120
60 135 170
120 0 0
120 0 60
120 0 120
120 0 170
120 60 0
120 60 60
120 60 120
120 60 170
120 120 0
120 120 60
120 120 120
120 120 170
120 135 0
120 135 60
120 135 120
120 135 170
180 0 0
180 0 60
180 0 120
180 0 170
180 60 0
180 60 60
180 60 120
180 60 170
180 120 0
180 120 60
180 120 120
180 120 170
180 135 0
180 135 60
180 135 120
180 135 170
240 0 0
240 0 60
240 0 120
240 0 170
240 60 0
240 60 60
240 60 120
240 60 170
240 120 0
240 120 60
240 120 120
240 120 170
240 135 0
240 135 60
240 135 120
240 135 170
number of tiles: 80 


In [274]:
arr_data[1].shape
#input shape size required 1,148,148,148,1

(148, 148, 148)

In [275]:
arr_out = []
for i in trange(arr_data.shape[0]):
    img = arr_data[i]
    img = img[np.newaxis,...,np.newaxis]
    out = net.predict("./unet_trained/model 6.cpkt", img)[0][:,:,:,0]
    out_p = np.pad(out,((44,44),(44,44), (44,44)), mode='constant', constant_values=[0])
    arr_out.append(out_p)


  0%|          | 0/80 [00:00<?, ?it/s][A

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt



  1%|▏         | 1/80 [00:02<03:17,  2.50s/it][A

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt



  2%|▎         | 2/80 [00:04<03:10,  2.44s/it][A


INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


  4%|▍         | 3/80 [00:07<03:05,  2.41s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


  5%|▌         | 4/80 [00:09<03:02,  2.40s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


  6%|▋         | 5/80 [00:11<02:59,  2.39s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


  8%|▊         | 6/80 [00:14<02:57,  2.40s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


  9%|▉         | 7/80 [00:16<02:55,  2.40s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 10%|█         | 8/80 [00:19<02:53,  2.42s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 11%|█▏        | 9/80 [00:21<02:56,  2.48s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 12%|█▎        | 10/80 [00:24<02:52,  2.47s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 14%|█▍        | 11/80 [00:26<02:50,  2.47s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 15%|█▌        | 12/80 [00:29<02:49,  2.49s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 16%|█▋        | 13/80 [00:31<02:49,  2.52s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 18%|█▊        | 14/80 [00:34<02:47,  2.54s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 19%|█▉        | 15/80 [00:36<02:44,  2.54s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 20%|██        | 16/80 [00:39<02:42,  2.54s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 21%|██▏       | 17/80 [00:42<02:40,  2.54s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 22%|██▎       | 18/80 [00:45<02:48,  2.71s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 24%|██▍       | 19/80 [00:47<02:46,  2.73s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 25%|██▌       | 20/80 [00:50<02:44,  2.74s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 26%|██▋       | 21/80 [00:53<02:42,  2.75s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 28%|██▊       | 22/80 [00:56<02:42,  2.79s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 29%|██▉       | 23/80 [00:59<02:39,  2.80s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 30%|███       | 24/80 [01:02<02:37,  2.82s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 31%|███▏      | 25/80 [01:04<02:36,  2.84s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 32%|███▎      | 26/80 [01:07<02:32,  2.83s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 34%|███▍      | 27/80 [01:10<02:31,  2.86s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 35%|███▌      | 28/80 [01:13<02:31,  2.91s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 36%|███▋      | 29/80 [01:16<02:29,  2.92s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 38%|███▊      | 30/80 [01:19<02:26,  2.93s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 39%|███▉      | 31/80 [01:22<02:23,  2.94s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 40%|████      | 32/80 [01:25<02:22,  2.96s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 41%|████▏     | 33/80 [01:28<02:20,  2.98s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 42%|████▎     | 34/80 [01:31<02:15,  2.95s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 44%|████▍     | 35/80 [01:34<02:13,  2.96s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 45%|████▌     | 36/80 [01:37<02:11,  3.00s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 46%|████▋     | 37/80 [01:40<02:09,  3.02s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 48%|████▊     | 38/80 [01:43<02:06,  3.02s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 49%|████▉     | 39/80 [01:46<02:05,  3.06s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 50%|█████     | 40/80 [01:49<02:02,  3.05s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 51%|█████▏    | 41/80 [01:52<01:58,  3.04s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 52%|█████▎    | 42/80 [01:55<01:55,  3.05s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 54%|█████▍    | 43/80 [01:58<01:53,  3.06s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 55%|█████▌    | 44/80 [02:02<01:50,  3.07s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 56%|█████▋    | 45/80 [02:05<01:48,  3.09s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 57%|█████▊    | 46/80 [02:08<01:45,  3.09s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 59%|█████▉    | 47/80 [02:11<01:41,  3.07s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 60%|██████    | 48/80 [02:14<01:37,  3.05s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 61%|██████▏   | 49/80 [02:17<01:35,  3.08s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 62%|██████▎   | 50/80 [02:20<01:33,  3.11s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 64%|██████▍   | 51/80 [02:23<01:30,  3.13s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 65%|██████▌   | 52/80 [02:26<01:27,  3.12s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 66%|██████▋   | 53/80 [02:30<01:24,  3.13s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 68%|██████▊   | 54/80 [02:33<01:21,  3.14s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 69%|██████▉   | 55/80 [02:36<01:19,  3.18s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 70%|███████   | 56/80 [02:39<01:16,  3.20s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 71%|███████▏  | 57/80 [02:43<01:13,  3.21s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 72%|███████▎  | 58/80 [02:46<01:10,  3.19s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 74%|███████▍  | 59/80 [02:49<01:07,  3.20s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 75%|███████▌  | 60/80 [02:52<01:04,  3.20s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 76%|███████▋  | 61/80 [02:55<01:01,  3.22s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 78%|███████▊  | 62/80 [02:59<00:58,  3.24s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 79%|███████▉  | 63/80 [03:02<00:55,  3.27s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 80%|████████  | 64/80 [03:05<00:52,  3.30s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 81%|████████▏ | 65/80 [03:09<00:49,  3.31s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 82%|████████▎ | 66/80 [03:12<00:46,  3.32s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 84%|████████▍ | 67/80 [03:15<00:43,  3.32s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 85%|████████▌ | 68/80 [03:19<00:40,  3.34s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 86%|████████▋ | 69/80 [03:22<00:37,  3.37s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 88%|████████▊ | 70/80 [03:26<00:33,  3.40s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 89%|████████▉ | 71/80 [03:29<00:30,  3.41s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 90%|█████████ | 72/80 [03:32<00:27,  3.40s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 91%|█████████▏| 73/80 [03:36<00:23,  3.42s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 92%|█████████▎| 74/80 [03:39<00:20,  3.43s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 94%|█████████▍| 75/80 [03:43<00:17,  3.41s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 95%|█████████▌| 76/80 [03:46<00:13,  3.41s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 96%|█████████▋| 77/80 [03:50<00:10,  3.42s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 98%|█████████▊| 78/80 [03:53<00:07,  3.52s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


 99%|█████████▉| 79/80 [03:57<00:03,  3.50s/it]

INFO:tensorflow:Restoring parameters from ./unet_trained/model 6.cpkt


100%|██████████| 80/80 [04:00<00:00,  3.51s/it]


In [276]:
data = np.zeros((Mx, My, Mz))
l=-1
# for i in range(kx):
#     for j in range(ky):
#         for k in range(kz):
#             l+=1
#             x = tile*i - off_x*i
#             y = tile*j - off_y*j
#             z = tile*k - off_z*k
#             x = np.int(x)
#             y = np.int(y)
#             z = np.int(z)
#             data[x : x + tile, y : y + tile, z : z + tile ] += arr_out[l]
            
            
            
            
for i in range(kx):
    for j in range(ky):
        for k in range(kz):
            l+=1
            x = min(off_x*i, Mx - tile)
            y = min(off_y*j, My - tile)
            z = min(off_z*k, Mz - tile)
            x = np.int(x)
            y = np.int(y)
            z = np.int(z)
            data[x : x + tile, y : y + tile, z : z + tile ] += arr_out[l]
            if (off_z*k+1) > (Mz - tile):
                break
        if (off_y*j+1) > (My - tile):
                break
    if (off_x*i+1) > (Mx - tile):
                break
            
            
data = np.array(data)
# data[np.where(data<l//2)]=0
# data[np.where(data>=l//2)]=1
data = data.astype(np.int8)
data=data[44:-44,44:-44,44:-44]
print(np.unique(data, return_counts=True))
print(data.shape)

(array([0, 1, 2, 3, 4], dtype=int8), array([13159265,   234036,    60829,        1,      869]))
(300, 195, 230)


In [195]:
islands = measure.label(data)
K = np.max(islands)
cp =sns.color_palette("Set2", K)
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
for j in range(1,K):
    xs,ys,zs = np.where(islands == j)
    ax.scatter(xs, ys, zs, marker='o',color=cp[j], alpha=0.9, s=2)
plt.show()

<IPython.core.display.Javascript object>

In [277]:
nrrd.write('test70.nrrd', data, options=options)

In [197]:
data.shape

(300, 195, 230)