In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pydicom
import os
from PIL import Image
import glob
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [None]:
# import pre_processing 
from pre_processing import HandScanDataset2, transform, validation_transform, train_df, valid_df

In [2]:
class ConvBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock3D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class Encoder3D(nn.Module):
    def __init__(self, in_channels):
        super(Encoder3D, self).__init__()
        self.encoder = nn.Sequential(
            ConvBlock3D(in_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(32, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(64, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(128, 256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.encoder(x)

In [4]:
class Decoder3D(nn.Module):
    def __init__(self, in_channels):
        super(Decoder3D, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(in_channels, 128, kernel_size=2, stride=2),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(32, 1, kernel_size=2, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)

In [5]:
class MaskedAutoencoder3D(nn.Module):
    def __init__(self, in_channels):
        super(MaskedAutoencoder3D, self).__init__()
        self.encoder = Encoder3D(in_channels)
        self.decoder = Decoder3D(256)

    def forward(self, x, mask):
        x_masked = x * mask
        encoded = self.encoder(x_masked)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
def create_mask(shape, mask_ratio=0.7):
    mask = np.ones(shape)
    num_elements = shape[1] * shape[2] * shape[3]
    num_masked = int(mask_ratio * num_elements)
    indices = np.random.choice(num_elements, num_masked, replace=False)
    mask.reshape(-1)[indices] = 0
    return torch.tensor(mask, dtype=torch.float32)