Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #132 from DerThorsten/master
Browse files Browse the repository at this point in the history
doc update
  • Loading branch information
DerThorsten committed Aug 10, 2018
2 parents 1df62d3 + 5d55397 commit 970e027
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 67 deletions.
9 changes: 0 additions & 9 deletions _requirements_readthedocs.txt

This file was deleted.

3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def __getattr__(cls, name):
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery',
'sphinxcontrib.bibtex',
'sphinx.ext.napoleon'
'sphinx.ext.napoleon',
'sphinxcontrib.inlinesyntaxhighlight'
]


Expand Down
2 changes: 1 addition & 1 deletion docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ dependencies:
- sphinx-gallery
- sphinxcontrib-napoleon
- sphinxcontrib-bibtex

- sphinxcontrib-inlinesyntaxhighlight
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Welcome to inferno's documentation!
Contents:

.. toctree::
:maxdepth: 4
:maxdepth: 1

readme
installation
Expand Down
169 changes: 169 additions & 0 deletions examples/plot_unet_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""
UNet Tutorial
================================
A tentative tutorial on the usage
of the unet framework in inferno
"""

##############################################################################
# Preface
# --------------
# We start with some unspectacular multi purpose imports needed for this example
import pylab
import torch
import numpy

##############################################################################
# should CUDA be used
USE_CUDA = True


##############################################################################
# Dataset
# --------------
# For simplicity we will use a toy dataset where we need to perform
# a binary segmentation task.
from inferno.io.box.binary_blobs import get_binary_blob_loaders

# convert labels from long to float as needed by
# binary cross entropy loss
label_transform = lambda x : torch.from_numpy(x).float()

train_loader, test_loader, validate_loader = get_binary_blob_loaders(
size=8, # how many images per {train,test,validate}
train_batch_size=2,
length=256, # <= size of the images
gaussian_noise_sigma=1.5, # <= how noise are the images
train_label_transform = label_transform,
validate_label_transform = label_transform
)

image_channels = 1 # <-- number of channels of the image
pred_channels = 1 # <-- number of channels needed for the prediction

##############################################################################
# Visualize Dataset
# ~~~~~~~~~~~~~~~~~~~~~~
fig = pylab.figure()

for i,(image, target) in enumerate(train_loader):
ax = fig.add_subplot(1, 2, 1)
ax.imshow(image[0,0,...])
ax.set_title('raw data')
ax = fig.add_subplot(1, 2, 2)
ax.imshow(target[0,...])
ax.set_title('ground truth')
break
fig.tight_layout()
pylab.show()


##############################################################################
# Simple UNet
# ----------------------------
# We start with a very simple predefined
# res block UNet. By default, this UNet uses ReLUs (in conjunction with batchnorm) as nonlinearities
# With :code:`activated=False` we make sure that the last layer
# is not activated since we chain the UNet with a sigmoid
# activation function.
from inferno.extensions.layers.unet import ResBlockUNet
from inferno.extensions.layers import RemoveSingletonDimension

model = torch.nn.Sequential(
ResBlockUNet(dim=2, in_channels=image_channels, out_channels=pred_channels, activated=False),
RemoveSingletonDimension(dim=1),
torch.nn.Sigmoid()
)

##############################################################################
# while the model above will work in principal, it has some drawbacks.
# Within the UNet, the number of features is increased by a multiplicative
# factor while going down, the so-called gain. The default value for the gain is 2.
# Since we start with only a single channel we could either increase the gain,
# or use a some convolutions to increase the number of channels
# before the the UNet.
from inferno.extensions.layers import ConvReLU2D
model_a = torch.nn.Sequential(
ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),
ResBlockUNet(dim=2, in_channels=5, out_channels=pred_channels, activated=False) ,
RemoveSingletonDimension(dim=1),
torch.nn.Sigmoid()
)


##############################################################################
# Training
# ----------------------------
# To train the unet, we use the infernos Trainer class of inferno.
# Since we train many models later on in this example we encapsulate
# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for
# an example dedicated to the trainer itself).
from inferno.trainers import Trainer
from inferno.utils.python_utils import ensure_dir

def train_model(model, loaders, **kwargs):

trainer = Trainer(model)
trainer.build_criterion('BCELoss')
trainer.build_optimizer('Adam')
trainer.validate_every((kwargs.get('validate_every', 10), 'epochs'))
trainer.save_every((kwargs.get('save_every', 10), 'epochs'))
trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor')))
trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 20))

# bind the loaders
trainer.bind_loader('train', loaders[0])
trainer.bind_loader('validate', loaders[1])

if USE_CUDA:
trainer.cuda()

# do the training
trainer.fit()

return trainer


trainer = train_model(model=model_a, loaders=[train_loader, validate_loader], save_dir='model_a')



##############################################################################
# Prediction
# ----------------------------
# The trainer contains the trained model and we can do predictions.
# We use :code:`unwrap` to convert the results to numpy arrays.
trainer.eval_mode()
from inferno.utils.torch_utils import unwrap


for image, target in test_loader:

# transfer image to gpu
image = image.cuda() if USE_CUDA else image

# get batch size from image
batch_size = image.size()[0]

prediction = trainer.apply_model(image)

image = unwrap(image, as_numpy=True, to_cpu=True)
prediction = unwrap(prediction, as_numpy=True, to_cpu=True)


fig = pylab.figure()

ax = fig.add_subplot(1, 3, 1)
ax.imshow(image[0,0,...])
ax.set_title('raw data')

ax = fig.add_subplot(1, 3, 2)
ax.imshow(target[0,...])
ax.set_title('raw data')

ax = fig.add_subplot(1, 3, 3)
ax.imshow(prediction[0,...])
ax.set_title('raw data')

fig.tight_layout()
pylab.show()
3 changes: 2 additions & 1 deletion inferno/extensions/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
'Conv2D', 'Conv3D',
'BNReLUConv2D', 'BNReLUConv3D',
'BNReLUDepthwiseConv2D',
'ConvSELU2D', 'ConvSELU3D']
'ConvSELU2D', 'ConvSELU3D',
'ConvReLU2D', 'ConvReLU3D']
_all = __all__

class ConvActivation(nn.Module):
Expand Down
27 changes: 26 additions & 1 deletion inferno/extensions/layers/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'Concatenate', 'Cat',
'ResizeAndConcatenate', 'PoolCat',
'GlobalMeanPooling', 'GlobalMaxPooling',
'Sum', 'SplitChannels']
'Sum', 'SplitChannels','Squeeze', 'RemoveSingletonDimension']
_all = __all__

class View(nn.Module):
Expand Down Expand Up @@ -206,3 +206,28 @@ def forward(self, input):
split_0 = input[:, 0:split_location, ...]
split_1 = input[:, split_location:, ...]
return split_0, split_1



class Squeeze(nn.Module):
def __init__(self):
super(Squeeze, self).__init__()
def forward(self, x):
return x.squeeze()

class RemoveSingletonDimension(nn.Module):
def __init__(self, dim=1):
super(RemoveSingletonDimension, self).__init__()
self.dim = 1
def forward(self, x):
size = list(x.size())
if size[self.dim] != 1:
raise RuntimeError("RemoveSingletonDimension expects a single channel at dim %d"%d)

slicing = []
for s in size:
slicing.append(slice(0, s))

slicing[self.dim] = 0

return x[slicing]
23 changes: 18 additions & 5 deletions inferno/io/box/binary_blobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __getitem__(self, index):

image -= image.mean()
image /= image.std()

label = label.astype('long')
try:
# Apply transforms
if self.image_transform is not None:
Expand All @@ -102,26 +102,39 @@ def __getitem__(self, index):
raise

image = image[None,...]
return image, label.astype('long')
return image, label

def __len__(self):
return self.size


def get_binary_blob_loaders(train_batch_size=1, test_batch_size=1,
num_workers=1,
train_image_transform=None,
train_label_transform=None,
train_joint_transform=None,
validate_image_transform=None,
validate_label_transform=None,
validate_joint_transform=None,
test_image_transform=None,
test_label_transform=None,
test_joint_transform=None,
**kwargs):

trainset = BinaryBlobs(split='train', **kwargs)
testset = BinaryBlobs(split='test', **kwargs)
validset = BinaryBlobs(split='validate', **kwargs)
trainset = BinaryBlobs(split='train', image_transform=train_image_transform,
label_transform=train_label_transform, joint_transform=train_joint_transform, **kwargs)
testset = BinaryBlobs(split='test', image_transform=test_image_transform,
label_transform=test_label_transform, joint_transform=test_joint_transform, **kwargs)
validset = BinaryBlobs(split='validate',image_transform=validate_image_transform,
label_transform=validate_label_transform, joint_transform=validate_joint_transform, **kwargs)


trainloader = data.DataLoader(trainset, batch_size=train_batch_size,
num_workers=num_workers)

testloader = data.DataLoader(testset, batch_size=test_batch_size,
num_workers=num_workers)

validloader = data.DataLoader(validset, batch_size=test_batch_size,
num_workers=num_workers)

Expand Down
4 changes: 2 additions & 2 deletions inferno/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import basic
from . import callbacks

__all__ = ['basic','callbacks-']
from . basic import Trainer
__all__ = ['basic','callbacks','Trainer']
3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ scipy>=0.13.0
h5py
scikit-image
sphinx-gallery
sphinxcontrib-napoleon
sphinxcontrib-napoleon
sphinxcontrib-inlinesyntaxhighlight
21 changes: 0 additions & 21 deletions setup.cfg

This file was deleted.

24 changes: 0 additions & 24 deletions tox.ini

This file was deleted.

0 comments on commit 970e027

Please sign in to comment.