In [1]:
from functools import partial

import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Adam

from dpipe.model_core.tnet import TNet3d
from dpipe.torch.model import TorchModel, to_var
from dpipe.torch.utils import NormalizedSoftmaxCrossEntropy, softmax_cross_entropy

In [2]:
batch_size = 8

n_chans_in = 4
n_chans_out = 3

spatial_size = 3 * [200]

x = np.float32(np.random.rand(batch_size, n_chans_in, *spatial_size))
y = np.uint8(np.random.randint(0, n_chans_out-1, size=(batch_size, *spatial_size)))

In [3]:
x.dtype, y.dtype

(dtype('float32'), dtype('uint8'))

In [4]:
y_torch = to_var(y, cuda=True)

In [5]:
y_torch.data.type()

'torch.cuda.ByteTensor'

In [None]:
%%time

y_torch.long().long();

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 416 µs


In [6]:
from abc import ABCMeta

import torch
import torch.nn as nn
from torch.autograd import Variable


def compress_to_2d(x):
    return x.view(*x.size()[:-2], -1)


def softmax_cross_entropy(logits, target, weight=None, reduce=True):
    """Softmax cross entropy loss for Nd data."""
    # target = target.long()
    if target.dim() > 4:
        logits = compress_to_2d(logits)
        target = compress_to_2d(target)

    return nn.functional.cross_entropy(compress_to_2d(logits), compress_to_2d(target), weight=weight, reduce=reduce)


class Eye(nn.Module, metaclass=ABCMeta):
    def __init__(self, n_classes):
        super().__init__()
        self.n_classes = n_classes
        self.register_buffer('eye', Variable(torch.eye(n_classes)))


class NormalizedSoftmaxCrossEntropy(Eye):
    def forward(self, logits, target):
        # target = target.long()
        flat_target = target.view(-1)
        count = self.eye.index_select(0, flat_target).sum(0)
        weight = flat_target.size()[0] / count

        return softmax_cross_entropy(logits, target, weight=weight)

In [6]:
structure = [
    [[], [], [8]],
    [[8, 8], [16, 16, 32, 32], [64]],
    [[16, 16, 32, 32, 64, 64]],
]

stride = 3

model_core = TNet3d(n_chans_in, n_chans_out, nn.functional.relu, structure, stride)

In [7]:
logits2ce = softmax_cross_entropy
logits2nce = NormalizedSoftmaxCrossEntropy(n_chans_out)

logits2pred = partial(nn.functional.softmax, dim=1)

model_ce = TorchModel(model_core, logits2pred=logits2pred, logits2loss=logits2ce, optimize=Adam)
model_nce = TorchModel(model_core, logits2pred=logits2pred, logits2loss=logits2nce, optimize=Adam)

In [8]:
print(x.shape, y.shape)

(8, 4, 200, 200, 200) (8, 200, 200, 200)


In [9]:
%%timeit

model_ce.do_inf_step(x);

1 loop, best of 3: 1.33 s per loop


In [10]:
%%timeit

loss, y_pred = model_ce.do_val_step(x, y)

RuntimeError: Expected object of type Variable[torch.cuda.LongTensor] but found type Variable[torch.cuda.ByteTensor] for argument #1 'target'

In [None]:
%%timeit

loss, y_pred = model_nce.do_val_step(x, y)

In [11]:
y_pred, loss = model_nce.do_val_step(x, y)
loss

1.1036024

In [None]:
def dice_loss_from_proba(input, target, eps=1e-5):
    i = input.view(-1)
    o = target.view(-1)
    return 1 - ((2. * (i * o).sum() + eps) / (i.sum() + o.sum() + eps))


def dice_loss_logits(logits, target):
    return dice_loss_from_proba(nn.functional.sigmoid(logits), target)


class SegmDiceLoss(torch.nn.Module):
    def __init__(self, segm2msegm, eps=1e-5):
        self.n_chans = len(segm2msegm)
        self.segm2msegm = torch.from_numpy(segm2msegm)
        self.eps = eps
        
    def forward(self, logits, target):
        msegm_true = self.segm2msegm.index_select(0, target.view(-1))
        msegm_pred = nn.functional.linear(swap_channels(input).view(-1, self.n_chans), self.segm2msegm)
        return dice_loss_from_proba(msegm_pred, msegm_true)