In [1]:
!pip install monai

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from typing import Dict, Optional, Tuple, Type, Union

import torch
import torch.nn as nn

In [4]:
from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [None]:
__all__ = ["VNet"]

In [None]:
def get_acti_layer(act: Union[Tuple[str, Dict], str], nchan: int = 0):
    if act == "prelu":
        act = ("prelu", {"num_parameters": nchan})
    act_name, act_args = split_args(act)
    act_type = Act[act_name]
    return act_type(**act_args)


class LUConv(nn.Module):
    def __init__(self, spatial_dims: int, nchan: int, act: Union[Tuple[str, Dict], str], bias: bool = False):
        super().__init__()

        self.act_function = get_acti_layer(act, nchan)
        self.conv_block = Convolution(
            spatial_dims=spatial_dims,
            in_channels=nchan,
            out_channels=nchan,
            kernel_size=5,
            act=None,
            norm=Norm.BATCH,
            bias=bias,
        )

    def forward(self, x):
        out = self.conv_block(x)
        out = self.act_function(out)
        return out


def _make_nconv(spatial_dims: int, nchan: int, depth: int, act: Union[Tuple[str, Dict], str], bias: bool = False):
    layers = []
    for _ in range(depth):
        layers.append(LUConv(spatial_dims, nchan, act, bias))
    return nn.Sequential(*layers)


class InputTransition(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        act: Union[Tuple[str, Dict], str],
        bias: bool = False,
    ):
        super().__init__()

        # if out_channels % in_channels != 0:
        #     raise ValueError(
        #         f"out channels should be divisible by in_channels. Got in_channels={in_channels}, out_channels={out_channels}."
        #     )

        self.spatial_dims = spatial_dims
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.act_function = get_acti_layer(act, out_channels)
        self.conv_block = Convolution(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=5,
            act=None,
            norm=Norm.BATCH,
            bias=bias,
        )

    def forward(self, x):
        print(x.shape)
        out = self.conv_block(x)
        print("intrans")
        print(out.shape)
        repeat_num = self.out_channels // self.in_channels
        x16 = x.repeat([1, repeat_num, 1, 1, 1][: self.spatial_dims + 2])
        print(x16.shape)
        #out = self.act_function(torch.add(out, x16))
        out = self.act_function(out)
        return out


class DownTransition(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        nconvs: int,
        act: Union[Tuple[str, Dict], str],
        dropout_prob: Optional[float] = None,
        dropout_dim: int = 3,
        bias: bool = False,
    ):
        super().__init__()

        conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
        norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims]
        dropout_type: Type[Union[nn.Dropout, nn.Dropout2d, nn.Dropout3d]] = Dropout[Dropout.DROPOUT, dropout_dim]

        out_channels = 2 * in_channels
        self.down_conv = conv_type(in_channels, out_channels, kernel_size=2, stride=2, bias=bias)
        self.bn1 = norm_type(out_channels)
        self.act_function1 = get_acti_layer(act, out_channels)
        self.act_function2 = get_acti_layer(act, out_channels)
        self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act, bias)
        self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None

    def forward(self, x):
        down = self.act_function1(self.bn1(self.down_conv(x)))
        if self.dropout is not None:
            out = self.dropout(down)
        else:
            out = down
        out = self.ops(out)
        out = self.act_function2(torch.add(out, down))
        return out

class UpTransition(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        nconvs: int,
        act: Union[Tuple[str, Dict], str],
        dropout_prob: Optional[float] = None,
        dropout_dim: int = 3,
    ):
        super().__init__()

        conv_trans_type: Type[Union[nn.ConvTranspose2d, nn.ConvTranspose3d]] = Conv[Conv.CONVTRANS, spatial_dims]
        norm_type: Type[Union[nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims]
        dropout_type: Type[Union[nn.Dropout, nn.Dropout2d, nn.Dropout3d]] = Dropout[Dropout.DROPOUT, dropout_dim]

        self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2)
        self.bn1 = norm_type(out_channels // 2)
        self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None
        self.dropout2 = dropout_type(0.5)
        self.act_function1 = get_acti_layer(act, out_channels // 2)
        self.act_function2 = get_acti_layer(act, out_channels)
        self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act)

    def forward(self, x, skipx):
        if self.dropout is not None:
            out = self.dropout(x)
        else:
            out = x
        print("uptransition")
        print(out.shape) #(1,256,9,15,15)
        skipxdo = self.dropout2(skipx)
        print(skipxdo.shape) #(1,128,19,30,30)
        out = self.up_conv(out)
        print(out.shape) #torch.Size([1, 128, 18, 30, 30])
        out = self.bn1(out)
        print(out.shape) #torch.Size([1, 128, 18, 30, 30])
        out = self.act_function1(out)
        print(out.shape) #torch.Size([1, 128, 18, 30, 30])
        #out = torch.reshape(out, (out.shape[0],out.shape[1],out.shape[2]+1, out.shape[3], out.shape[4]))
        if out.shape[1] !=64:
            reshape_out = torch.zeros((out.shape[0],out.shape[1],1, out.shape[3], out.shape[4]))
            out = torch.cat((out, reshape_out), 2)
        print(out.shape)
        xcat = torch.cat((out, skipxdo), 1)
        out = self.ops(xcat)
        out = self.act_function2(torch.add(out, xcat))
        return out

class OutputTransition(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        act: Union[Tuple[str, Dict], str],
        bias: bool = False,
    ):
        super().__init__()

        conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]

        self.act_function1 = get_acti_layer(act, out_channels)
        self.conv_block = Convolution(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=5,
            act=None,
            norm=Norm.BATCH,
            bias=bias,
        )
        self.conv2 = conv_type(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # convolve 32 down to 4 channels
        out = self.conv_block(x)
        out = self.act_function1(out)
        out = self.conv2(out)
        return out

In [None]:
class VNet(nn.Module):
    def __init__(
        self,
        spatial_dims: int = 3,
        in_channels: int = 3,
        out_channels: int = 4,
        act: Union[Tuple[str, Dict], str] = ("elu", {"inplace": True}),
        dropout_prob: float = 0.5,
        dropout_dim: int = 3,
        bias: bool = False,
    ):
        super().__init__()

        if spatial_dims not in (2, 3):
            raise AssertionError("spatial_dims can only be 2 or 3.")

        self.in_tr = InputTransition(spatial_dims, in_channels, 16, act, bias=bias)
        self.down_tr32 = DownTransition(spatial_dims, 16, 1, act, bias=bias)
        self.down_tr64 = DownTransition(spatial_dims, 32, 2, act, bias=bias)
        self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob, bias=bias)
        self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob, bias=bias)
        self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob)
        self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob)
        self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act)
        self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act)
        self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act, bias=bias)
    def forward(self, x):
        depth = x.shape[2]
        print(x.shape)
        out16 = self.in_tr(x)
        print(out16.shape)
        out32 = self.down_tr32(out16)
        print(out32.shape)
        out64 = self.down_tr64(out32)
        print(out64.shape)
        out128 = self.down_tr128(out64)
        print(out128.shape)
        out256 = self.down_tr256(out128)
        print(out256.shape)
        x = self.up_tr256(out256, out128)
        x = self.up_tr128(x, out64)
        x = self.up_tr64(x, out32)
        x = self.up_tr32(x, out16)
        x = self.out_tr(x)
        return x

In [None]:
import torch
import torch.nn as nn
batch_size = 10
img_shape = (3, 155, 240, 240)
img_tensor = torch.rand(batch_size, *img_shape)

In [None]:
model = VNet()
outputs = model(img_tensor)

torch.Size([10, 3, 155, 240, 240])
torch.Size([10, 3, 155, 240, 240])


In [None]:
model = VNet()
outputs = model(img_tensor)


torch.Size([1, 3, 155, 240, 240])
torch.Size([1, 3, 155, 240, 240])
intrans
torch.Size([1, 16, 155, 240, 240])
torch.Size([1, 15, 155, 240, 240])
torch.Size([1, 16, 155, 240, 240])
torch.Size([1, 32, 77, 120, 120])
torch.Size([1, 64, 38, 60, 60])
torch.Size([1, 128, 19, 30, 30])
torch.Size([1, 256, 9, 15, 15])
uptransition
torch.Size([1, 256, 9, 15, 15])
torch.Size([1, 128, 19, 30, 30])
torch.Size([1, 128, 18, 30, 30])
torch.Size([1, 128, 18, 30, 30])
torch.Size([1, 128, 18, 30, 30])
torch.Size([1, 128, 19, 30, 30])
uptransition
torch.Size([1, 256, 19, 30, 30])
torch.Size([1, 64, 38, 60, 60])
torch.Size([1, 64, 38, 60, 60])
torch.Size([1, 64, 38, 60, 60])
torch.Size([1, 64, 38, 60, 60])
torch.Size([1, 64, 38, 60, 60])
uptransition
torch.Size([1, 128, 38, 60, 60])
torch.Size([1, 32, 77, 120, 120])
torch.Size([1, 32, 76, 120, 120])
torch.Size([1, 32, 76, 120, 120])
torch.Size([1, 32, 76, 120, 120])
torch.Size([1, 32, 77, 120, 120])
uptransition
torch.Size([1, 64, 77, 120, 120])
torch.Siz

In [None]:
outputs.shape

torch.Size([1, 4, 155, 240, 240])