In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pydicom
import os
from PIL import Image
import glob
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [2]:
class ConvBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock3D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class Encoder3D(nn.Module):
    def __init__(self, in_channels):
        super(Encoder3D, self).__init__()
        self.encoder = nn.Sequential(
            ConvBlock3D(in_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(32, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(64, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(128, 256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.encoder(x)