Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge #1

Merged
merged 5 commits into from
May 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Contributors

* `@FelixGruen <https://github.com/FelixGruen>`_
* `@ameya005 <https://github.com/ameya005>`_
* `@agrafix <https://github.com/agrafix>`_

Citations
---------
Expand Down
25 changes: 25 additions & 0 deletions tf_unet/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,31 @@ def __call__(self, n):

return X, Y

class SimpleDataProvider(BaseDataProvider):
"""
A simple data provider for numpy arrays.
Assumes that the data and label are numpy array with the dimensions
data `[n, X, Y, channels]`, label `[n, X, Y, classes]`. Where
`n` is the number of images, `X`, `Y` the size of the image.

:param data: data numpy array. Shape=[n, X, Y, channels]
:param label: label numpy array. Shape=[n, X, Y, classes]
:param a_min: (optional) min value used for clipping
:param a_max: (optional) max value used for clipping

"""

def __init__(self, data, label, a_min=None, a_max=None, channels=1, n_class = 2):
super(SimpleDataProvider, self).__init__(a_min, a_max)
self.data = data
self.label = label
self.file_count = data.shape[0]
self.n_class = n_class
self.channels = channels

def _next_data(self):
idx = np.random.choice(self.file_count)
return self.data[idx], self.label[idx]


class ImageDataProvider(BaseDataProvider):
Expand Down
6 changes: 5 additions & 1 deletion tf_unet/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def _initialize(self, training_iters, output_path, restore):

return init

def train(self, data_provider, output_path, training_iters=10, epochs=100, dropout=0.75, display_step=1, restore=False):
def train(self, data_provider, output_path, training_iters=10, epochs=100, dropout=0.75, display_step=1, restore=False, write_graph=False):
"""
Lauches the training process

Expand All @@ -381,6 +381,7 @@ def train(self, data_provider, output_path, training_iters=10, epochs=100, dropo
:param dropout: dropout probability
:param display_step: number of steps till outputting stats
:param restore: Flag if previous model should be restored
:param write_graph: Flag if the computation graph should be written as protobuf file to the output path
"""
save_path = os.path.join(output_path, "model.cpkt")
if epochs == 0:
Expand All @@ -389,6 +390,9 @@ def train(self, data_provider, output_path, training_iters=10, epochs=100, dropo
init = self._initialize(training_iters, output_path, restore)

with tf.Session() as sess:
if write_graph:
tf.train.write_graph(sess.graph_def, output_path, "graph.pb", False)

sess.run(init)

if restore:
Expand Down