In [1]:
import os
import random
from pathlib import Path
from typing import Union

import h5py
import numpy as np
import torch
from datasets import *
from data_utils import *
from fastmri.data.subsample import EquiSpacedMaskFunc, RandomMaskFunc
from fastmri.data.transforms import tensor_to_complex_np, to_tensor
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import config
from torch.utils.data import DataLoader

file_data = '/itet-stor/mcrespo/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train/file_brain_AXT1POST_203_6000861.h5'
dataset = KCoordDataset(file_data, n_slices=3, n_volumes=1, with_mask=True, acceleration=3, center_frac=0.15, mask_type='Random')




training w/o center


In [93]:
from fastmri.data.transforms import tensor_to_complex_np, to_tensor

vol_id = 0
file = file_data
n_volumes = 1
n_slices = 1
with_mask = False
acceleration = 3
center_frac = 0.15
mask_type = 'Random'


with h5py.File(file, "r") as hf:
    volume_kspace = to_tensor(preprocess_kspace(hf["kspace"][()]))[:n_slices]

##################################################
# Mask creation
##################################################
if mask_type == "Random":
    mask_func = RandomMaskFunc(
    center_fractions=[center_frac], accelerations=[acceleration]
)
elif mask_type == "Equispaced": 
    mask_func = EquiSpacedMaskFunc(
    center_fractions=[center_frac], accelerations=[acceleration])
    
shape = (1,) * len(volume_kspace.shape[:-3]) + tuple(
    volume_kspace.shape[-3:])
mask, _ = mask_func(
    shape, None, vol_id
)  # use the volume index as random seed.

# mask, left_idx, right_idx = remove_center(mask)
_, left_idx, right_idx = remove_center(mask)  # NOTE: Uncomment to include the center region in the training data. Note that 'left_idx' and 'right_idx' are still needed.

##################################################
# Computing the indices
##################################################
n_slices, n_coils, height, width = volume_kspace.shape[:-1]
if with_mask:
    kx_ids = torch.where(mask.squeeze())[0]
else:
    kx_ids = torch.arange(width)
    # kx_ids = torch.from_numpy(np.setdiff1d(np.arange(width), np.arange(left_idx, right_idx))) # NOTE: Uncomment to include all the datapoints (fully-sampled volume), with the exception of the center region.
ky_ids = torch.arange(height)
kz_ids = torch.arange(n_slices)
coil_ids = torch.arange(n_coils)

kspace_ids = torch.meshgrid(kx_ids, ky_ids, coil_ids, indexing="ij")
kspace_ids = torch.stack(kspace_ids, dim=-1).reshape(-1, len(kspace_ids))

##################################################
# Computing the inputs
##################################################
# Convert indices into normalized coordinates in [-1, 1].
kspace_coords = torch.zeros((kspace_ids.shape[0], 3), dtype=torch.float)
kspace_coords[:, 0] = (2 * kspace_ids[:, 0]) / (width - 1) - 1
kspace_coords[:, 1] = (2 * kspace_ids[:, 1]) / (height - 1) - 1
kspace_coords[:, 2] = (2 * kspace_ids[:, 2]) / (n_coils - 1) - 1
# kspace_coords[:, 3] = (2 * kspace_ids[:, 3]) / (n_coils - 1) - 1

In [86]:
left_idx
right_idx

184

In [98]:
not_center_idx = kx_ids[left_idx:right_idx]
mask_kx = torch.ones(kx_ids.shape)
mask_kx[not_center_idx] = 0

kx_new = kx_ids.clone()
kx_new = torch.where(mask_kx.squeeze())[0]

In [4]:
def normalize (data, norm_factor):
    """Function that normalizes a data matrix to the range [-1,1]"""
    n_data = (2*data) / (norm_factor - 1) - 1 
    return n_data

def denormalize (n_data, norm_factor):
    """Function that reverts a normalized data matrix to the original range, specified by norm_factor"""
    data = ((n_data + 1) * (norm_factor - 1))/2
    return data

def split_batch (data, size_minibatch):
    """Function that performs the random spliting of the dataloader batch into Ns subsets of generally the same size"""
    total_batch = data.shape[0]
    iter = total_batch//size_minibatch
    sample_batch = []
    last_idx = 0
    
    for i in range(iter):
        if i == 0:
            minibatch = data[:size_minibatch,...]
        elif i==iter-1:
            minibatch = data[last_idx+1:,...]
        else:
            minibatch = data[last_idx+1:last_idx+size_minibatch,...]
        
        sample_batch.append(minibatch)
        last_idx += size_minibatch
    return sample_batch, iter

def compute_Lsquares (X, Y, alpha):
    """Solves the Least Squares giving matrix W"""
    P_TxP = torch.matmul(X.T, X)
    P_TxT = torch.matmul(X.T, Y)

    reg = alpha * torch.eye(P_TxP.shape[0])
    W = torch.linalg.solve(P_TxP + reg, P_TxT)
    return W

def complex_distance(W1, W2):
    """
    Computes the L2 distance between two complex matrices W1 and W2.
    It compares both real and imaginary parts.
    """
    # Separate real and imaginary parts
    W1_real, W1_imag = W1.real, W1.imag
    W2_real, W2_imag = W2.real, W2.imag

    # Compute the squared differences for both real and imaginary parts
    real_diff = torch.norm(W1_real - W2_real) ** 2
    # print(real_diff)
    imag_diff = torch.norm(W1_imag - W2_imag) ** 2
    # print(imag_diff)

    # Return the combined distance
    return real_diff + imag_diff


def L_pisco (Ws):
    """Function to compute the Pisco loss
    Inputs:
    - Ws (list) : contains the corresponding Ws computed from Least squares
    
    """
    # Compare the Ws, obtain the Pisco loss
    total_loss = 0
    Ns = len(Ws)
    for i in range(Ns):
        for j in range(Ns):
            if i!=j:
                diff = Ws[i].flatten() - Ws[j].flatten()
                pisco = torch.linalg.norm(diff,ord=2)
                total_loss += pisco
                
    return (1/Ns**2) * total_loss

def get_grappa_matrixes (inputs, shape):
    """Function that generates two matrixes out of the input coordinates of the batch points     
    - n_r_kcoors : normalized and reshaped matrix containing the kspace coordinates 
        dim -> (Nm x Nc x 4)
    - n_r_patch : normalized and reshaped matrix containing the kspace coordinates of the neighbourhood for each point in first matrix
        dim -> (Nm·Nn x Nc x 4)
    """
    n_slices, n_coils, height, width = shape
    k_coors = torch.zeros((inputs.shape[0], 4), dtype=torch.float)
    k_coors[:,0] = denormalize(inputs[:,0], width)
    k_coors[:,1] = denormalize(inputs[:,1], height)
    k_coors[:,2] = denormalize(inputs[:,2], n_slices)
    k_coors[:,3] = denormalize(inputs[:,3], n_coils)

    #### Reshape:
    # Reshape input matrixes for coilID to be considered dim : n_points x N_coils x 4
    r_kcoors = np.repeat(k_coors[:, np.newaxis, :], n_coils, axis=1)
    r_kcoors[...,-1] = torch.arange(n_coils)

    
    ##### Reshape patches matrix to : n_points x n_neighbours x N_coils x 4
    build_neighbours = get_patch()
    patch_coors = build_neighbours(r_kcoors)
    
    # Reshape so that dim : n_points x N_n x Nc x 4 (kx,ky,kz, n_coils coordinates)
    r_patch = torch.zeros((patch_coors.shape[0],patch_coors.shape[1], r_kcoors.shape[2]))
    r_patch[...,:3] = patch_coors
    r_patch = np.repeat(r_patch[:, :, np.newaxis], n_coils, axis=2)
    r_patch[...,-1] = torch.arange(n_coils)

    ### For predicting, normalize coordinates back to [-1,1]
    # Normalize the NP neighbourhood coordinates
    n_r_patch = torch.zeros((r_patch.shape), dtype=torch.float)
    n_r_patch[:,:,:,0] = normalize(r_patch[:,:,:,0], width)
    n_r_patch[:,:,:,1] = normalize(r_patch[:,:,:,1], height)
    n_r_patch[:,:,:,2] = normalize(r_patch[:,:,:,2], n_slices)
    n_r_patch[:,:,:,3] = normalize(r_patch[:,:,:,3], n_coils)
    # Flatten the first dimensions for the purpose of kvalue prediction
    Nn = n_r_patch.shape[1]
    n_r_patch = n_r_patch.view(-1, n_coils, 4)

    # Normalize the Nt targets coordinates
    n_r_koors = torch.zeros((r_kcoors.shape), dtype=torch.float)
    n_r_koors[:,:,0] = normalize(r_kcoors[:,:,0], width)
    n_r_koors[:,:,1] = normalize(r_kcoors[:,:,1], height)
    n_r_koors[:,:,2] = normalize(r_kcoors[:,:,2], n_slices)
    n_r_koors[:,:,3] = normalize(r_kcoors[:,:,3], n_coils)
    
    return n_r_koors, n_r_patch, Nn


class get_patch:
    def __init__(
        self, 
        width = 320,
        height = 320,
        patch_size=9, 
        ):
        
        self.width = width
        self.height = height
        self.patch_size = patch_size
        
        super().__init__()
    
    def forward(self, batch_coors: torch.Tensor) -> torch.Tensor:
        """Returns the 3x3 neighbors for all points in a batch.
        Inputs : 
        - batch_coors : matrix of dimension batch_size x 4 denormalized coordinates (kx,ky,kz,coilid)
        """
        
        shifts = torch.tensor([[-1, -1], [0, -1], [1, -1],
                [ -1, 0], [ 1, 0],
                [ -1, 1], [ 0, 1], [ 1, 1]], device=batch_coors.device)  

        # Extract kx, ky from k_coor
        kx = batch_coors[:,:,0][:,0].unsqueeze(1)  # shape: (batch_size, 1)
        ky = batch_coors[:,:,1][:,1].unsqueeze(1)  # shape: (batch_size, 1)
        kz = batch_coors[:,:,2][:,0].unsqueeze(1)  # shape: (batch_size, 1)
        
        # Compute all neighbor shifts at once (apply shifts to kx, ky)
        kx_neighbors = torch.clamp(kx + shifts[:, 0], 0, self.width - 1)
        ky_neighbors = torch.clamp(ky + shifts[:, 1], 0, self.height - 1)
        
        # Ouput of neighbors dim : batch_size x nneighbors x 3 coordinates (kx, ky, kz)
        neighbors = torch.stack([kx_neighbors, ky_neighbors, kz.repeat(1, self.patch_size-1)], dim=-1)
        return neighbors
    
    def __call__(self, batch_coors: torch.Tensor) -> torch.Tensor:
        return self.forward(batch_coors)
    

In [74]:
### Get the grid for computing PISCO 
dataloader = DataLoader(dataset, batch_size=120000, shuffle=True, pin_memory=False)
count = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_coils = 4
n_slices = 3
width = 320
height = 320

for inputs, targets in dataloader:
    inputs, targets = inputs.to(device), targets.to(device)
    count += 1
    if count > 1:
        break

k_coors = torch.zeros((inputs.shape[0], 4), dtype=torch.float)
k_coors[:,0] = denormalize(inputs[:,0], width)
k_coors[:,1] = denormalize(inputs[:,1], height)
k_coors[:,2] = denormalize(inputs[:,2], n_slices)
k_coors[:,3] = denormalize(inputs[:,3], n_coils)


# Remove the edges 
leftmost_vedge = (k_coors[:, 1] == 0)
rightmost_vedge = (k_coors[:, 1] == 319)
upmost_vedge = (k_coors[:, 0] == 0)
downmost_vedge = (k_coors[:, 0] == 319)

edges = leftmost_vedge | rightmost_vedge | upmost_vedge | downmost_vedge
k_nedge = k_coors[~edges]

# #### Reshape:
# # Reshape input matrixes for coilID to be considered dim : n_points x N_coils x 4
r_kcoors = np.repeat(k_nedge[:, np.newaxis, :], n_coils, axis=1)
r_kcoors[...,-1] = torch.arange(n_coils)

# ##### Reshape patches matrix to : n_points x n_neighbours x N_coils x 4
build_neighbours = get_patch()
patch_coors = build_neighbours(r_kcoors)

# Reshape so that dim : n_points x N_n x Nc x 4 (kx,ky,kz, n_coils coordinates)
r_patch = torch.zeros((patch_coors.shape[0],patch_coors.shape[1], r_kcoors.shape[2]))
r_patch[...,:3] = patch_coors
r_patch = np.repeat(r_patch[:, :, np.newaxis], n_coils, axis=2)
r_patch[...,-1] = torch.arange(n_coils)

### For predicting, normalize coordinates back to [-1,1]
# Normalize the NP neighbourhood coordinates
n_r_patch = torch.zeros((r_patch.shape), dtype=torch.float)
n_r_patch[:,:,:,0] = normalize(r_patch[:,:,:,0], width)
n_r_patch[:,:,:,1] = normalize(r_patch[:,:,:,1], height)
n_r_patch[:,:,:,2] = normalize(r_patch[:,:,:,2], n_slices)
n_r_patch[:,:,:,3] = normalize(r_patch[:,:,:,3], n_coils)
# Flatten the first dimensions for the purpose of kvalue prediction
Nn = n_r_patch.shape[1]
n_r_patch = n_r_patch.view(-1, n_coils, 4)

# Normalize the Nt targets coordinates
n_r_koors = torch.zeros((r_kcoors.shape), dtype=torch.float)
n_r_koors[:,:,0] = normalize(r_kcoors[:,:,0], width)
n_r_koors[:,:,1] = normalize(r_kcoors[:,:,1], height)
n_r_koors[:,:,2] = normalize(r_kcoors[:,:,2], n_slices)
n_r_koors[:,:,3] = normalize(r_kcoors[:,:,3], n_coils)

In [76]:
torch.min(r_patch)

tensor(0.)

In [5]:
from model import *
model = Siren()
size_minibatch = 1000

t_predicted = torch.zeros((n_r_koors.shape[0], n_coils), dtype=torch.complex64)
patch_predicted = torch.zeros((n_r_patch.shape[0], n_coils), dtype=torch.complex64)

for coil_id in range(n_coils):
    t_predicted[:,coil_id] = torch.view_as_complex(model(n_r_koors[:,coil_id,:]))
    patch_predicted[:,coil_id] = torch.view_as_complex(model(n_r_patch[:,coil_id,:]))

# # Reshape back the patches_matrix
patch_predicted = patch_predicted.view(n_r_koors.shape[0], Nn, n_coils)

# size_minibatch = 300
T_s, Ns = split_batch(t_predicted, size_minibatch)
P_s, Ns = split_batch(patch_predicted, size_minibatch)


# ######## Here compute the Lpisco

In [7]:
T_s[17]

tensor([[-0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j],
        [-0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j],
        [-0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j],
        ...,
        [-0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j],
        [-0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j],
        [-0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j, -0.0218+0.0151j]],
       grad_fn=<SliceBackward0>)

In [42]:
## L pisco
##################################
alpha = 1.e-4
Ws = []

# Generate the list of Ws from the subset of minibatches 
for i, t_s in enumerate(T_s):
    p_s = P_s[i]
    p_s = torch.flatten(p_s, start_dim=1)
    print()
    ws = compute_Lsquares(p_s, t_s, alpha)
    print(ws.shape)
    Ws.append(ws)


pisco_loss = L_pisco (Ws) # Ws is a list of Ws' from the minibatches

print(pisco_loss)


torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])

torch.Size([32, 4])
tensor(1.9872e-05, grad_fn=<MulBackward0>)


In [43]:
## Measure distortion in Ws
tensor_magnitudes = [torch.abs(tensor) for tensor in Ws]
stacked_tensors = torch.stack(tensor_magnitudes)
std_dev_across_tensors = torch.std(stacked_tensors, dim=0)
torch.norm(std_dev_across_tensors)

tensor(0.0020, grad_fn=<LinalgVectorNormBackward0>)