Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #93 from inferno-pytorch/hyper-dev
Browse files Browse the repository at this point in the history
Add BNRelu3D layers, Random Transpostition trafo and more segmentation metrics
  • Loading branch information
nasimrahaman committed Nov 24, 2017
2 parents aa26432 + 2139662 commit becd717
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 3 deletions.
24 changes: 22 additions & 2 deletions inferno/extensions/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'StridedConvELU2D', 'StridedConvELU3D',
'DilatedConvELU2D', 'DilatedConvELU3D',
'Conv2D', 'Conv3D',
'BNReLUConv2D',
'BNReLUConv2D', 'BNReLUConv3D',
'BNReLUDepthwiseConv2D',
'ConvSELU2D', 'ConvSELU3D']

Expand Down Expand Up @@ -287,6 +287,26 @@ def forward(self, input):
return conved


class BNReLUConv3D(ConvActivation):
"""
3D BN-ReLU-Conv layer with 'SAME' padding and He weight initialization.
"""
def __init__(self, in_channels, out_channels, kernel_size):
super(BNReLUConv3D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dim=3,
activation=nn.ReLU(inplace=True),
initialization=KaimingNormalWeightsZeroBias(0))
self.batchnorm = nn.BatchNorm3d(in_channels)

def forward(self, input):
normed = self.batchnorm(input)
activated = self.activation(normed)
conved = self.conv(activated)
return conved


class BNReLUDepthwiseConv2D(ConvActivation):
"""
2D BN-ReLU-Conv layer with 'SAME' padding, He weight initialization and depthwise convolution.
Expand Down Expand Up @@ -343,4 +363,4 @@ def __init__(self, in_channels, out_channels, kernel_size):
kernel_size=kernel_size,
dim=3,
activation=activation,
initialization=SELUWeightsZeroBias())
initialization=SELUWeightsZeroBias())
305 changes: 305 additions & 0 deletions inferno/extensions/metrics/voi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
from .base import Metric

import numpy as np
import scipy.sparse as sparse


class VoiScore(Metric):
"""
Computes a score based on the variation of information according to [1].
References
----------
[1] Meila, M. (2007). Comparing clusterings - an information based
distance. Journal of Multivariate Analysis 98, 873-895.
"""
def forward(self, prediction, target):
assert(len(prediction) == len(target))
segmentation = prediction.cpu().numpy()
target = target.cpu().numpy()
return np.mean([sum(voi(segmentation[i], target[i]))
for i in range(len(prediction))])


# Copied from `cremi-python`
# https://github.com/cremi/cremi_python/blob/master/cremi/evaluation/voi.py

# Evaluation code courtesy of Juan Nunez-Iglesias, taken from
# https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py

def voi(seg, gt, ignore_reconstruction=[], ignore_groundtruth=[0]):
"""Return the conditional entropies of the variation of information metric. [1]
Let X be a seg, and Y a ground truth labelling. The variation of
information between the two is the sum of two conditional entropies:
VI(X, Y) = H(X|Y) + H(Y|X).
The first one, H(X|Y), is a measure of oversegmentation, the second one,
H(Y|X), a measure of undersegmentation. These measures are referred to as
the variation of information split or merge error, respectively.
Parameters
----------
seg : np.ndarray, int type, arbitrary shape
A candidate segmentation.
gt : np.ndarray, int type, same shape as `seg`
The ground truth segmentation.
ignore_seg, ignore_gt : list of int, optional
Any points having a label in this list are ignored in the evaluation.
By default, only the label 0 in the ground truth will be ignored.
Returns
-------
(split, merge) : float
The variation of information split and merge error, i.e., H(X|Y) and H(Y|X)
References
----------
[1] Meila, M. (2007). Comparing clusterings - an information based
distance. Journal of Multivariate Analysis 98, 873-895.
"""
hyxg, hxgy = split_vi(seg, gt, ignore_reconstruction, ignore_groundtruth)
return hxgy, hyxg


def split_vi(x, y=None, ignore_x=[0], ignore_y=[0]):
"""Return the symmetric conditional entropies associated with the VI.
The variation of information is defined as VI(X,Y) = H(X|Y) + H(Y|X).
If Y is the ground-truth segmentation, then H(Y|X) can be interpreted
as the amount of under-segmentation of Y and H(X|Y) is then the amount
of over-segmentation. In other words, a perfect over-segmentation
will have H(Y|X)=0 and a perfect under-segmentation will have H(X|Y)=0.
If y is None, x is assumed to be a contingency table.
Parameters
----------
x : np.ndarray
Label field (int type) or contingency table (float). `x` is
interpreted as a contingency table (summing to 1.0) if and only if `y`
is not provided.
y : np.ndarray of int, same shape as x, optional
A label field to compare to `x`.
ignore_x, ignore_y : list of int, optional
Any points having a label in this list are ignored in the evaluation.
Ignore 0-labeled points by default.
Returns
-------
sv : np.ndarray of float, shape (2,)
The conditional entropies of Y|X and X|Y.
See Also
--------
vi
"""
_, _, _, hxgy, hygx, _, _ = vi_tables(x, y, ignore_x, ignore_y)
# false merges, false splits
return np.array([hygx.sum(), hxgy.sum()])


def vi_tables(x, y=None, ignore_x=[0], ignore_y=[0]):
"""Return probability tables used for calculating VI.
If y is None, x is assumed to be a contingency table.
Parameters
----------
x, y : np.ndarray
Either x and y are provided as equal-shaped np.ndarray label fields
(int type), or y is not provided and x is a contingency table
(sparse.csc_matrix) that may or may not sum to 1.
ignore_x, ignore_y : list of int, optional
Rows and columns (respectively) to ignore in the contingency table.
These are labels that are not counted when evaluating VI.
Returns
-------
pxy : sparse.csc_matrix of float
The normalized contingency table.
px, py, hxgy, hygx, lpygx, lpxgy : np.ndarray of float
The proportions of each label in `x` and `y` (`px`, `py`), the
per-segment conditional entropies of `x` given `y` and vice-versa, the
per-segment conditional probability p log p.
"""
if y is not None:
pxy = contingency_table(x, y, ignore_x, ignore_y)
else:
cont = x
total = float(cont.sum())
# normalize, since it is an identity op if already done
pxy = cont / total

# Calculate probabilities
px = np.array(pxy.sum(axis=1)).ravel()
py = np.array(pxy.sum(axis=0)).ravel()
# Remove zero rows/cols
nzx = px.nonzero()[0]
nzy = py.nonzero()[0]
nzpx = px[nzx]
nzpy = py[nzy]
nzpxy = pxy[nzx, :][:, nzy]

# Calculate log conditional probabilities and entropies
lpygx = np.zeros(np.shape(px))
lpygx[nzx] = xlogx(divide_rows(nzpxy, nzpx)).sum(axis=1).squeeze() # \sum_x{p_{y|x} \log{p_{y|x}}}
hygx = -(px * lpygx) # \sum_x{p_x H(Y|X=x)} = H(Y|X)

lpxgy = np.zeros(np.shape(py))
lpxgy[nzy] = xlogx(divide_columns(nzpxy, nzpy)).sum(axis=0)
hxgy = -(py * lpxgy)

return [pxy] + list(map(np.asarray, [px, py, hxgy, hygx, lpygx, lpxgy]))


def contingency_table(seg, gt, ignore_seg=[0], ignore_gt=[0], norm=True):
"""Return the contingency table for all regions in matched segmentations.
Parameters
----------
seg : np.ndarray, int type, arbitrary shape
A candidate segmentation.
gt : np.ndarray, int type, same shape as `seg`
The ground truth segmentation.
ignore_seg : list of int, optional
Values to ignore in `seg`. Voxels in `seg` having a value in this list
will not contribute to the contingency table. (default: [0])
ignore_gt : list of int, optional
Values to ignore in `gt`. Voxels in `gt` having a value in this list
will not contribute to the contingency table. (default: [0])
norm : bool, optional
Whether to normalize the table so that it sums to 1.
Returns
-------
cont : scipy.sparse.csc_matrix
A contingency table. `cont[i, j]` will equal the number of voxels
labeled `i` in `seg` and `j` in `gt`. (Or the proportion of such voxels
if `norm=True`.)
"""
segr = seg.ravel()
gtr = gt.ravel()
ignored = np.zeros(segr.shape, np.bool)
data = np.ones(len(gtr))
for i in ignore_seg:
ignored[segr == i] = True
for j in ignore_gt:
ignored[gtr == j] = True
data[ignored] = 0
cont = sparse.coo_matrix((data, (segr, gtr))).tocsc()
if norm:
cont /= float(cont.sum())
return cont


def divide_columns(matrix, row, in_place=False):
"""Divide each column of `matrix` by the corresponding element in `row`.
The result is as follows: out[i, j] = matrix[i, j] / row[j]
Parameters
----------
matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N)
The input matrix.
column : a 1D np.ndarray, shape (N,)
The row dividing `matrix`.
in_place : bool (optional, default False)
Do the computation in-place.
Returns
-------
out : same type as `matrix`
The result of the row-wise division.
"""
if in_place:
out = matrix
else:
out = matrix.copy()
if type(out) in [sparse.csc_matrix, sparse.csr_matrix]:
if type(out) == sparse.csc_matrix:
convert_to_csc = True
out = out.tocsr()
else:
convert_to_csc = False
row_repeated = np.take(row, out.indices)
nz = out.data.nonzero()
out.data[nz] /= row_repeated[nz]
if convert_to_csc:
out = out.tocsc()
else:
out /= row[np.newaxis, :]
return out


def divide_rows(matrix, column, in_place=False):
"""Divide each row of `matrix` by the corresponding element in `column`.
The result is as follows: out[i, j] = matrix[i, j] / column[i]
Parameters
----------
matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N)
The input matrix.
column : a 1D np.ndarray, shape (M,)
The column dividing `matrix`.
in_place : bool (optional, default False)
Do the computation in-place.
Returns
-------
out : same type as `matrix`
The result of the row-wise division.
"""
if in_place:
out = matrix
else:
out = matrix.copy()
if type(out) in [sparse.csc_matrix, sparse.csr_matrix]:
if type(out) == sparse.csr_matrix:
convert_to_csr = True
out = out.tocsc()
else:
convert_to_csr = False
column_repeated = np.take(column, out.indices)
nz = out.data.nonzero()
out.data[nz] /= column_repeated[nz]
if convert_to_csr:
out = out.tocsr()
else:
out /= column[:, np.newaxis]
return out


def xlogx(x, out=None, in_place=False):
"""Compute x * log_2(x).
We define 0 * log_2(0) = 0
Parameters
----------
x : np.ndarray or scipy.sparse.csc_matrix or csr_matrix
The input array.
out : same type as x (optional)
If provided, use this array/matrix for the result.
in_place : bool (optional, default False)
Operate directly on x.
Returns
-------
y : same type as x
Result of x * log_2(x).
"""
if in_place:
y = x
elif out is None:
y = x.copy()
else:
y = out
if type(y) in [sparse.csc_matrix, sparse.csr_matrix]:
z = y.data
else:
z = y
nz = z.nonzero()
z[nz] *= np.log2(z[nz])
return y
17 changes: 16 additions & 1 deletion inferno/io/transform/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,21 @@ def image_function(self, image):
return np.rot90(image, k=self.get_random_variable('k'))


class RandomTranspose(Transform):
"""Random 2d transpose."""
def __init__(self, **super_kwargs):
super(RandomTranspose, self).__init__(**super_kwargs)

def build_random_variables(self, **kwargs):
np.random.seed()
self.set_random_variable('do_transpose', np.random.uniform() > 0.5)

def image_function(self, image):
if self.get_random_variable('do_transpose'):
image = np.transpose(image)
return image


class RandomFlip(Transform):
"""Random left-right or up-down flips."""
def __init__(self, allow_lr_flips=True, allow_ud_flips=True, **super_kwargs):
Expand Down Expand Up @@ -433,7 +448,7 @@ def image_function(self, image):
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return image[x1:x1+tw, y1:y1+th]
return image[x1:x1 + tw, y1:y1 + th]


class BinaryMorphology(Transform):
Expand Down

0 comments on commit becd717

Please sign in to comment.