In [1]:
from pre_processing_py import HandScanDataset2, transform, validation_transform, train_df, valid_df, training_data_dir, display_images
from timm.models.layers import DropPath, trunc_normal_
import numpy as np

/Users/eleanorbolton/Library/CloudStorage/OneDrive-UniversityofLeeds/CCP_MRI_image_subset
Sample 0: Image shape: torch.Size([1, 20, 512, 512]), Label: 1
Sample 1: Image shape: torch.Size([1, 20, 512, 512]), Label: 0
Sample 2: Image shape: torch.Size([1, 20, 512, 512]), Label: 0
Batch image shape: torch.Size([1, 1, 20, 512, 512])
Batch label shape: torch.Size([1])


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.utils.data import DataLoader, Dataset
import numpy as np


In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=1, patch_size=4, embed_dim=64):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, D/patch_size, H/patch_size, W/patch_size)
        x = rearrange(x, 'b c d h w -> b (d h w) c')  # Flatten to (B, N, embed_dim)
        return x


In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, mlp_ratio=4.0):
        super(TransformerBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.norm1(x + self.attn(x, x, x)[0])
        x = self.norm2(x + self.mlp(x))
        return x


In [5]:
class DisruptiveAutoencoder(nn.Module):
    def __init__(self, in_channels=1, patch_size=4, embed_dim=64, depth=4, num_heads=4, mlp_ratio=4.0):
        super(DisruptiveAutoencoder, self).__init__()
        
        # Encoder
        self.patch_embed = PatchEmbedding(in_channels, patch_size, embed_dim)
        self.encoder_layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)
        ])
        
        # Decoder
        self.decoder_layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth)
        ])
        self.reconstruction_head = nn.Linear(embed_dim, patch_size ** 3 * in_channels)

    def forward(self, x):
        original_shape = x.shape  # Save the original shape
        
        x = self.patch_embed(x)
        
        # Encoder
        for layer in self.encoder_layers:
            x = layer(x)
        
        # Decoder
        for layer in self.decoder_layers:
            x = layer(x)
        
        # Reconstruction
        x = self.reconstruction_head(x)
        
        # Reshape back to original dimensions before applying L1 loss
        # Assume d_patches=2, h_patches=64, w_patches=64 (or calculate as needed)
        d_patches, h_patches, w_patches = 2, 64, 64
        x = rearrange(x, 'b (d h w) (p1 p2 p3 c) -> b c (d p1) (h p2) (w p3)', d=d_patches, h=h_patches, w=w_patches, p1=4, p2=4, p3=4)
        
        # Upsample to the original shape
        x = F.interpolate(x, size=original_shape[2:], mode='trilinear', align_corners=False)
        
        return x




    def add_noise(self, x, noise_level=0.1):
        noise = torch.randn_like(x) * noise_level
        return x + noise

    def downsample(self, x, scale_factor=0.5):
        return F.interpolate(x, scale_factor=scale_factor, mode='trilinear', align_corners=False)

    def local_mask(self, x, mask_ratio=0.15):
        """
        Apply local masking by setting a percentage of channels to zero.

        Args:
            x: Input tensor with shape [batch_size, channels, depth, height, width].
            mask_ratio: Ratio of channels to be masked.

        Returns:
            Masked tensor with the same shape as input.
        """
        # Determine the shape of the input tensor
        b, c, d, h, w = x.shape  # Assume x is [batch_size, channels, depth, height, width]
        
        num_masked = int(mask_ratio * c)
        mask_indices = torch.randperm(c)[:num_masked]  # Randomly select indices of channels to mask

        x[:, mask_indices, :, :, :] = 0  # Mask the selected channels
        return x



    def compute_loss(self, reconstructed, original, zsim, zlabel, alpha=0.05):
        # Ensure the reconstructed output matches the original size
        reconstructed = F.interpolate(reconstructed, size=original.shape[2:], mode='trilinear', align_corners=False)

        # L1 reconstruction loss
        l1_loss = F.l1_loss(reconstructed, original)

        # Contrastive loss (LCMCL)
        bce_loss = F.binary_cross_entropy_with_logits(zsim, zlabel)
        contrastive_loss = alpha * bce_loss

        # Total loss
        total_loss = l1_loss + contrastive_loss
        return total_loss, l1_loss, contrastive_loss




In [11]:
batch_size = 1
# Initialize dataset and data loader
train_dataset = HandScanDataset2(labels_df=train_df, data_dir=training_data_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [12]:
# Initialize the dataset with the selected subjects
valid_dataset = HandScanDataset2(labels_df=valid_df, data_dir=training_data_dir, transform=validation_transform)
# Create a data loader for testing
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

In [8]:
# Select a few subjects (e.g., the first three subjects)
test_subjects_df = train_df.iloc[:3]

# Initialize the dataset with the selected subjects
test_dataset = HandScanDataset2(labels_df=test_subjects_df, data_dir=training_data_dir, transform=transform)

# Create a data loader for testing
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


In [9]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')
        self.activation_count = 0

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                self.activation_count += 1
                return True
        return False


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from einops import rearrange

def visualize_reconstruction(dae_model, patches, idx, device, num_slices=1):
    """
    Visualizes the original, tokenized patches, and reconstructed images.

    Args:
        dae_model (nn.Module): The trained autoencoder model.
        patches (torch.Tensor): The input patches.
        idx (int): Index of the image (used for title).
        device (torch.device): The device to which the model and data are moved.
        num_slices (int): Number of slices to visualize.
    """
    dae_model.eval()  # Set the model to evaluation mode
    
    # Move patches to the appropriate device
    patches = patches.to(device)

    with torch.no_grad():
        # Forward pass to get the intermediate and final outputs
        noisy_patches = dae_model.add_noise(patches)
        downsampled_patches = dae_model.downsample(noisy_patches)
        tokenized_patches = dae_model.patch_embed(downsampled_patches)  # Tokenization step
        masked_patches = dae_model.local_mask(downsampled_patches)
        reconstructed_patches = dae_model(masked_patches)

    # Convert tensors to numpy arrays for visualization
    original_image = patches.cpu().numpy()[0, 0, :, :, :]  # Assuming single-channel, 3D volume
    reconstructed_image = reconstructed_patches.cpu().numpy()[0, 0, :, :, :]

    # Reconstruct patches into a grid to visualize
    patch_grid = rearrange(tokenized_patches.cpu().numpy()[0], '(d h w) c -> d h w c', d=2, h=64, w=64)
    patch_grid_image = np.mean(patch_grid, axis=-1)[0]  # Visualizing the mean across the channels

    # Display the images
    plt.figure(figsize=(15, 5))

    for slice_idx in range(min(num_slices, original_image.shape[0])):
        plt.subplot(3, num_slices, slice_idx + 1)
        plt.imshow(original_image[slice_idx], cmap='gray')
        plt.title(f'Original Image {idx} - Slice {slice_idx}')
        plt.axis('off')

        plt.subplot(3, num_slices, num_slices + slice_idx + 1)
        plt.imshow(patch_grid_image, cmap='gray')
        plt.title(f'Patches {idx} - Slice {slice_idx}')
        plt.axis('off')

        plt.subplot(3, num_slices, 2 * num_slices + slice_idx + 1)
        plt.imshow(reconstructed_image[slice_idx], cmap='gray')
        plt.title(f'Reconstructed {idx} - Slice {slice_idx}')
        plt.axis('off')

    plt.show()


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Define the DisruptiveAutoencoder model (this should be done before the training loop)
dae_model = DisruptiveAutoencoder(
    in_channels=1,       # Input channel dimension, typically 1 for grayscale medical images
    patch_size=4,        # Size of each patch
    embed_dim=64,        # Embedding dimension size
    depth=4,             # Number of transformer layers
    num_heads=4,         # Number of attention heads
    mlp_ratio=4.0        # Ratio of MLP hidden layer dimension to embedding dimension
)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dae_model = dae_model.to(device)

# Optimizer and loss function
optimizer = torch.optim.Adam(dae_model.parameters(), lr=0.001)

# Number of epochs
num_epochs = 10

# Define the EarlyStopper
early_stopper = EarlyStopper(patience=3, min_delta=0.001)

# Training loop
for epoch in range(num_epochs):
    dae_model.train()  # Set model to training mode
    epoch_loss = 0
    for patches, labels in train_loader:
        patches = patches.to(device)  # Move patches to the same device as the model
        labels = labels.to(device)    # Move labels to the same device as the model

        optimizer.zero_grad()

        # Apply noise, downsampling, and local masking
        noisy_patches = dae_model.add_noise(patches)
        downsampled_patches = dae_model.downsample(noisy_patches)
        
        # Forward pass
        masked_patches = dae_model.local_mask(downsampled_patches)
        reconstructed_patches = dae_model(masked_patches)

        # Generate zsim and zlabel for contrastive learning
        zsim = torch.mm(reconstructed_patches.view(reconstructed_patches.size(0), -1), 
                        reconstructed_patches.view(reconstructed_patches.size(0), -1).T)
        zlabel = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()  # Binary label matrix for positive pairs

        # Compute combined loss
        total_loss, l1_loss, contrastive_loss = dae_model.compute_loss(reconstructed_patches, patches, zsim, zlabel)
        epoch_loss += total_loss.item()

        # Backpropagation
        total_loss.backward()
        optimizer.step()

    # Validation phase (or calculate validation loss)
    dae_model.eval()  # Set model to evaluation mode
    validation_loss = 0
    with torch.no_grad():
        for patches, labels in valid_loader:
            patches = patches.to(device)
            labels = labels.to(device)

            # Forward pass
            noisy_patches = dae_model.add_noise(patches)
            downsampled_patches = dae_model.downsample(noisy_patches)
            masked_patches = dae_model.local_mask(downsampled_patches)
            reconstructed_patches = dae_model(masked_patches)

            # Compute loss
            zsim = torch.mm(reconstructed_patches.view(reconstructed_patches.size(0), -1), 
                            reconstructed_patches.view(reconstructed_patches.size(0), -1).T)
            zlabel = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

            _, l1_loss, contrastive_loss = dae_model.compute_loss(reconstructed_patches, patches, zsim, zlabel)
            validation_loss += l1_loss.item()

    # Average validation loss
    validation_loss /= len(valid_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {validation_loss:.4f}')

    # Check early stopping
    if early_stopper.early_stop(validation_loss):
        print("Early stopping triggered.")
        break

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}, '
          f'L1 Loss: {l1_loss.item():.4f}, Contrastive Loss: {contrastive_loss.item():.4f}')


Error reading /Users/eleanorbolton/Library/CloudStorage/OneDrive-UniversityofLeeds/t1_vibe_we_hand_subset/CCP_393/f6a3f98dc1/t1_vibe_we/1.3.12.2.1107.5.2.36.40258.2016030308584524314805431.DCM: File is missing DICOM File Meta Information header or the 'DICM' prefix is missing from the header. Use force=True to force reading.
Error reading /Users/eleanorbolton/Library/CloudStorage/OneDrive-UniversityofLeeds/t1_vibe_we_hand_subset/CCP_181/983862e0ae/t1_vibe_we/1.3.12.2.1107.5.2.36.40258.201309201106223154520723.DCM: File is missing DICOM File Meta Information header or the 'DICM' prefix is missing from the header. Use force=True to force reading.
Epoch [1/10], Validation Loss: 7.5646
Epoch [1/10], Loss: 6.6615, L1 Loss: 8.2956, Contrastive Loss: 0.0000
Error reading /Users/eleanorbolton/Library/CloudStorage/OneDrive-UniversityofLeeds/t1_vibe_we_hand_subset/CCP_393/f6a3f98dc1/t1_vibe_we/1.3.12.2.1107.5.2.36.40258.2016030308584524314805431.DCM: File is missing DICOM File Meta Information h