In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import random_split
import cv2
import pywt


# Helper function for loading TIFF images
def load_tiff_images(folder_path):
    images = []
    for file_name in sorted(os.listdir(folder_path)):
        if file_name.endswith(".tif") or file_name.endswith(".tiff"):
            img = Image.open(os.path.join(folder_path, file_name))
            images.append(np.array(img, dtype=np.float32))
    return images


In [2]:
class MRIDataset(Dataset):
    def __init__(self, undersampled_folder, ground_truth_folder, mask):
        self.undersampled_images = load_tiff_images(undersampled_folder)
        self.ground_truth_images = load_tiff_images(ground_truth_folder)
        self.mask = mask
        
        # Ensure consistent shapes by filtering
        self.filtered_undersampled_images = []
        self.filtered_ground_truth_images = []

        for undersampled, ground_truth in zip(self.undersampled_images, self.ground_truth_images):
            if ground_truth.shape == (640, 320) and undersampled.shape == (640, 320):
                self.filtered_undersampled_images.append(undersampled)
                self.filtered_ground_truth_images.append(ground_truth)

        print(f"Filtered dataset contains {len(self.filtered_undersampled_images)} samples.")

    def __len__(self):
        return len(self.filtered_undersampled_images)
    
    def __getitem__(self, idx):
        undersampled = self.filtered_undersampled_images[idx]
        ground_truth = self.filtered_ground_truth_images[idx]
        
        # Normalize images
        undersampled = undersampled / np.max(undersampled)
        ground_truth = ground_truth / np.max(ground_truth)
        
        return (
            torch.tensor(undersampled, dtype=torch.float32),  # Apply mask
            torch.tensor(ground_truth, dtype=torch.float32),
        )


In [3]:
def create_random_mask(height, width, center_fraction=0.1, undersample_fraction=0.1, seed=50):#at least 0.5
    """
    Create a 2D random mask with a fully sampled center region.

    Args:
        height: The height of the mask (frequency encoding dimension).
        width: The width of the mask (phase encoding dimension).
        center_fraction: Fraction of the image to be fully sampled at the center.
        undersample_fraction: Fraction of the remaining k-space to be sampled randomly.

    Returns:
        mask_2d: A binary mask with the same shape as the k-space data.
    """
    if seed is not None:
        np.random.seed(seed)  # Set the random seed locally
        mask = np.zeros((height, width))
    
    # Fully sample the center region
    center_height = int(height * center_fraction)
    center_width = int(width * center_fraction)
    center_start_h = (height - center_height) // 2
    center_start_w = (width - center_width) // 2
    mask[center_start_h:center_start_h + center_height, center_start_w:center_start_w + center_width] = 1
    
    # Randomly sample the remaining k-space
    remaining_mask = np.random.choice([0, 1], size=(height, width), p=[undersample_fraction, 1-undersample_fraction])
    mask = np.maximum(mask, remaining_mask)  # Ensure the center is fully sampled
    
    return mask

In [4]:
class DAE_4(nn.Module):
    def __init__(self):
        super(DAE_4, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [5]:
class dae_bottleneck(nn.Module):
    def __init__(self):
        super(dae_bottleneck, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        bottleneck = self.bottleneck(encoded)
        decoded = self.decoder(bottleneck)
        return decoded



In [6]:
class UNetDAE(nn.Module):
    def __init__(self, input_channels=1, base_filters=64):
        super(UNetDAE, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(input_channels, base_filters)
        self.enc2 = self.conv_block(base_filters, base_filters * 2)
        self.enc3 = self.conv_block(base_filters * 2, base_filters * 4)
        self.enc4 = self.conv_block(base_filters * 4, base_filters * 8)
        
        # Bottleneck
        self.bottleneck = self.conv_block(base_filters * 8, base_filters * 16)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(base_filters * 16, base_filters * 8, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(base_filters * 16, base_filters * 8)
        
        self.up3 = nn.ConvTranspose2d(base_filters * 8, base_filters * 4, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(base_filters * 8, base_filters * 4)
        
        self.up2 = nn.ConvTranspose2d(base_filters * 4, base_filters * 2, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(base_filters * 4, base_filters * 2)
        
        self.up1 = nn.ConvTranspose2d(base_filters * 2, base_filters, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(base_filters * 2, base_filters)
        
        # Final output layer
        self.final = nn.Conv2d(base_filters, input_channels, kernel_size=1)

        # Pooling
        self.pool = nn.MaxPool2d(2)
        
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))
        
        # Decoder with skip connections
        dec4 = self.dec4(torch.cat([self.up4(bottleneck), enc4], dim=1))
        dec3 = self.dec3(torch.cat([self.up3(dec4), enc3], dim=1))
        dec2 = self.dec2(torch.cat([self.up2(dec3), enc2], dim=1))
        dec1 = self.dec1(torch.cat([self.up1(dec2), enc1], dim=1))
        
        return self.final(dec1)


In [None]:

class VSQPReconstructionNetwork(nn.Module):
    def __init__(self, num_layers=5):
        super(VSQPReconstructionNetwork, self).__init__()
        self.num_layers = num_layers
        self.mu = nn.Parameter(torch.tensor(0.1))  # Learnable penalty parameter
        self.regularization_net = UNetDAE()  # Replace wavelet-based regularization with DAE

    def forward(self, x_init, y, mask):
        """
        x_init: Initial estimate of the image (image space)
        y: Undersampled k-space data (in image space, transformed to k-space)
        mask: Sampling mask (applied in k-space)
        """
        x = x_init
        y_k = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(y)))  # Transform y to k-space and apply fftshift

        for i in range(self.num_layers):
            # Regularization step using the DAE
            z = self.regularization_net(x.unsqueeze(1)).squeeze(1)

            # Data consistency step
            x_k = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(x)))  # Transform x to k-space and apply fftshift
            Ax = mask * x_k  # Apply k-space mask
            residual_k = y_k - Ax
            residual = torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(residual_k)))  # Inverse FT to image space
            residual = torch.abs(residual)
            # Combine regularized and data consistency updates
            x = self.mu * z + residual

        return x



def train_network(model, train_dataloader, val_dataloader, num_epochs, learning_rate, device):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        for undersampled, ground_truth in train_dataloader:
            undersampled = undersampled.to(device)
            ground_truth = ground_truth.to(device)           
            ground_truth_k = torch.fft.fftshift(torch.fft.fft2(ground_truth))  # Fully-sampled k-space
            
            optimizer.zero_grad()
            output = model(undersampled, undersampled, mask=torch.tensor(mask_2d, dtype=torch.float32).to(device))
            output_k = torch.fft.fftshift(torch.fft.fft2(output))  # Fully-sampled k-space
            
            ground_truth_k = ground_truth_k / torch.abs(ground_truth_k).max()
            output_k = output_k / torch.abs(output_k).max()

            loss = loss_fn(output, ground_truth)
            #loss = loss_fn(torch.log1p(output_k.abs()), torch.log1p(ground_truth_k.abs()))
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        train_loss /= len(train_dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {train_loss:.10f}")
        
        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for undersampled, ground_truth in val_dataloader:
                undersampled = undersampled.to(device)
                ground_truth = ground_truth.to(device)
                ground_truth_k = torch.fft.fftshift(torch.fft.fft2(ground_truth))  # Fully-sampled k-space               
                output = model(undersampled, undersampled, mask=torch.tensor(mask_2d, dtype=torch.float32).to(device))
                output_k = torch.fft.fftshift(torch.fft.fft2(output))  # Reconstructed k-space
                
                ground_truth_k = ground_truth_k / torch.abs(ground_truth_k).max()
                output_k = output_k / torch.abs(output_k).max()            
                
                loss = loss_fn(output, ground_truth)
                #loss = loss_fn(torch.log1p(output_k.abs()), torch.log1p(ground_truth_k.abs()))# 
                val_loss += loss.item()
        
        val_loss /= len(val_dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.10f}")

            
def save_results_as_numpy(undersampled, output, ground_truth, index, save_dir="results"):
    os.makedirs(save_dir, exist_ok=True)  # Create directory if it doesn't exist
    
    np.save(os.path.join(save_dir, f"undersampled_{index}.npy"), undersampled)
    np.save(os.path.join(save_dir, f"output_{index}.npy"), output)
    np.save(os.path.join(save_dir, f"ground_truth_{index}.npy"), ground_truth)
    #print(f"Results for sample {index} saved in {save_dir}")

            
def process_and_save_all_test_images(model, test_dataloader, device, save_dir="results"):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)  # Ensure the save directory exists
    
    with torch.no_grad():
        for i, (undersampled, ground_truth) in enumerate(test_dataloader):
            #print(f"Processing test sample {i}")

            # Move data to device
            undersampled = undersampled.to(device)
            ground_truth = ground_truth.to(device)

            # Forward pass
            output = model(undersampled, undersampled, mask=torch.tensor(mask_2d, dtype=torch.float32).to(device))

            # Convert tensors to numpy arrays
            undersampled_np = undersampled[0].cpu().numpy()  # Move to CPU and convert to NumPy
            output_np = output[0].cpu().numpy()  # Move to CPU and convert to NumPy
            ground_truth_np = ground_truth[0].cpu().numpy()  # Move to CPU and convert to NumPy

            # Save the results
            save_results_as_numpy(undersampled_np, output_np, ground_truth_np, index=i, save_dir=save_dir)

            #print(f"Saved test sample {i} to {save_dir}")
            
            # Free memory
            del undersampled, ground_truth, output
            torch.cuda.empty_cache()




if __name__ == "__main__":
        # Define paths
    undersampled_folder = r"D:\Class Project\209\brain_multicoil_train_batch_1\noisy_images_0.5"
    ground_truth_folder = r"D:\Class Project\209\brain_multicoil_train_batch_1\ground_truth_images"
    
    # Create mask
    height, width = 640, 320  # Example dimensions (adjust as needed)
    np.random.seed(50)
    mask_2d = create_random_mask(height, width, center_fraction=0.2, undersample_fraction=0.5, seed=50)# this need to match the mask used to generate data!!!
    
    # Define dataset
    dataset = MRIDataset(undersampled_folder, ground_truth_folder, mask_2d)
    
    # Split dataset
    train_size = int(0.7 * len(dataset))
    val_size = int(0.2 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    
    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # Initialize model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VSQPReconstructionNetwork(num_layers=5)
    
    # Train model
    train_network(model, train_dataloader, val_dataloader, num_epochs=20, learning_rate=0.001, device=device)
    
    # Test and visualize
    #test_network(model, test_dataloader, device=device)
    
    # Evaluate the model
    #evaluate_network(model, dataloader, device=device)
    process_and_save_all_test_images(model, test_dataloader, device=device, save_dir="results")
    print('complete!')



Filtered dataset contains 2049 samples.


RuntimeError: Input type (struct c10::complex<float>) and bias type (float) should be the same