<a href="https://colab.research.google.com/github/chardave/BEng-Research-Project/blob/main/RAKI_test2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#mount to google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#import print_function
import tensorflow as tf
from tensorflow.keras import Model, layers, optimizers
import scipy.io as sio
import numpy as np
#from tensorflow.examples.tutorials.mnist import input_data
import time
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import pandas as pd

In [3]:
import pandas as pd
import numpy as np

# Filepaths for the Excel files
k_sampled_file = "/content/drive/MyDrive/MEng group project/imagescan_data_test2.xlsx"
ref_scan_file = "/content/drive/MyDrive/MEng group project/refscan_data_test2.xlsx"

# Number of coils and slices
num_coils = 4
num_slices = 4

# Function to load and reshape complex data
def load_complex_data_3d(file, prefix):
    # Initialize a dictionary to store 3D matrices for each slice
    slice_data = {}

    for slice_idx in range(1, num_slices + 1):
        # Create a 3D matrix for this slice (x, y, coil)
        slice_matrix = []

        for coil in range(1, num_coils + 1):
            # Construct sheet names
            real_sheet = f"real_{prefix}_c{coil}_s{slice_idx}"
            imag_sheet = f"imag_{prefix}_c{coil}_s{slice_idx}"

            # Read the real and imaginary parts
            real_part = pd.read_excel(file, sheet_name=real_sheet, header=None).to_numpy()
            imag_part = pd.read_excel(file, sheet_name=imag_sheet, header=None).to_numpy()

            # Combine into a complex matrix and append to the slice
            coil_matrix = real_part + 1j * imag_part
            slice_matrix.append(coil_matrix)

        # Stack along the third dimension (coil axis)
        slice_data[f"slice_{slice_idx}"] = np.stack(slice_matrix, axis=-1)

    return slice_data

# Load k-sampled and reference data into 3D matrices
k_sampled_slices = load_complex_data_3d(k_sampled_file, "imagedat")
ref_scan_slices = load_complex_data_3d(ref_scan_file, "refdat")

# Example access: k_sampled_slices["slice_1"] or ref_scan_slices["slice_4"]
# Printing the shape of one of the loaded slices
print(f"Shape of k-sampled slice 1: {k_sampled_slices['slice_1'].shape}")
print(f"Shape of ref-scan slice 4: {ref_scan_slices['slice_4'].shape}")

# Output: Each key corresponds to a 192x133x4 complex matrix.

Shape of k-sampled slice 1: (192, 133, 4)
Shape of ref-scan slice 4: (128, 50, 4)


define functions to get the weight and bias and convolution

In [4]:
# weight initialisation function
def weight_variable(shape,vari_name):
    initial = tf.random.normal(shape, mean=0.0, stddev=0.1,dtype=tf.float32)
    return tf.Variable(initial,name = vari_name)
    # neural network weights need to start with random values before training,
    # which helps in breaking symmetry and ensuring the model learns meanigful features

# bias initialisation function
def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape,dtype=tf.float32)
    return tf.Variable(initial)
    # used to shift the output of the neurons and provide more flexibility to the model
    # initialising bias with a small positive value helps avod issues like 'dying RELU'
    # where all neurons are stuck in the zero state due to negative biases

# standard 2D convolution function
def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')
    # strides defines how te filter moves across the input tensor - moves by 1 in both height
    # width dimensions with no padding in the batch and channel dimensions
    # padding = 'Valid' - output size will be smaller than the input size

# dilated 2D convolution function
def conv2d_dilate(x, W,dilate_rate):
    return tf.nn.convolution(x, W, dilations = [1,dilate_rate])
    # increases receptive field of the convolutional filter without
    # increasing the number of parameters


In [5]:
# Define the learning function
def learning(ACS_input, target_input, accrate_input):
    # Define input dimensions based on the inputs
    ACS_dim_X, ACS_dim_Y, ACS_dim_Z = ACS_input.shape[1:]
    target_dim_X, target_dim_Y, target_dim_Z = target_input.shape[1:]
    # Ensure target and ACS have matching shapes
    # by cropping/padding as necessary.

    diff_X = ACS_dim_X - target_dim_X
    diff_Y = ACS_dim_Y - target_dim_Y

    # Handle X dimension
    if diff_X > 0:  # ACS is larger, crop ACS
        start_X = diff_X // 2
        end_X = start_X + target_dim_X
        ACS_input = ACS_input[:, start_X:end_X, :, :]
    elif diff_X < 0:  # Target is larger, pad ACS
        pad_X = abs(diff_X) // 2
        ACS_input = np.pad(ACS_input, ((0, 0), (pad_X, pad_X), (0, 0), (0, 0)), mode='constant')

    # Handle Y dimension
    if diff_Y > 0:  # ACS is larger, crop ACS
        start_Y = diff_Y // 2
        end_Y = start_Y + target_dim_Y
        ACS_input = ACS_input[:, :, start_Y:end_Y, :]
    elif diff_Y < 0:  # Target is larger, pad ACS
        pad_Y = abs(diff_Y) // 2
        ACS_input = np.pad(ACS_input, ((0, 0), (0, 0), (pad_Y, pad_Y), (0, 0)), mode='constant')

    # Update dimensions after cropping/padding
    ACS_dim_X, ACS_dim_Y, ACS_dim_Z = ACS_input.shape[1:]

    # Build the model
    inputs = tf.keras.Input(shape=(ACS_dim_X, ACS_dim_Y, ACS_dim_Z))

    # Convolutional Layers
    conv1 = layers.Conv2D(
        filters=layer1_channels,
        kernel_size=(kernel_x_1, kernel_y_1),
        activation='relu',
        padding='same',
        dilation_rate=(1, accrate_input)
    )(inputs)

    conv2 = layers.Conv2D(
        filters=layer2_channels,
        kernel_size=(kernel_x_2, kernel_y_2),
        activation='relu',
        padding='same',
        dilation_rate=(1, accrate_input)
    )(conv1)

    conv3 = layers.Conv2D(
        filters=target_dim_Z,
        kernel_size=(kernel_last_x, kernel_last_y),
        activation=None,
        padding='same',
        dilation_rate=(1, accrate_input)
    )(conv2)

    # Define the model
    model = Model(inputs=inputs, outputs=conv3)

    # Compile the model
    optimizer = optimizers.Adam(learning_rate=LearningRate)
    model.compile(optimizer=optimizer, loss='mse')

    # Train the model
    history = model.fit(
        x=ACS_input,
        y=target_input,
        batch_size=1,
        epochs=MaxIteration,
        verbose=1
    )

    # Get the trained weights
    W_conv1, W_conv2, W_conv3 = [layer.get_weights()[0] for layer in model.layers if 'conv2d' in layer.name]

    # Final error (loss) after training
    error = history.history['loss'][-1]

    return [W_conv1, W_conv2, W_conv3, error]

Define the network architecture

In [18]:
def cnn_3layer(input_kspace,w1,b1,w2,b2,w3,b3,acc_rate):
    h_conv1 = tf.nn.relu(conv2d_dilate(input_kspace, w1,acc_rate))
    h_conv2 = tf.nn.relu(conv2d_dilate(h_conv1, w2,acc_rate))
    h_conv3 = conv2d_dilate(h_conv2, w3,acc_rate)
    return h_conv3

In [30]:
###################### Reconstruction Parameters ######################

#### Network Parameters ####
kernel_x_1 = 3
kernel_y_1 = 3

kernel_x_2 = 3
kernel_y_2 = 3

kernel_last_x = 1
kernel_last_y = 1

layer1_channels = 16
layer2_channels = 8

MaxIteration = 1
LearningRate = 1e-3

#### Input/Output Data ####
#inputData = 'rawdata.mat'
#input_variable_name = 'kspace'
slices = list(k_sampled_slices.keys()) # list of slice keys
resultName = 'RAKI_recon'
recon_variable_name = 'kspace_recon'

######################################################################

for slice_name in slices:
  print(f"Processing {slice_name}...")

  # Extract the slice data
  kspace = k_sampled_slices[slice_name]


  # normalise the k-space data
  normalise = 0.015/np.max(abs(kspace[:]))
  kspace = np.multiply(kspace,normalise)
  #ACS = np.multiply(ACS, normalise)

  # save dimension of the kspace data and number of coils -
  # should give m1=192, n1=133 and no_ch=4
  [m1, n1, no_ch] = np.shape(kspace)
  no_inds = 1
  kspace_all = kspace;
  kx = np.transpose(np.int32([(range(1,m1+1))]))
  ky = np.int32([(range(1,n1+1))])

  # create a mask to identify columns in the k-space data that have non-zero values
  kspace = np.copy(kspace_all)
  mask = np.squeeze(np.sum(np.sum(np.abs(kspace),0),1))>0;
  picks = np.where(mask == 1);
  kspace = kspace[:,np.int32(picks[0][0]):n1+1,:]
  kspace_all = kspace_all[:,np.int32(picks[0][0]):n1+1,:]

  # preserve a copy of the original data
  kspace_NEVER_TOUCH = np.copy(kspace_all)

  # WHAT IS HAPPENING HERE
  mask = np.squeeze(np.sum(np.sum(np.abs(kspace),0),1))>0;
  picks = np.where(mask == 1);
  d_picks = np.diff(picks,1)
  indic = np.where(d_picks == 1);

  # I STILL DONT KNOW WHAT IS HAPPENING HERE
  mask_x = np.squeeze(np.sum(np.sum(np.abs(kspace),2),1))>0;
  picks_x = np.where(mask_x == 1);
  x_start = picks_x[0][0]
  x_end = picks_x[0][-1]

  # now process the ACS/Refscan
  ACS = ref_scan_slices[slice_name]
  [ACS_dim_X, ACS_dim_Y, ACS_dim_Z] = np.shape(ACS)
  ACS_re = np.zeros([ACS_dim_X,ACS_dim_Y,ACS_dim_Z*2])
  ACS_re[:,:,0:no_ch] = np.real(ACS)
  ACS_re[:,:,no_ch:no_ch*2] = np.imag(ACS)

  acc_rate = d_picks[0][0]
  no_channels = ACS_dim_Z*2

  # save kernel and network parameters in a .mat file
  name_weight = resultName + ('_weight_%d%d,%d%d,%d%d_%d,%d.mat' % (kernel_x_1,kernel_y_1,kernel_x_2,kernel_y_2,kernel_last_x,kernel_last_y,layer1_channels,layer2_channels))
  name_image = resultName + ('_image_%d%d,%d%d,%d%d_%d,%d.mat' % (kernel_x_1,kernel_y_1,kernel_x_2,kernel_y_2,kernel_last_x,kernel_last_y,layer1_channels,layer2_channels))

  existFlag = os.path.isfile(name_image)

  # initialises 3D convolutional kernels - [1, 2, 3, 4, 5] - 1 and 2 are spatial size kernel_x, kernel_y, 3 is input channels, 4 is output channels for layers 1 and 2,
  # 5 maintains input channels for multicoil data - number of output valies for reconstruction

  # first layer
  w1_all = np.zeros([kernel_x_1, kernel_y_1, no_channels, layer1_channels, no_channels],dtype=np.float32)
  # second layer
  w2_all = np.zeros([kernel_x_2, kernel_y_2, layer1_channels,layer2_channels,no_channels],dtype=np.float32)
  # third and last layer
  w3_all = np.zeros([kernel_last_x, kernel_last_y, layer2_channels,acc_rate - 1, no_channels],dtype=np.float32)

  # bias initialisation - flags control to include biases in the layers
  b1_flag = 0
  b2_flag = 0
  b3_flag = 0

  if (b1_flag == 1):
      b1_all = np.zeros([1,1, layer1_channels,no_channels]);
  else:
      b1 = []

  if (b2_flag == 1):
      b2_all = np.zeros([1,1, layer2_channels,no_channels])
  else:
      b2 = []

  if (b3_flag == 1):
      b3_all = np.zeros([1,1, layer3_channels, no_channels])
  else:
      b3 = []

  #target region in ACS for learning and reconstruction
  target_x_start = np.int32((np.ceil(kernel_x_1/2)-1) + (np.ceil(kernel_x_2/2)-1) + (np.ceil(kernel_last_x/2)-1)) * acc_rate
  target_x_end = np.int32(ACS_dim_X - target_x_start -1)

  time_ALL_start = time.time()

  [ACS_dim_X, ACS_dim_Y, ACS_dim_Z] = np.shape(ACS_re)
  ACS = np.reshape(ACS_re, [1,ACS_dim_X, ACS_dim_Y, ACS_dim_Z])
  ACS = np.float32(ACS)

  target_y_start = np.int32((np.ceil(kernel_y_1/2)-1) + (np.ceil(kernel_y_2/2)-1) + (np.ceil(kernel_last_y/2)-1)) * acc_rate
  target_y_end = ACS_dim_Y  - np.int32((np.floor(kernel_y_1/2) + (np.floor(kernel_y_2/2)) + np.floor(kernel_last_y/2))) * acc_rate -1

  target_dim_X = target_x_end - target_x_start + 1
  target_dim_Y = target_y_end - target_y_start + 1

  target_dim_Z = acc_rate - 1

  ## TESTING ##

  print('go!')
  time_Learn_start = time.time()

  errorSum = 0

  if tf.compat.v1.get_default_session():
    tf.compat.v1.keras.backend.clear_session()
  config = tf.compat.v1.ConfigProto()

  for ind_c in range (ACS_dim_Z):
    # process each channel of the ACS data separately
    # initialises a tensorflow session with specified configurations
      # sess = tf.compat.v1.Session(config=config) - dont need to do this
      # set target lines
      target = np.zeros([1,target_dim_X,target_dim_Y,target_dim_Z]) # target_dim_Z depth corresponds to acc_rate-1
                                                                    # number of interpolation targets between sampled points
                                                                    # outputs array filled with relevant ACS data for each undersampled slice
      print('learning channel #',ind_c+1)
      time_channel_start = time.time()

      # loops through each undersampled slice in the y-direction to define the part of the target that corresponds to that acceleration index
      for ind_acc in range(acc_rate-1):
          # adds/subtracts contributions from the kernel sizes to avoid boundary issues and multiplies by acc_rate to align with the undersampling pattern
          target_y_start = np.int32((np.ceil(kernel_y_1/2)-1) + (np.ceil(kernel_y_2/2)-1) + (np.ceil(kernel_last_y/2)-1)) * acc_rate + ind_acc + 1
          target_y_end = ACS_dim_Y  - np.int32((np.floor(kernel_y_1/2) + (np.floor(kernel_y_2/2)) + np.floor(kernel_last_y/2))) * acc_rate + ind_acc
          # assigns the corresponding region of th ACS for the current slice to the target array
          target[0,:,:,ind_acc] = ACS[0,target_x_start:target_x_end + 1, target_y_start:target_y_end +1,ind_c];
          #print(target.shape)
          #print(ACS.shape)

      # learning - outputs trained convolutional kernels for each layer
      [w1,w2,w3,error] = learning(ACS,target,acc_rate)

      w1_all[:,:,:,:,ind_c] = w1
      w2_all[:,:,:,:,ind_c] = w2
      w3_all[:,:,:,:,ind_c] = w3
      time_channel_end = time.time()
      print('Time cost: ', time_channel_end-time_channel_start,'s')
      print('Norm of Error = ',error)
      errorSum = errorSum + error

  time_Learn_end = time.time()
  print('learning step costs: ', (time_Learn_end - time_Learn_start)/60, 'min')

  sio.savemat(name_weight, {'w1': w1_all, 'w2': w2_all, 'w3': w3_all})

  # initialise copies of full k-space data to store reconstructed k-space data
  kspace_recon_all = np.copy(kspace_all)
  kspace_recon_all_nocenter = np.copy(kspace_all)

  kspace = np.copy(kspace_all)

  # identify undersampled k-space indices based on acceleration rate
  # use setdiff1d to find which indices in picks are not covered by the undersampling pattern
  over_samp = np.setdiff1d(picks,np.int32([range(0, n1,acc_rate)]))
  kspace_und = kspace
  kspace_und[:,over_samp,:]=0
  [dim_kspaceUnd_X,dim_kspaceUnd_Y,dim_kspaceUnd_Z] = np.shape(kspace_und)

  # prepare the k-space for neural network input
  # neural networks require real-valued input so real and imaginary parts are separated
  kspace_und_re = np.zeros([dim_kspaceUnd_X, dim_kspaceUnd_Y, dim_kspaceUnd_Z*2])
  kspace_und_re[:,:,0:dim_kspaceUnd_Z] = np.real(kspace_und)
  kspace_und_re[:,:,dim_kspaceUnd_Z:dim_kspaceUnd_Z*2] = np.imag(kspace_und)

  # reshape tensor to 4D - 1st dimension is batch size (1)
  kspace_und_re = np.float32(kspace_und_re)
  kspace_und_re = np.reshape(kspace_und_re,[1,dim_kspaceUnd_X,dim_kspaceUnd_Y,dim_kspaceUnd_Z*2])
  kspace_recon = kspace_und_re

  ## start reconstruction

  if tf.compat.v1.get_default_session():
    tf.compat.v1.keras.backend.clear_session()
  config = tf.compat.v1.ConfigProto()
  config.gpu_options.per_process_gpu_memory_fraction = 1/3 ;

  for ind_c in range(0, no_channels):
    print('Reconstructing Channel #',ind_c+1)

    #sess = tf.compat.v1.Session(config=config)
    #if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
    #    init = tf.initialize_all_variables()
    #else:
    #  init = tf.compat.v1.global_variables_initializer()
    #sess.run(init)

    # grab w and b
    w1 = np.float32(w1_all[:,:,:,:,ind_c])
    w2 = np.float32(w2_all[:,:,:,:,ind_c])
    w3 = np.float32(w3_all[:,:,:,:,ind_c])

    if (b1_flag == 1):
        b1 = b1_all[:,:,:,ind_c]
    if (b2_flag == 1):
        b2 = b2_all[:,:,:,ind_c]
    if (b3_flag == 1):
        b3 = b3_all[:,:,:,ind_c]

    # perform 3-layer convolutional reconstruction
    res = cnn_3layer(kspace_und_re,w1,b1,w2,b2,w3,b3,acc_rate)

    # assign reconstructed data to kspace_recon
    #target_x_end_kspace = dim_kspaceUnd_X - target_x_start
    #target_x_end_kspace = target_x_end - target_x_start + 1  # Change calculation
    target_x_end_kspace = target_x_start + res[0,:,::acc_rate,ind_acc].shape[0]
    print(target_x_end_kspace)
    # reconstruct k-space data by filling in the missing values (undersampled)
    # with the reconstructed values (res)

    #for ind_acc in range(0,acc_rate-1):

     #   target_y_start = np.int32((np.ceil(kernel_y_1/2)-1) + np.int32((np.ceil(kernel_y_2/2)-1)) + np.int32(np.ceil(kernel_last_y/2)-1)) * acc_rate + ind_acc + 1;
     #   target_y_end_kspace = dim_kspaceUnd_Y - np.int32((np.floor(kernel_y_1/2)) + (np.floor(kernel_y_2/2)) + np.floor(kernel_last_y/2)) * acc_rate + ind_acc;
     #   kspace_recon[0,target_x_start:target_x_end_kspace,target_y_start:target_y_end_kspace+1:acc_rate,ind_c] = res[0,:,::acc_rate,ind_acc]

    for ind_acc in range(0,acc_rate-1):

      target_y_start = np.int32((np.ceil(kernel_y_1/2)-1) + np.int32((np.ceil(kernel_y_2/2)-1)) + np.int32(np.ceil(kernel_last_y/2)-1)) * acc_rate + ind_acc + 1;
      target_y_end_kspace = dim_kspaceUnd_Y - np.int32((np.floor(kernel_y_1/2)) + (np.floor(kernel_y_2/2)) + np.floor(kernel_last_y/2)) * acc_rate + ind_acc;

      # Get the shape of the target slice
      target_shape = kspace_recon[0,target_x_start:target_x_end_kspace,target_y_start:target_y_end_kspace+1:acc_rate,ind_c].shape

      # Reshape or slice the 'res' array to match the target shape
      res_slice = res[0,:,::acc_rate,ind_acc][:target_shape[0], :target_shape[1]]

      kspace_recon[0,target_x_start:target_x_end_kspace,target_y_start:target_y_end_kspace+1:acc_rate,ind_c] = res_slice

  # remove batch dimension
  kspace_recon = np.squeeze(kspace_recon)

  # combine real and imaginary parts to reconstruct complex k-space data
  kspace_recon_complex = (kspace_recon[:,:,0:np.int32(no_channels/2)] + np.multiply(kspace_recon[:,:,np.int32(no_channels/2):no_channels],1j))
  kspace_recon_all_nocenter[:,:,:] = np.copy(kspace_recon_complex);

  for sli in range(0,no_ch):
      kspace_recon_all[:,:,sli] = np.fft.ifft2(kspace_recon_all[:,:,sli])

  rssq = (np.sum(np.abs(kspace_recon_all)**2,2)**(0.5))
  sio.savemat(name_image,{recon_variable_name:kspace_recon_complex})

  time_ALL_end = time.time()
  print('All process costs ',(time_ALL_end-time_ALL_start)/60,'mins')
  print('Error Average in Training is ',errorSum/no_channels)














Processing slice_1...
go!
learning channel # 1
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 884ms/step - loss: 2.9308e-10
Time cost:  0.9602491855621338 s
Norm of Error =  2.9307586912885597e-10
learning channel # 2
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 886ms/step - loss: 9.8179e-11
Time cost:  0.9608359336853027 s
Norm of Error =  9.817913149134938e-11
learning channel # 3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 858ms/step - loss: 1.0829e-10
Time cost:  0.9328868389129639 s
Norm of Error =  1.0828673374652098e-10
learning channel # 4
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - loss: 3.7321e-10
Time cost:  1.0995700359344482 s
Norm of Error =  3.732148201596175e-10
learning channel # 5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - loss: 2.6097e-10
Time cost:  1.3278372287750244 s
Norm of Error =  2.609736038383659e-10
learning channel # 6
[1m1/1[0m [32m━━━━━━━━━━━━━━━