In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Lambda, Add, LeakyReLU,  \
                                    MaxPooling2D, concatenate, UpSampling2D, Multiply, ZeroPadding2D, Cropping2D
             

In [2]:
%run ./Modules/layer.ipynb
%run ./Modules/activation.ipynb


Exception: File `'./Modules/layer.ipynb'` not found.

In [None]:

  

def fft_layer(image):
    # get real and imaginary portions
    real = Lambda(lambda image: image[:, :, :, 0])(image)
    imag = Lambda(lambda image: image[:, :, :, 1])(image)

    image_complex = tf.complex(real, imag)  # Make complex-valued tensor
    kspace_complex = tf.signal.fft2d(image_complex)

    # expand channels to tensorflow/keras format
    real = tf.expand_dims(tf.math.real(kspace_complex), -1)
    imag = tf.expand_dims(tf.math.imag(kspace_complex), -1)
    kspace = tf.concat([real, imag], -1)
    return kspace

def fft_layer2(real,imag):
    # get real and imaginary portions
    image_complex = tf.complex(real, imag)  # Make complex-valued tensor
    kspace_complex = tf.signal.fft2d(image_complex)

    # expand channels to tensorflow/keras format
    real = tf.math.real(kspace_complex)
    imag = tf.math.imag(kspace_complex)
    
    return real,imag


def ifft_layer(kspace_2channel):
    #get real and imaginary portions
    real = Lambda(lambda kspace_2channel : kspace_2channel[:,:,:,0])(kspace_2channel)
    imag = Lambda(lambda kspace_2channel : kspace_2channel[:,:,:,1])(kspace_2channel)

    kspace_complex = tf.complex(real,imag) # Make complex-valued tensor
    image_complex = tf.signal.ifft2d(kspace_complex)

    # expand channels to tensorflow/keras format
    real = tf.expand_dims(tf.math.real(image_complex),-1)
    imag = tf.expand_dims(tf.math.imag(image_complex),-1)
    # generate 2-channel representation of image domain
    image_complex_2channel = tf.concat([real, imag], -1)
    return image_complex_2channel
def ifft_layer2(real,imag):
    #get real and imaginary portions
    kspace_complex = tf.complex(real,imag) # Make complex-valued tensor
    image_complex = tf.signal.ifft2d(kspace_complex)

    # expand channels to tensorflow/keras format
    real = tf.math.real(image_complex)
    imag = tf.math.imag(image_complex)
    # generate 2-channel representation of image domain
    
    return real,imag


In [None]:
def DC_block(rec,mask,sampled_kspace,channels,kspace = False):
    """
    :param rec: Reconstructed data, can be k-space or image domain
    :param mask: undersampling mask
    :param sampled_kspace:
    :param kspace: Boolean, if true, the input is k-space, if false it is image domain
    :return: k-space after data consistency
    """

    if kspace:
        rec_kspace = rec
    else:
        rec_kspace = Lambda(fft_layer)(rec)
    mask = 1 - mask
    rec_kspace_dc =  Multiply()([rec_kspace,mask])
    rec_kspace_dc = Add()([rec_kspace_dc,sampled_kspace])
    return rec_kspace_dc

In [None]:
from tensorflow.keras.layers import Concatenate, Conv2D, Activation, Multiply
import tensorflow.keras.backend as K

def spatial_attention_complex(real, imag):
    """
    Applies spatial attention to complex feature maps.
    Attention is applied equally to both real and imaginary parts.

    :param real: Real part of feature map, shape [B, H, W, C]
    :param imag: Imaginary part of feature map, shape [B, H, W, C]
    :return: real_out, imag_out with spatial attention applied
    """
    # Combine real and imag via concatenation or magnitude
    combined = K.sqrt(K.square(real) + K.square(imag))  # shape [B, H, W, C]
    
    # Compute spatial attention (mean over channels)
    avg_pool = K.mean(combined, axis=-1, keepdims=True)  # [B, H, W, 1]
    max_pool = K.max(combined, axis=-1, keepdims=True)   # [B, H, W, 1]
    
    concat = Concatenate(axis=-1)([avg_pool, max_pool])  # [B, H, W, 2]
    
    # 7x7 convolution for spatial attention
    attention_map = Conv2D(filters=1, kernel_size=7, padding='same', activation='sigmoid')(concat)  # [B, H, W, 1]
    
    # Apply attention map to both real and imag
    real_out = Multiply()([real, attention_map])
    imag_out = Multiply()([imag, attention_map])
    
    return real_out, imag_out


In [None]:
from keras.layers import Conv2D, LeakyReLU, Add
from keras import backend as K
def cnn_block(cnn_input, nf, kshape, channels):
    """
    CNN block with fixed depth of 5.
    Applies 5 convolutional layers followed by a final 1x1 convolution.
    
    :param cnn_input: Input layer to CNN block
    :param nf: Number of filters in convolutional layers
    :param kshape: Shape of the convolutional kernel
    :param channels: Number of output channels (2 for real and imaginary)
    :return: 2-channel complex reconstruction with residual connection
    """
    
    # Separate real and imaginary parts
    real = cnn_input[..., 0]  # First channel (real part)
    imag = cnn_input[..., 1]  # Second channel (imaginary part)
    #print("real",real.shape)
    #print("imag",imag.shape)
    real = K.expand_dims(real, axis=-1)  # Shape becomes (None, 256, 256, 1)
    imag = K.expand_dims(imag, axis=-1)  # Shape becomes (None, 256, 256, 1)
    # Print the new shapes
    #print("real after", real.shape)  # Should print (None, 256, 256, 1)
    #print("imag after", imag.shape)


    # First convolution and activation
    real_conv1, imag_conv1 = complex_Conv2D(nf, kshape,  padding="same")(real,imag)
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    real_conv1, imag_conv1  = complex_Conv2D(nf, kshape, padding='same')(real_conv1, imag_conv1 )
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    real_conv1, imag_conv1 = complex_Conv2D(nf, kshape,  padding="same")(real_conv1,imag_conv1)
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    
    
    real_conv1, imag_conv1 = complex_Conv2D(nf, kshape,  padding="same")(real_conv1,imag_conv1)
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    
    real_conv1, imag_conv1 = complex_Conv2D(nf, kshape,  padding="same")(real_conv1,imag_conv1)
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    
    real_conv1, imag_conv1=spatial_attention_complex(real_conv1, imag_conv1)
    

    # Final 1x1 convolution to return to 2 channels
    real_fianl, imag_final = complex_Conv2D(2, (1, 1))(real_conv1, imag_conv1 )
    #print("in image domain:real_fianl",real_fianl.shape)

    # Add residual connection (input + CNN output)
    real_res1,imag_res1=add_with(real_fianl, imag_final,real,imag)
    res1 =concatenate([real_res1,imag_res1],axis=-1)
    
    return res1


In [None]:
from keras.layers import Conv2D, LeakyReLU, Add
from keras import backend as K
def cnn_block_kspace(cnn_input,nf, kshape, channels):
    """
    CNN block with fixed depth of 5.
    Applies 5 convolutional layers followed by a final 1x1 convolution.
    
    :param cnn_input: Input layer to CNN block
    :param nf: Number of filters in convolutional layers
    :param kshape: Shape of the convolutional kernel
    :param channels: Number of output channels (2 for real and imaginary)
    :return: 2-channel complex reconstruction with residual connection
    """
    
    # Separate real and imaginary parts
    real = cnn_input[..., 0]  # First channel (real part)
    imag = cnn_input[..., 1]  # Second channel (imaginary part)
    #print("real",real.shape)
    #print("imag",imag.shape)
    real = K.expand_dims(real, axis=-1)  # Shape becomes (None, 256, 256, 1)
    imag = K.expand_dims(imag, axis=-1)  # Shape becomes (None, 256, 256, 1)
    # Print the new shapes
    #print("real after", real.shape)  # Should print (None, 256, 256, 1)
    #print("imag after", imag.shape)


    # First convolution and activation
    real_conv1, imag_conv1 = complex_Conv2D(nf, kshape,  padding="same")(real,imag)
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    real_conv1, imag_conv1  = complex_Conv2D(nf, kshape, padding='same')(real_conv1, imag_conv1 )
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    real_conv1, imag_conv1 = complex_Conv2D(nf, kshape,  padding="same")(real_conv1,imag_conv1)
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    
    
    real_conv1, imag_conv1 = complex_Conv2D(nf, kshape,  padding="same")(real_conv1,imag_conv1)
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    
    
    real_conv1, imag_conv1 = complex_Conv2D(nf, kshape,  padding="same")(real_conv1,imag_conv1)
    real_conv1, imag_conv1 =CLeaky_ReLU(real_conv1, imag_conv1)
    real_conv1, imag_conv1=ifft_layer2( real_conv1, imag_conv1)
    
    
    real_conv1, imag_conv1=spatial_attention_complex(real_conv1, imag_conv1)
    #print("real_conv1",real_conv1.shape,"imag_conv1",imag_conv1.shape)
    real_conv1, imag_conv1=fft_layer2(real_conv1, imag_conv1)
    
    

    # Final 1x1 convolution to return to 2 channels
    
    real_fianl, imag_final = complex_Conv2D(2, (1, 1))(real_conv1, imag_conv1 )

    # Add residual connection (input + CNN output)
    #print("real_fianl",real_fianl.shape,"imag_final",imag_final.shape)
    real_res1,imag_res1=add_with(real_fianl, imag_final,real,imag)
    res1 =concatenate([real_res1,imag_res1],axis=-1)
    
    return res1


In [None]:
def deep_cascade_flat_unrolled(depth_str='ikikii', H=256, W=256, kshape=(3, 3), nf=48, channels=2):
    """
    Deep Cascade Flat Unrolled model with different CNNs for k-space and image-space.
    
    :param depth_str: String defining the sequence of IFFT ('i') and CNN in k-space ('k')
    :param H: Image height
    :param W: Image width
    :param kshape: Kernel size for CNN
    :param nf: Number of filters in CNN
    :param channels: Number of input/output channels (real + imaginary = 2)
    :return: Keras model
    """

    from keras.layers import Input, Lambda
    from keras.models import Model

    inputs = Input(shape=(H, W, channels))  # Input k-space data
    mask = Input(shape=(H, W, channels))    # Undersampling mask

    x = inputs
    kspace_flag = True  # Tracks if we are currently in k-space

    for ii in depth_str:
        if ii == 'i':
            x = Lambda(ifft_layer)(x)
            kspace_flag = False
            x = cnn_block(x, nf, kshape, channels)  # CNN in image domain

        elif ii == 'k':
            # If needed, can wrap with FFT here too
            # x = Lambda(fft_layer)(x) if not in k-space already
            kspace_flag = True
            x = cnn_block_kspace(x, nf, kshape, channels)  # CNN in k-space

        # Apply DC block after every CNN
        x = DC_block(x, mask, inputs, channels, kspace=kspace_flag)

    out = Lambda(ifft_layer)(x)  # Final IFFT to return to image domain
    model = Model(inputs=[inputs, mask], outputs=out)
    return model
