Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #133 from DerThorsten/master
Browse files Browse the repository at this point in the history
improved unet tutorial
  • Loading branch information
DerThorsten committed Aug 10, 2018
2 parents 970e027 + 436f0df commit 1d2dc58
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 39 deletions.
165 changes: 128 additions & 37 deletions examples/plot_unet_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Preface
# --------------
# We start with some unspectacular multi purpose imports needed for this example
import pylab
import matplotlib.pyplot as plt
import torch
import numpy

Expand All @@ -27,13 +27,15 @@

# convert labels from long to float as needed by
# binary cross entropy loss
label_transform = lambda x : torch.from_numpy(x).float()
def label_transform(x):
return torch.from_numpy(x).float()
#label_transform = lambda x : torch.from_numpy(x).float()

train_loader, test_loader, validate_loader = get_binary_blob_loaders(
size=8, # how many images per {train,test,validate}
train_batch_size=2,
length=256, # <= size of the images
gaussian_noise_sigma=1.5, # <= how noise are the images
gaussian_noise_sigma=1.4, # <= how noise are the images
train_label_transform = label_transform,
validate_label_transform = label_transform
)
Expand All @@ -44,7 +46,7 @@
##############################################################################
# Visualize Dataset
# ~~~~~~~~~~~~~~~~~~~~~~
fig = pylab.figure()
fig = plt.figure()

for i,(image, target) in enumerate(train_loader):
ax = fig.add_subplot(1, 2, 1)
Expand All @@ -55,7 +57,7 @@
ax.set_title('ground truth')
break
fig.tight_layout()
pylab.show()
plt.show()


##############################################################################
Expand Down Expand Up @@ -85,12 +87,16 @@
from inferno.extensions.layers import ConvReLU2D
model_a = torch.nn.Sequential(
ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),
ResBlockUNet(dim=2, in_channels=5, out_channels=pred_channels, activated=False) ,
RemoveSingletonDimension(dim=1),
torch.nn.Sigmoid()
ResBlockUNet(dim=2, in_channels=5, out_channels=pred_channels, activated=False,
res_block_kwargs=dict(batchnorm=True,size=2)) ,
RemoveSingletonDimension(dim=1)
# torch.nn.Sigmoid()
)





##############################################################################
# Training
# ----------------------------
Expand All @@ -104,12 +110,12 @@
def train_model(model, loaders, **kwargs):

trainer = Trainer(model)
trainer.build_criterion('BCELoss')
trainer.build_optimizer('Adam')
trainer.validate_every((kwargs.get('validate_every', 10), 'epochs'))
trainer.save_every((kwargs.get('save_every', 10), 'epochs'))
trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor')))
trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 20))
trainer.build_criterion('BCEWithLogitsLoss')
trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001))
#trainer.validate_every((kwargs.get('validate_every', 10), 'epochs'))
#trainer.save_every((kwargs.get('save_every', 10), 'epochs'))
#trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor')))
trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 200))

# bind the loaders
trainer.bind_loader('train', loaders[0])
Expand All @@ -124,7 +130,7 @@ def train_model(model, loaders, **kwargs):
return trainer


trainer = train_model(model=model_a, loaders=[train_loader, validate_loader], save_dir='model_a')
trainer = train_model(model=model_a, loaders=[train_loader, validate_loader], save_dir='model_a', lr=0.01)



Expand All @@ -133,37 +139,122 @@ def train_model(model, loaders, **kwargs):
# ----------------------------
# The trainer contains the trained model and we can do predictions.
# We use :code:`unwrap` to convert the results to numpy arrays.
trainer.eval_mode()
# Since we want to do many prediction we encapsulate the
# the prediction in a function
from inferno.utils.torch_utils import unwrap

def predict(trainer, test_loader, save_dir=None):

for image, target in test_loader:

# transfer image to gpu
image = image.cuda() if USE_CUDA else image
trainer.eval_mode()
for image, target in test_loader:

# get batch size from image
batch_size = image.size()[0]

prediction = trainer.apply_model(image)
# transfer image to gpu
image = image.cuda() if USE_CUDA else image

image = unwrap(image, as_numpy=True, to_cpu=True)
prediction = unwrap(prediction, as_numpy=True, to_cpu=True)
# get batch size from image
batch_size = image.size()[0]

for b in range(batch_size):
prediction = trainer.apply_model(image)
prediction = torch.nn.functional.sigmoid(prediction)

image = unwrap(image, as_numpy=True, to_cpu=True)
prediction = unwrap(prediction, as_numpy=True, to_cpu=True)
target = unwrap(target, as_numpy=True, to_cpu=True)

fig = pylab.figure()
fig = plt.figure()

ax = fig.add_subplot(1, 3, 1)
ax.imshow(image[0,0,...])
ax.set_title('raw data')
ax = fig.add_subplot(2, 2, 1)
ax.imshow(image[b,0,...])
ax.set_title('raw data')

ax = fig.add_subplot(1, 3, 2)
ax.imshow(target[0,...])
ax.set_title('raw data')
ax = fig.add_subplot(2, 2, 2)
ax.imshow(target[b,...])
ax.set_title('ground truth')

ax = fig.add_subplot(1, 3, 3)
ax.imshow(prediction[0,...])
ax.set_title('raw data')
ax = fig.add_subplot(2, 2, 4)
ax.imshow(prediction[b,...])
ax.set_title('prediction')

fig.tight_layout()
plt.show()

###################################################
# do the prediction
predict(trainer=trainer, test_loader=test_loader)




##############################################################################
# Custom UNet
# ----------------------------
# Often one needs to have a UNet with custom layers.
# Here we show how to implement such a customized UNet.
# To this end we derive from :code:`UNetBase`.
# For the sake of this example we will create
# a rather exotic UNet which uses different types
# of convolutions/non-linearities in the different branches
# of the unet
from inferno.extensions.layers import UNetBase
from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D

class MySimple2DUnet(UNetBase):
def __init__(self, in_channels, out_channels, depth=3, **kwargs):
super(MySimple2DUnet, self).__init__(in_channels=in_channels, out_channels=out_channels,
dim=2, depth=depth, **kwargs)

def conv_op_factory(self, in_channels, out_channels, part, index):

if part == 'down':
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvELU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
)
elif part == 'bottom':
return torch.nn.Sequential(
ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3),
)
elif part == 'up':
# are we in the very last block?
if index + 1 == self.depth:
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
Conv2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
)
else:
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
)
else:
raise RuntimeError("something is wrong")




# this function CAN be implemented, if not, MaxPooling is used by default
def downsample_op_factory(self, index):
return torch.nn.MaxPool2d(kernel_size=2, stride=2)

# this function CAN be implemented, if not, Upsampling is used by default
def upsample_op_factory(self, index):
return torch.nn.Upsample(mode='bilinear', align_corners=False,scale_factor=2)

model_b = torch.nn.Sequential(
ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),
MySimple2DUnet(in_channels=5, out_channels=pred_channels) ,
RemoveSingletonDimension(dim=1)
)


###################################################
# do the training (with the same functions as before)
trainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001)

###################################################
# do the training (with the same functions as before)
predict(trainer=trainer, test_loader=test_loader)

fig.tight_layout()
pylab.show()
2 changes: 1 addition & 1 deletion inferno/extensions/layers/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(self, dim=1):
def forward(self, x):
size = list(x.size())
if size[self.dim] != 1:
raise RuntimeError("RemoveSingletonDimension expects a single channel at dim %d"%d)
raise RuntimeError("RemoveSingletonDimension expects a single channel at dim %d, shape=%s"%(self.dim,str(size)))

slicing = []
for s in size:
Expand Down
3 changes: 2 additions & 1 deletion inferno/extensions/layers/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ def __init__(self, in_channels, dim, out_channels, unet_kwargs=None,
def conv_op_factory(self, in_channels, out_channels, part, index):

# is this the very last convolutional block?
very_last = part == 'up' and index + 1 == self.depth
very_last = (part == 'up' and index + 1 == self.depth)


# should the residual block be activated?
activated = not very_last or self.activated
Expand Down
1 change: 1 addition & 0 deletions inferno/io/box/binary_blobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __getitem__(self, index):

image -= image.mean()
image /= image.std()

label = label.astype('long')
try:
# Apply transforms
Expand Down

0 comments on commit 1d2dc58

Please sign in to comment.