In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Lambda, Add, LeakyReLU,  \
                                    MaxPooling2D, concatenate, UpSampling2D,\
                                    Multiply, ZeroPadding2D, Cropping2D,Activation
  

def fft_layer(image):
    # get real and imaginary portions
    real = Lambda(lambda image: image[:, :, :, 0])(image)
    imag = Lambda(lambda image: image[:, :, :, 1])(image)

    image_complex = tf.complex(real, imag)  # Make complex-valued tensor
    kspace_complex = tf.signal.fft2d(image_complex)

    # expand channels to tensorflow/keras format
    real = tf.expand_dims(tf.math.real(kspace_complex), -1)
    imag = tf.expand_dims(tf.math.imag(kspace_complex), -1)
    kspace = tf.concat([real, imag], -1)
    return kspace


def ifft_layer(kspace_2channel):
    #get real and imaginary portions
    real = Lambda(lambda kspace_2channel : kspace_2channel[:,:,:,0])(kspace_2channel)
    imag = Lambda(lambda kspace_2channel : kspace_2channel[:,:,:,1])(kspace_2channel)

    kspace_complex = tf.complex(real,imag) # Make complex-valued tensor
    image_complex = tf.signal.ifft2d(kspace_complex)

    # expand channels to tensorflow/keras format
    real = tf.expand_dims(tf.math.real(image_complex),-1)
    imag = tf.expand_dims(tf.math.imag(image_complex),-1)
    # generate 2-channel representation of image domain
    image_complex_2channel = tf.concat([real, imag], -1)
    return image_complex_2channel
def spatial_attention(input_feature):
    avg_pool = tf.reduce_mean(input_feature, axis=-1, keepdims=True)
    max_pool = tf.reduce_max(input_feature, axis=-1, keepdims=True)
    concat = tf.concat([avg_pool, max_pool], axis=-1)
    attention = Conv2D(1, kernel_size=7, padding='same', activation='sigmoid')(concat)
    return Multiply()([input_feature, attention])

def cnn_block(cnn_input, error_map, depth, nf, kshape, channels):
    """
    CNN block with error propagation and spatial attention after the first convolution layer
    :param cnn_input: Input layer to CNN block (image or k-space data).
    :param error_map: Error map from the previous block (None for the first block).
    :param depth: Number of layers for convolutional block.
    :param nf: Number of filters in convolution layers.
    :param kshape: Shape of the convolutional kernel.
    :param channels: Output channels (for the final 2-channel image).
    :return: Refined 2-channel reconstruction and the error map for the next block.
    """
    layers = [cnn_input]

    # If there is an error map, concatenate it with the input to guide the block
    if error_map is not None:
        # Concatenate the error map with the input data (channels last)
        layers.append(error_map)
    
    # Add first convolutional layer
    layers.append(Conv2D(nf, kshape, padding='same')(layers[-1]))
    layers.append(LeakyReLU(alpha=0.1)(layers[-1]))

    # Apply spatial attention after the first convolution layer
    # Here we apply spatial attention based on both the feature map and error map
    if error_map is not None:
        #attention_input = Add()([layers[-1], error_map]) # Combine the feature map and error map
        layers[-1] = spatial_attention(layers[-1])  # Apply spatial attention to highlight errors

    # Add subsequent convolution layers
    for ii in range(1, depth):  # We already processed the first layer
        layers.append(Conv2D(nf, kshape, padding='same')(layers[-1]))
        layers.append(LeakyReLU(alpha=0.1)(layers[-1]))

        # Apply spatial attention after specific layers (optional)
        if error_map is None:
            if ii == 3:
                layers[-1] = spatial_attention(layers[-1])

    # Final convolution layer to produce the 2-channel output
    final_conv = Conv2D(channels, (1, 1), activation='linear')(layers[-1])

    # Compute the error map: difference between the final output and ground truth (or input)
    rec1 = Add()([final_conv, cnn_input])  # Residual connection
    new_error_map = cnn_input - rec1  # Absolute error between prediction and ground truth
    
    return rec1, new_error_map


def unet_block(unet_input, kshape=(3, 3),channels = 2):
    """
    :param unet_input: Input layer
    :param kshape: Kernel size
    :return: 2-channel, complex reconstruction
    """

    conv1 = Conv2D(48, kshape, activation='relu', padding='same')(unet_input)
    conv1 = Conv2D(48, kshape, activation='relu', padding='same')(conv1)
    conv1 = Conv2D(48, kshape, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, kshape, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, kshape, activation='relu', padding='same')(conv2)
    conv2 = Conv2D(64, kshape, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, kshape, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, kshape, activation='relu', padding='same')(conv3)
    conv3 = Conv2D(128, kshape, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, kshape, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, kshape, activation='relu', padding='same')(conv4)
    conv4 = Conv2D(256, kshape, activation='relu', padding='same')(conv4)

    up1 = concatenate([UpSampling2D(size=(2, 2))(conv4), conv3], axis=-1)
    conv5 = Conv2D(128, kshape, activation='relu', padding='same')(up1)
    conv5 = Conv2D(128, kshape, activation='relu', padding='same')(conv5)
    conv5 = Conv2D(128, kshape, activation='relu', padding='same')(conv5)

    up2 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv2], axis=-1)
    conv6 = Conv2D(64, kshape, activation='relu', padding='same')(up2)
    conv6 = Conv2D(64, kshape, activation='relu', padding='same')(conv6)
    conv6 = Conv2D(64, kshape, activation='relu', padding='same')(conv6)

    up3 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv1], axis=-1)
    conv7 = Conv2D(48, kshape, activation='relu', padding='same')(up3)
    conv7 = Conv2D(48, kshape, activation='relu', padding='same')(conv7)
    conv7 = Conv2D(48, kshape, activation='relu', padding='same')(conv7)

    conv8 = Conv2D(channels, (1, 1), activation='linear')(conv7)
    out = Add()([conv8, unet_input])
    return out

def unet_block2(unet_input, kshape=(3, 3),channels = 2):
    """
    :param unet_input: Input layer
    :param kshape: Kernel size
    :return: 2-channel, complex reconstruction
    """

    conv1 = Conv2D(48, kshape, activation='relu', padding='same')(unet_input)
    conv1 = Conv2D(48, kshape, activation='relu', padding='same')(conv1)
    conv1 = Conv2D(48, kshape, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(96, kshape, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(96, kshape, activation='relu', padding='same')(conv2)
    conv2 = Conv2D(96, kshape, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(192, kshape, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(192, kshape, activation='relu', padding='same')(conv3)
    conv3 = Conv2D(192, kshape, activation='relu', padding='same')(conv3)
    
    up1 = concatenate([UpSampling2D(size=(2, 2))(conv3), conv2], axis=-1)
    conv4 = Conv2D(96, kshape, activation='relu', padding='same')(up1)
    conv4 = Conv2D(96, kshape, activation='relu', padding='same')(conv4)
    conv4 = Conv2D(96, kshape, activation='relu', padding='same')(conv4)

    up2 = concatenate([UpSampling2D(size=(2, 2))(conv4), conv1], axis=-1)
    conv5 = Conv2D(48, kshape, activation='relu', padding='same')(up2)
    conv5 = Conv2D(48, kshape, activation='relu', padding='same')(conv5)
    conv5 = Conv2D(48, kshape, activation='relu', padding='same')(conv5)

    conv6 = Conv2D(channels, (1, 1), activation='linear')(conv5)
    out = Add()([conv6, unet_input])
    return out

def DC_block(rec,mask,sampled_kspace,channels,kspace = False):
    """
    :param rec: Reconstructed data, can be k-space or image domain
    :param mask: undersampling mask
    :param sampled_kspace:
    :param kspace: Boolean, if true, the input is k-space, if false it is image domain
    :return: k-space after data consistency
    """

    if kspace:
        rec_kspace = rec
    else:
        rec_kspace = Lambda(fft_layer)(rec)
    mask = 1 - mask
    rec_kspace_dc =  Multiply()([rec_kspace,mask])
    rec_kspace_dc = Add()([rec_kspace_dc,sampled_kspace])
    return rec_kspace_dc

def deep_cascade_flat_unrolled(depth_str='iiiiii', H=256, W=256, depth=5, kshape=(3, 3), nf=48, channels=2):
    """
    Deep Cascade model that incorporates error propagation and spatial attention.
    :param depth_str: A string that determines the depth and domain of each subnetwork (image or k-space)
    :param H: Image height
    :param W: Image width
    :param kshape: Kernel size
    :param nf: Number of filters in each convolutional layer
    :return: Deep Cascade model
    """
    inputs = Input(shape=(H, W, channels))  # Input image (or k-space)
    mask = Input(shape=(H, W, channels))    # Undersampling mask
    layers = [inputs]
    kspace_flag = True  # Start in k-space domain
    
    error_map = None  # Initialize the error map as None

    for ii in depth_str:
        if ii == 'i':
            # Add IFFT if switching to image domain
            layers.append(Lambda(ifft_layer)(layers[-1]))
            kspace_flag = False
        
        # Add CNN block with error propagation and attention
        cnn_output, error_map = cnn_block(layers[-1], error_map, depth, nf, kshape, channels)
        layers.append(cnn_output)
        
        # Add DC block (data consistency block)
        dc_output = DC_block(layers[-1], mask, inputs, channels, kspace=kspace_flag)
        layers.append(dc_output)
        
        kspace_flag = True  # Switch back to k-space for the next block
    
    # Final output (transform back to image domain)
    out = Lambda(ifft_layer)(layers[-1])
    model = Model(inputs=[inputs, mask], outputs=out)
    return model


def deep_cascade_unet(depth_str='ki', H=218, W=170, Hpad = 3, Wpad = 3, kshape=(3, 3),channels = 22):

    inputs = Input(shape=(H,W,channels))
    mask = Input(shape=(H,W,channels))
    layers = [inputs]
    kspace_flag = True
    for ii in depth_str:
        
        if ii =='i':
            # Add IFFT
            layers.append(Lambda(ifft_layer)(layers[-1]))
            kspace_flag = False
        # Add CNN block
        layers.append(ZeroPadding2D(padding=(Hpad,Wpad))(layers[-1]))
        layers.append(unet_block(layers[-1], kshape,channels))
        layers.append(Cropping2D(cropping=(Hpad,Wpad))(layers[-1]))
        
        # Add DC block
        layers.append(DC_block(layers[-1],mask,inputs,channels,kspace=kspace_flag))
        kspace_flag = True
    out = Lambda(ifft_layer)(layers[-1])
    model = Model(inputs=[inputs,mask], outputs=out)
    return model