In [104]:
import numpy as np

import torch
import torch.nn as nn
import torch.functional as F

class Unet(nn.Module):
    """
    U-net for use in multiplanar u-net
    """
    def __init__(self,
                 n_classes,
                 img_rows=None,
                 img_cols=None,
                 n_channels=1,
                 depth=4,
                 cf=np.sqrt(2)):
        super(Unet, self).__init__()
        self.n_classes = n_classes
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.n_channels = n_channels
        self.depth = depth
        self.cf = cf

        # Initial number of filters which will double/half after every down/up layer
        self.init_filters = 64

    def _run_contractive_path(self, _x):
        x = _x
        cross_connections = []
        in_channels = self.n_channels
        filters = self.init_filters
        for i in range(self.depth):
            out_channels = int(filters * self.cf)

            tmp_out = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    padding="same",
                ),
                nn.ReLU(),
                nn.Conv2d(
                    in_channels=out_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    padding="same",
                ),
                nn.ReLU(),
                nn.BatchNorm2d(out_channels)
            )(x)

            x = nn.MaxPool2d(kernel_size=2)(tmp_out)

            cross_connections.append(tmp_out)

            in_channels = out_channels
            filters *= 2

        return x, cross_connections, filters

    def _run_encoder(self, _x, filters):
        in_channels = int(filters / 2 * self.cf)
        out_channels = int(filters * self.cf)

        x = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding="same",
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding="same",
            ),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )(_x)

        return x

    def _run_expansive_path(self, _x, filters, cross_connections):
        x = _x
        in_channels = int(filters * self.cf)
        filters = filters // 2
        for i in range(self.depth):
            out_channels = int(filters * self.cf)

            bn_out = nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=2,
                    padding="same"
                ),
                nn.ReLU(),
                nn.BatchNorm2d(out_channels)
            )(x)

            bn_cross = cross_connections[i]

            merge_bn = torch.cat([bn_cross, bn_out], dim=1)

            x = nn.Sequential(
                nn.Conv2d(
                    in_channels=out_channels * 2,
                    out_channels=out_channels,
                    kernel_size=3,
                    padding="same"
                ),
                nn.ReLU(),
                nn.Conv2d(
                    in_channels=out_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    padding="same"
                ),
                nn.ReLU(),
                nn.BatchNorm2d(out_channels)
            )(merge_bn)

            in_channels = out_channels
            filters /= 2

        return x

    def forward(self, x):
        x, cross_connections, filters = self._run_contractive_path(x)
        x = self._run_encoder(x, filters)
        x = self._run_expansive_path(x, filters, cross_connections[::-1])
        x = nn.Conv2d(
            in_channels=int(self.init_filters * self.cf),
            out_channels=self.n_classes,
            kernel_size=1
        )(x)
        x = nn.Softmax(dim=1)(x)
        return x


if __name__ == "__main__":
    unet = Unet(4)
    x = torch.ones((1,1,256,256))
    unet(x)

In [105]:
unet = Unet(4)

In [106]:
x = torch.ones((1, 1, 256, 256))

In [107]:
out = unet.forward(x)

In [108]:
out.shape

torch.Size([1, 4, 256, 256])