In [None]:
!pip install segmentation-models-pytorch

In [4]:
import torch
import torch.nn as nn
from segmentation_models_pytorch import Unet

import math 

class Block(nn.Module):
    """Convolutional block with 3D convolutions and max pooling."""
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        mid_channels = out_channels // 2
        self.conv_block = nn.Sequential(
            nn.Conv3d(in_channels, mid_channels, kernel_size=kernel_size, padding='same'),
            nn.ReLU(),
            nn.Conv3d(mid_channels, out_channels, kernel_size=kernel_size, padding='same'),
            nn.ReLU(),
            nn.MaxPool3d(2)
        )

    def forward(self, x):
        return self.conv_block(x)


class Classifier(nn.Module):
    """3D Convolutional Neural Network for LUNA16 dataset."""
    def __init__(self, in_channels: int = 1, out_channels: int = 64, kernel_size: int = 3, batch_norm: bool = False):
        super().__init__()
        self.tail = nn.BatchNorm3d(1) if batch_norm else nn.Identity()
        self.backbone = nn.Sequential(
            Block(in_channels, out_channels // 8),
            Block(out_channels // 8, out_channels // 4),
            Block(out_channels // 4, out_channels // 2),
            Block(out_channels // 2, out_channels)
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(3 * 3 * 2 * out_channels, 2)
        )
        self.init_weights()

    def init_weights(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv3d, nn.Linear)):
                nn.init.kaiming_normal_(
                    module.weight, 
                    a=0, 
                    mode='fan_out', 
                    nonlinearity='relu'
                )
                
                if module.bias is not None:
                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(module.weight)
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.uniform_(module.bias, -bound, bound)

    def forward(self, x):
        x = self.tail(x)
        x = self.backbone(x)
        out = self.head(x)
        prob = nn.functional.softmax(out.detach(), dim=1)
        return out, prob

In [14]:
class Segmenter(nn.Module):
    """3D Convolutional Neural Network for LUNA16 dataset."""
    def __init__(self, in_channels: int = 3, classes: int = 1, batch_norm: bool = False):
        super().__init__()
        self.tail = nn.BatchNorm3d(1) if batch_norm else nn.Identity()
        self.backbone = Unet(
            encoder_name="resnet34",        
            encoder_weights="imagenet",     
            in_channels=in_channels,                  
            classes=classes,                      
        )
        self.head = nn.Sigmoid()
        self.init_weights()

    def init_weights(self):
        for module in self.head.modules():
            if isinstance(module, (nn.Conv3d, nn.Linear)):
                nn.init.kaiming_normal_(
                    module.weight, 
                    a=0, 
                    mode='fan_out', 
                    nonlinearity='relu'
                )
                
                if module.bias is not None:
                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(module.weight)
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.uniform_(module.bias, -bound, bound)

    def forward(self, x):
        x = self.tail(x)
        x = self.backbone(x)
        out = self.head(x)
        return out