In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import random_split
import cv2
import pywt


# Helper function for loading TIFF images
def load_tiff_images(folder_path):
    images = []
    for file_name in sorted(os.listdir(folder_path)):
        if file_name.endswith(".tif") or file_name.endswith(".tiff"):
            img = Image.open(os.path.join(folder_path, file_name))
            images.append(np.array(img, dtype=np.float32))
    return images


In [2]:
def create_random_mask(height, width, center_fraction=0.1, undersample_fraction=0.1, seed=50):
    """
    Create a 2D random mask with a fully sampled center region.

    Args:
        height: The height of the mask (frequency encoding dimension).
        width: The width of the mask (phase encoding dimension).
        center_fraction: Fraction of the image to be fully sampled at the center.
        undersample_fraction: Fraction of the remaining k-space to be sampled randomly.

    Returns:
        mask_2d: A binary mask with the same shape as the k-space data.
    """
    if seed is not None:
        np.random.seed(seed)  # Set the random seed locally
        mask = np.zeros((height, width))
    
    # Fully sample the center region
    center_height = int(height * center_fraction)
    center_width = int(width * center_fraction)
    center_start_h = (height - center_height) // 2
    center_start_w = (width - center_width) // 2
    mask[center_start_h:center_start_h + center_height, center_start_w:center_start_w + center_width] = 1
    
    # Randomly sample the remaining k-space
    remaining_mask = np.random.choice([0, 1], size=(height, width), p=[undersample_fraction, 1-undersample_fraction])
    mask = np.maximum(mask, remaining_mask)  # Ensure the center is fully sampled
    
    return mask

In [3]:
class MRIDataset(Dataset):
    def __init__(self, undersampled_folder, ground_truth_folder, mask):
        self.undersampled_images = load_tiff_images(undersampled_folder)
        self.ground_truth_images = load_tiff_images(ground_truth_folder)
        self.mask = mask
        
        # Ensure consistent shapes by filtering
        self.filtered_undersampled_images = []
        self.filtered_ground_truth_images = []

        for undersampled, ground_truth in zip(self.undersampled_images, self.ground_truth_images):
            if ground_truth.shape == (640, 320) and undersampled.shape == (640, 320):
                self.filtered_undersampled_images.append(undersampled)
                self.filtered_ground_truth_images.append(ground_truth)

        print(f"Filtered dataset contains {len(self.filtered_undersampled_images)} samples.")

    def __len__(self):
        return len(self.filtered_undersampled_images)
    
    def __getitem__(self, idx):
        undersampled = self.filtered_undersampled_images[idx]
        ground_truth = self.filtered_ground_truth_images[idx]
        
        # Normalize images
        undersampled = undersampled / np.max(undersampled)
        ground_truth = ground_truth / np.max(ground_truth)
        
        return (
            torch.tensor(undersampled, dtype=torch.float32),  # Apply mask
            torch.tensor(ground_truth, dtype=torch.float32),
        )


In [4]:

class VSQPReconstructionNetwork(nn.Module):
    def __init__(self, num_layers=5):
        super(VSQPReconstructionNetwork, self).__init__()
        self.num_layers = num_layers
        self.mu = nn.Parameter(torch.tensor(0.1))  # Learnable penalty parameter
        
        # Use nn.ParameterList to store the regularization thresholds
        self.regularizers = nn.ParameterList([
            nn.Parameter(torch.tensor(0.1)) for _ in range(num_layers)  # Thresholds for soft-thresholding
        ])
    
    def forward(self, x_init, y, mask):
        """
        x_init: Initial estimate of the image (image space)
        y: Undersampled k-space data (in image space, transformed to k-space)
        mask: Sampling mask (applied in k-space)
        """
        x = x_init
        y_k = torch.fft.fftshift(torch.fft.fft2(y))  # Transform y to k-space and apply fftshift

        for i in range(self.num_layers):
            # Regularization step (proximal operator - soft-thresholding)
            z = self.apply_wavelet_soft_threshold(x, self.regularizers[i])

            # Data consistency step
            # Step 1: Forward model (apply mask in k-space)
            x_k = torch.fft.fftshift(torch.fft.fft2(x))  # Transform x to k-space and apply fftshift
            Ax = mask * x_k  # Apply k-space mask
            
            # Step 2: Residual in k-space
            residual_k = y_k - Ax
            
            # Step 3: Backproject residual (to image space)
            residual_k_shifted = mask * residual_k  # Apply mask again (if needed)
            residual = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(residual_k_shifted)).real)  # Inverse FT to image space

            # Combine regularized and data consistency updates
            x = self.mu * z + residual

        return x

    def apply_wavelet_soft_threshold(self, x, threshold):
        # Detach and move the tensor to CPU before converting to NumPy
        x_np = x.detach().cpu().numpy()
        
        # Wavelet decomposition
        coeffs = pywt.wavedec2(x_np, wavelet='haar')
        
        # Apply soft thresholding to coefficients
        thresholded_coeffs = [
            pywt.threshold(c, threshold.item(), mode='soft') if isinstance(c, np.ndarray) else c
            for c in coeffs
        ]
        
        # Wavelet reconstruction
        x_thresholded = pywt.waverec2(thresholded_coeffs, wavelet='haar')
        
        # Convert back to a PyTorch tensor
        return torch.tensor(x_thresholded, device=x.device, dtype=torch.float32)




def train_network(model, train_dataloader, val_dataloader, num_epochs, learning_rate, device, loss_ratio = 0.5):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        for undersampled, ground_truth in train_dataloader:
            undersampled = undersampled.to(device)
            ground_truth = ground_truth.to(device)           
            ground_truth_k = torch.fft.fftshift(torch.fft.fft2(ground_truth))  # Fully-sampled k-space
            
            optimizer.zero_grad()
            output = model(undersampled, undersampled, mask=torch.tensor(mask_2d, dtype=torch.float32).to(device))
            output_k = torch.fft.fftshift(torch.fft.fft2(output))  # Fully-sampled k-space
            
            ground_truth_k = ground_truth_k / torch.abs(ground_truth_k).max()
            output_k = output_k / torch.abs(output_k).max()

            loss = loss_fn(output, ground_truth)
            #loss = loss_fn(torch.log1p(output_k.abs()), torch.log1p(ground_truth_k.abs()))
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        train_loss /= len(train_dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {train_loss:.10f}")
        
        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for undersampled, ground_truth in val_dataloader:
                undersampled = undersampled.to(device)
                ground_truth = ground_truth.to(device)
                ground_truth_k = torch.fft.fftshift(torch.fft.fft2(ground_truth))  # Fully-sampled k-space               
                output = model(undersampled, undersampled, mask=torch.tensor(mask_2d, dtype=torch.float32).to(device))
                output_k = torch.fft.fftshift(torch.fft.fft2(output))  # Reconstructed k-space
                
                ground_truth_k = ground_truth_k / torch.abs(ground_truth_k).max()
                output_k = output_k / torch.abs(output_k).max()            
                
                loss = loss_fn(output, ground_truth)
            
                #loss = loss_fn(torch.log1p(output_k.abs()), torch.log1p(ground_truth_k.abs()))
                val_loss += loss.item()
        
        val_loss /= len(val_dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_loss:.10f}")

            
def save_results_as_numpy(undersampled, output, ground_truth, index, save_dir="results"):
    os.makedirs(save_dir, exist_ok=True)  # Create directory if it doesn't exist
    
    np.save(os.path.join(save_dir, f"undersampled_{index}.npy"), undersampled)
    np.save(os.path.join(save_dir, f"output_{index}.npy"), output)
    np.save(os.path.join(save_dir, f"ground_truth_{index}.npy"), ground_truth)
    #print(f"Results for sample {index} saved in {save_dir}")

            
def process_and_save_all_test_images(model, test_dataloader, device, save_dir="results"):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)  # Ensure the save directory exists
    
    with torch.no_grad():
        for i, (undersampled, ground_truth) in enumerate(test_dataloader):
            print(f"Processing test sample {i}")

            # Move data to device
            undersampled = undersampled.to(device)
            ground_truth = ground_truth.to(device)

            # Forward pass
            output = model(undersampled, undersampled, mask=torch.tensor(mask_2d, dtype=torch.float32).to(device))

            # Convert tensors to numpy arrays
            undersampled_np = undersampled[0].cpu().numpy()  # Move to CPU and convert to NumPy
            output_np = output[0].cpu().numpy()  # Move to CPU and convert to NumPy
            ground_truth_np = ground_truth[0].cpu().numpy()  # Move to CPU and convert to NumPy

            # Save the results
            save_results_as_numpy(undersampled_np, output_np, ground_truth_np, index=i, save_dir=save_dir)

            #print(f"Saved test sample {i} to {save_dir}")
            
            # Free memory
            del undersampled, ground_truth, output
            torch.cuda.empty_cache()




if __name__ == "__main__":
        # Define paths
    undersampled_folder = r"D:\Class Project\209\brain_multicoil_train_batch_1\noisy_images_0.1"
    ground_truth_folder = r"D:\Class Project\209\brain_multicoil_train_batch_1\ground_truth_images"
    
    # Create mask
    height, width = 640, 320  # Example dimensions (adjust as needed)
    np.random.seed(50)
    mask_2d = create_random_mask(height, width, center_fraction=0.1, undersample_fraction=0.3, seed=50)# this need to match the mask used to generate data!!!
    
    # Define dataset
    dataset = MRIDataset(undersampled_folder, ground_truth_folder, mask_2d)
    
    # Split dataset
    train_size = int(0.7 * len(dataset))
    val_size = int(0.2 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    
    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # Initialize model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VSQPReconstructionNetwork(num_layers=5)
    
    # Train model
    train_network(model, train_dataloader, val_dataloader, num_epochs=20, learning_rate=0.001, device=device)
    
    # Test and visualize
    #test_network(model, test_dataloader, device=device)
    
    # Evaluate the model
    #evaluate_network(model, dataloader, device=device)
    process_and_save_all_test_images(model, test_dataloader, device=device, save_dir="results")



Filtered dataset contains 2049 samples.
Epoch [1/20], Training Loss: 0.2774887845
Epoch [1/20], Validation Loss: 0.2263909994
Epoch [2/20], Training Loss: 0.2397465750
Epoch [2/20], Validation Loss: 0.2319755329
Epoch [3/20], Training Loss: 0.2439987736
Epoch [3/20], Validation Loss: 0.2325923015
Epoch [4/20], Training Loss: 0.2440936961
Epoch [4/20], Validation Loss: 0.2329993158
Epoch [5/20], Training Loss: 0.2443499697
Epoch [5/20], Validation Loss: 0.2326943576
Epoch [6/20], Training Loss: 0.2442412279
Epoch [6/20], Validation Loss: 0.2327189903
Epoch [7/20], Training Loss: 0.2442833703
Epoch [7/20], Validation Loss: 0.2328094497
Epoch [8/20], Training Loss: 0.2442770720
Epoch [8/20], Validation Loss: 0.2328458433
Epoch [9/20], Training Loss: 0.2444667960
Epoch [9/20], Validation Loss: 0.2331820394
Epoch [10/20], Training Loss: 0.2444678399
Epoch [10/20], Validation Loss: 0.2323507711


KeyboardInterrupt: 