Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unet 3D #158

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions .idea/tf_unet.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

965 changes: 965 additions & 0 deletions .idea/workspace.xml

Large diffs are not rendered by default.

235 changes: 142 additions & 93 deletions demo/demo_toy_problem.ipynb

Large diffs are not rendered by default.

95 changes: 95 additions & 0 deletions demo/demo_toy_problem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@

# coding: utf-8

# In[1]:


from __future__ import division, print_function
# get_ipython().magic('matplotlib inline')
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
plt.rcParams['image.cmap'] = 'gist_earth'


# In[2]:


from tf_unet import image_gen
from tf_unet import unet
from tf_unet import util


# In[3]:


nx = 572
ny = 572


# In[4]:


generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20)


# In[5]:


x_test, y_test = generator(1)


# In[6]:


fig, ax = plt.subplots(1,2, sharey=True, figsize=(8,4))
ax[0].imshow(x_test[0,...,0], aspect="auto")
ax[1].imshow(y_test[0,...,1], aspect="auto")


# In[7]:


net = unet.Unet(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16)


# In[8]:


trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))


# In[9]:


path = trainer.train(generator, "./unet_trained", training_iters=20, epochs=10, display_step=2)


# In[22]:


x_test, y_test = generator(1)

prediction = net.predict("./unet_trained/model.cpkt", x_test)


# In[23]:


fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12,5))
ax[0].imshow(x_test[0,...,0], aspect="auto")
ax[1].imshow(y_test[0,...,1], aspect="auto")
mask = prediction[0,...,1] > 0.9
ax[2].imshow(mask, aspect="auto")
ax[0].set_title("Input")
ax[1].set_title("Ground truth")
ax[2].set_title("Prediction")
fig.tight_layout()
fig.savefig("../docs/toy_problem.png")


# In[ ]:




Binary file added demo/prediction/_init_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_0_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_10_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_11_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_12_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_13_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_14_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_15_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_16_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_1_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_2_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_3_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_4_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_5_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_6_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_7_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_8_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/prediction/epoch_9_0_0.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
489 changes: 489 additions & 0 deletions demo/toy_unet3d.ipynb

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions demo/unet3d_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from tf_unet import image_gen
from tf_unet import unet
from tf_unet import util

plt.rcParams['image.cmap'] = 'gist_earth'

nx = 256
ny = 256
generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20, depth_3d=8)
x_test, y_test = generator(1)

net = unet.Unet3D(channels=generator.channels, n_class=generator.n_class, layers=3, features_root=16)
trainer = unet.Trainer(net,optimizer="momentum", opt_kwargs=dict(momentum=0.2))
# trainer = unet.Trainer(net,optimizer="adam", opt_kwargs=dict(learning_rate=0.1))

path = trainer.train(generator,"./unet_trained", training_iters=10, epochs=50, display_step=2)
prediction = net.predict("./unet_trained/model.cpkt", x_test)

fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12,5))
ax[0].imshow(x_test[0,0,...,0], aspect="auto")
ax[1].imshow(y_test[0,0,...,1], aspect="auto")
mask = prediction[0,0,...,1] > 0.9
ax[2].imshow(mask, aspect="auto")
ax[0].set_title("Input")
ax[1].set_title("Ground truth")
ax[2].set_title("Prediction")
fig.tight_layout()
fig.savefig("../docs/toy_problem.png")
Binary file not shown.
Binary file modified docs/toy_problem.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 27 additions & 31 deletions tf_unet/image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,97 +24,93 @@
import numpy as np
from tf_unet.image_util import BaseDataProvider


class GrayScaleDataProvider(BaseDataProvider):
channels = 1
n_class = 2

def __init__(self, nx, ny, **kwargs):
super(GrayScaleDataProvider, self).__init__()
self.nx = nx
self.ny = ny
self.kwargs = kwargs
rect = kwargs.get("rectangles", False)
if rect:
self.n_class=3
self.n_class = 3

def _next_data(self):
return create_image_and_label(self.nx, self.ny, **self.kwargs)

class RgbDataProvider(BaseDataProvider):
channels = 3
n_class = 2

def __init__(self, nx, ny, **kwargs):
super(RgbDataProvider, self).__init__()
self.nx = nx
self.ny = ny
self.kwargs = kwargs
rect = kwargs.get("rectangles", False)
if rect:
self.n_class=3
self.n_class = 3


def _next_data(self):
data, label = create_image_and_label(self.nx, self.ny, **self.kwargs)
return to_rgb(data), label

def create_image_and_label(nx,ny, cnt = 10, r_min = 5, r_max = 50, border = 92, sigma = 20, rectangles=False):



def create_image_and_label(nx, ny, cnt=10, r_min=5, r_max=50, border=92, sigma=20, depth_3d = None, rectangles=False):
image = np.ones((nx, ny, 1))
label = np.zeros((nx, ny, 3), dtype=np.bool)
mask = np.zeros((nx, ny), dtype=np.bool)
for _ in range(cnt):
a = np.random.randint(border, nx-border)
b = np.random.randint(border, ny-border)
a = np.random.randint(border, nx - border)
b = np.random.randint(border, ny - border)
r = np.random.randint(r_min, r_max)
h = np.random.randint(1,255)
h = np.random.randint(1, 255)

y,x = np.ogrid[-a:nx-a, -b:ny-b]
m = x*x + y*y <= r*r
y, x = np.ogrid[-a:nx - a, -b:ny - b]
m = x * x + y * y <= r * r
mask = np.logical_or(mask, m)

image[m] = h

label[mask, 1] = 1

if rectangles:
mask = np.zeros((nx, ny), dtype=np.bool)
for _ in range(cnt//2):
for _ in range(cnt // 2):
a = np.random.randint(nx)
b = np.random.randint(ny)
r = np.random.randint(r_min, r_max)
h = np.random.randint(1,255)
r = np.random.randint(r_min, r_max)
h = np.random.randint(1, 255)

m = np.zeros((nx, ny), dtype=np.bool)
m[a:a+r, b:b+r] = True
m[a:a + r, b:b + r] = True
mask = np.logical_or(mask, m)
image[m] = h

label[mask, 2] = 1
label[..., 0] = ~(np.logical_or(label[...,1], label[...,2]))

label[..., 0] = ~(np.logical_or(label[..., 1], label[..., 2]))

image += np.random.normal(scale=sigma, size=image.shape)
image -= np.amin(image)
image /= np.amax(image)

if rectangles:
return image, label
else:
return image, label[..., 1]




def to_rgb(img):
img = img.reshape(img.shape[0], img.shape[1])
img[np.isnan(img)] = 0
img -= np.amin(img)
img /= np.amax(img)
blue = np.clip(4*(0.75-img), 0, 1)
red = np.clip(4*(img-0.25), 0, 1)
green= np.clip(44*np.fabs(img-0.5)-1., 0, 1)
blue = np.clip(4 * (0.75 - img), 0, 1)
red = np.clip(4 * (img - 0.25), 0, 1)
green = np.clip(44 * np.fabs(img - 0.5) - 1., 0, 1)
rgb = np.stack((red, green, blue), axis=2)
return rgb

59 changes: 45 additions & 14 deletions tf_unet/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,52 @@ def _post_process(self, data, labels):
return data, labels

def __call__(self, n):
train_data, labels = self._load_data_and_label()
nx = train_data.shape[1]
ny = train_data.shape[2]

X = np.zeros((n, nx, ny, self.channels))
Y = np.zeros((n, nx, ny, self.n_class))

X[0] = train_data
Y[0] = labels
for i in range(1, n):
depth = self.kwargs.get('depth_3d',None)
if depth is None:
train_data, labels = self._load_data_and_label()
X[i] = train_data
Y[i] = labels

return X, Y
nx = train_data.shape[1]
ny = train_data.shape[2]

X = np.zeros((n, nx, ny, self.channels))
Y = np.zeros((n, nx, ny, self.n_class))

X[0] = train_data
Y[0] = labels
for i in range(1, n):
train_data, labels = self._load_data_and_label()
X[i] = train_data
Y[i] = labels

return X, Y
else:
# get a stack of train_data and labels
train_data, labels = self._load_data_and_label()
nx = train_data.shape[1]
ny = labels.shape[2]
stack_x = np.empty((depth,nx,ny,self.channels)) # temporary placeholder for image stack
stack_y = np.empty((depth,nx,ny,self.n_class))
stack_x[0]=train_data
stack_y[0]=labels
for i in range(1,depth):
train_data, labels = self._load_data_and_label() # get one image and label everytime
stack_x[i]=train_data
stack_y[i]=labels

X = np.zeros((n, depth, nx, ny, self.channels))
Y = np.zeros((n, depth, nx, ny, self.n_class))
X[0] = stack_x
Y[0] = stack_y
# empty placeholder
stack_x = np.empty((depth,nx,ny,self.channels))
stack_y = np.empty((depth,nx,ny,self.n_class))
for j in range(1,n):
for i in range(depth):
train_data, labels = self._load_data_and_label() # get one image and label everytime
stack_x[i] = train_data
stack_y[i] = labels
X[j] = stack_x
Y[j] = stack_y
return X, Y

class SimpleDataProvider(BaseDataProvider):
"""
Expand Down