In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pydicom
import os
from PIL import Image
import glob
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd


In [7]:
# import pre_processing 
from pre_processing import HandScanDataset2, transform, validation_transform, train_df, valid_df

In [8]:
class ConvBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock3D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [9]:
class Encoder3D(nn.Module):
    def __init__(self, in_channels):
        super(Encoder3D, self).__init__()
        self.encoder = nn.Sequential(
            ConvBlock3D(in_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(32, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(64, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2),
            ConvBlock3D(128, 256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.encoder(x)

In [10]:
class Decoder3D(nn.Module):
    def __init__(self, in_channels):
        super(Decoder3D, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(in_channels, 128, kernel_size=2, stride=2),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(32, 1, kernel_size=2, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)

In [11]:
class MaskedAutoencoder3D(nn.Module):
    def __init__(self, in_channels):
        super(MaskedAutoencoder3D, self).__init__()
        self.encoder = Encoder3D(in_channels)
        self.decoder = Decoder3D(256)

    def forward(self, x, mask):
        x_masked = x * mask
        encoded = self.encoder(x_masked)
        decoded = self.decoder(encoded)
        return decoded

In [12]:
def create_mask(shape, mask_ratio=0.7):
    mask = np.ones(shape)
    num_elements = shape[1] * shape[2] * shape[3]
    num_masked = int(mask_ratio * num_elements)
    indices = np.random.choice(num_elements, num_masked, replace=False)
    mask.reshape(-1)[indices] = 0
    return torch.tensor(mask, dtype=torch.float32)

In [13]:
# Define the custom loss function
def custom_loss(output, target, alpha=0.8):
    mse_loss = nn.MSELoss()
    pixel_loss = mse_loss(output, target)
    
    # Compute the frequency domain representation
    output_freq = torch.fft.fftn(output)
    target_freq = torch.fft.fftn(target)
    
    # Use the magnitude of the complex numbers
    output_freq_mag = torch.abs(output_freq)
    target_freq_mag = torch.abs(target_freq)
    
    freq_loss = mse_loss(output_freq_mag, target_freq_mag)
    
    return alpha * pixel_loss + (1 - alpha) * freq_loss

In [14]:
def train_mae(model, train_loader, val_loader, num_epochs, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for images, _ in train_loader:
            mask = create_mask(images.shape)
            outputs = model(images, mask)
            loss = custom_loss(outputs, images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        val_loss = 0
        model.eval()
        with torch.no_grad():
            for images, _ in val_loader:
                mask = create_mask(images.shape)
                outputs = model(images, mask)
                loss = custom_loss(outputs, images)
                val_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}')


In [None]:
# Reading the CSV file
training_data_dir = "/Users/eleanorbolton/Library/CloudStorage/OneDrive-UniversityofLeeds/t1_vibe_we_hand_subset/" 
csv_path = os.path.join(training_data_dir, 'training_labels_subset.csv')
labels_df = pd.read_csv(csv_path)

# Split the data into training and validation sets
train_df, valid_df = train_test_split(labels_df, test_size=0.2, random_state=42, stratify=labels_df['progression'])

# Save the splits for reference
train_df.to_csv(os.path.join(training_data_dir, 'train_split.csv'), index=False)
valid_df.to_csv(os.path.join(training_data_dir, 'valid_split.csv'), index=False)

In [17]:
if __name__ == '__main__':
    # Reading the CSV file
    training_data_dir = "/Users/eleanorbolton/Library/CloudStorage/OneDrive-UniversityofLeeds/t1_vibe_we_hand_subset/" 
    csv_path = os.path.join(training_data_dir, 'training_labels_subset.csv')
    labels_df = pd.read_csv(csv_path)

    # Split the data into training and validation sets
    train_df, valid_df = train_test_split(labels_df, test_size=0.2, random_state=42, stratify=labels_df['progression'])

    # Save the splits for reference
    train_df.to_csv(os.path.join(training_data_dir, 'train_split.csv'), index=False)
    valid_df.to_csv(os.path.join(training_data_dir, 'valid_split.csv'), index=False)
    train_dataset = HandScanDataset2(labels_df=train_df, data_dir=training_data_dir, transform=transform)
    valid_dataset = HandScanDataset2(labels_df=valid_df, data_dir=training_data_dir, transform=validation_transform)

    # Creating data loaders
    batch_size = 4
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    
    model = MaskedAutoencoder3D(in_channels=1)
    train_mae(model, train_loader, valid_loader, num_epochs=50, learning_rate=0.001)

{'CCP_120': 1, 'CCP_117': 0, 'CCP_73': 0, 'CCP_386': 1, 'CCP_644': 1, 'CCP_NG_166': 0, 'CCP_NG_42': 1, 'CCP_138': 0, 'CCP_874': 0, 'CCP_907': 1, 'CCP_NG_137': 1, 'CCP_647': 1, 'CCP_262': 0, 'CCP_541': 0, 'CCP_89': 0, 'CCP_202': 0, 'CCP_NG_56': 0, 'CCP_821': 1, 'CCP_66': 0, 'CCP_471': 0, 'CCP_34': 0, 'CCP_181': 1, 'CCP_557': 1, 'CCP_416': 0, 'CCP_NG_207': 0, 'CCP_568': 0, 'CCP_753': 0, 'CCP_NG_181': 0, 'CCP_81': 0, 'CCP_415': 1, 'CCP_NG_8': 1, 'CCP_283': 1, 'CCP_906': 0, 'CCP_968': 0, 'CCP_664': 0, 'CCP_736': 0, 'CCP_355': 0, 'CCP_NG_104': 1, 'CCP_247': 1, 'CCP_NG_188': 1, 'CCP_976': 0, 'CCP_NG_214': 0, 'CCP_824': 1, 'CCP_62': 0, 'CCP_NG_36': 0, 'CCP_802': 0, 'CCP_NG_172': 0, 'CCP_873': 0, 'CCP_207': 0, 'CCP_167': 1, 'CCP_944': 1, 'CCP_133': 1, 'CCP_NG_60': 0, 'CCP_1000': 0, 'CCP_252': 1, 'CCP_672': 0, 'CCP_531': 0, 'CCP_NG_175': 1, 'CCP_NG_107': 1, 'CCP_53': 0, 'CCP_NG_106': 1, 'CCP_105': 0, 'CCP_507': 1, 'CCP_405': 0, 'CCP_172': 1, 'CCP_901': 1, 'CCP_212': 1, 'CCP_485': 0, 'CCP_185': 

RuntimeError: "mse_cpu" not implemented for 'ComplexFloat'