In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import os
import h5py
import matplotlib.pyplot as plt

In [2]:
path=".\dataset\singlecoil_val"

In [3]:
def revert_mask(mask):
    return (mask - 1) * -1

def lh_pass_filter(ks,low_radius,high_radius):
  l_r = np.hypot(*ks.shape) / 2 * low_radius / 100
  h_r = np.hypot(*ks.shape) / 2 * high_radius / 100
  rows, cols = np.array(ks.shape, dtype=int)
  a, b = np.floor(np.array((rows, cols)) / 2).astype(np.int)
  y, x = np.ogrid[-a:rows - a, -b:cols - b]
  mask_h = x * x + y * y >= h_r * h_r
  mask_l = x * x + y * y <= l_r * l_r
  ks[mask_h] = 0
  ks[mask_l] = 0
  return ks,  revert_mask(mask_h | mask_l)

def image_from_k(slice_kspace):
    k_shift_kspace = np.fft.ifftshift(slice_kspace, axes=(-2, -1))  
    image_kspace = np.fft.ifft2(k_shift_kspace)  
    image_shift_kspace = np.fft.fftshift(image_kspace)  
    return np.abs(image_shift_kspace)


def to_tensor(data):
    """
    Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts
    are stacked along the last dimension. Noop if data is already a Pytorch tensor
    Args:
        data (np.array): Input numpy array
    Returns:
        torch.Tensor: PyTorch version of data
    """
    if isinstance(data, torch.Tensor):
        return data
    if np.iscomplexobj(data):
        data = np.stack((data.real, data.imag), axis=-1)
    return torch.from_numpy(data)


In [4]:
class CustomDataSet(Dataset):
    def __init__(self, main_dir,slice_n=0,freq_l=0,freq_r=20,transform = None,):
        self.main_dir = main_dir
        all_imgs = os.listdir(main_dir)
        self.total_imgs = all_imgs
        self.slice_n=slice_n
        self.freq_l=freq_l
        self.freq_r=freq_r
    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        
        fname = self.total_imgs[idx]
        s=str(self.main_dir)+str('\\')+str(fname)
        hf = h5py.File(s, 'r')
        known_freq = hf["kspace"][self.slice_n]
        known_image= image_from_k(known_freq)
        ks,mask = lh_pass_filter(known_freq,self.freq_l,self.freq_r)
        image=image_from_k(ks)
            
        sample = (image, known_freq, known_image, mask)

        return sample

In [5]:
my_dataset = CustomDataSet(path,slice_n=20)
val_loader = torch.utils.data.DataLoader(my_dataset , batch_size=1, shuffle=False)

In [7]:
# for image, known_freq, known_image, mask in val_loader:
#     print(image.shape,known_freq.shape, known_image.shape, mask.shape)
#     plt.imshow(mask[0,:,:])