# Packages

In [None]:
# =============================
# Imports & GPU configuration
# =============================
import os

# Set before importing tensorflow
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Use CPU only; set to "0" or "1" to use GPU(s)

import numpy as np
import matplotlib.pyplot as plt
import h5py
import tensorflow as tf

from tensorflow.keras import Model, Input
from tensorflow.keras.layers import (
    Conv2D, Conv2DTranspose, MaxPooling2D, Dropout, BatchNormalization,
    ReLU, LeakyReLU, concatenate
)
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    Callback, ModelCheckpoint, LearningRateScheduler, CSVLogger
)

import einops


print("TF version:", tf.__version__)
print("Visible devices:", tf.config.list_physical_devices())


In [None]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())


# Slice Stacker

In [None]:
def stack_data_semi3d(data_set):
    data_set = einops.rearrange(data_set,"(subject slices) ...->subject slices ...", slices=13)
    data_set = np.lib.stride_tricks.sliding_window_view(data_set, 3, axis=1)
    data_set = einops.rearrange(data_set,"subject sample x y channel z -> (subject sample) x y z channel")
    return data_set

# Data Import

In [None]:
# =============================
# Training data
# =============================

with h5py.File("TrainData.mat", "r") as f:
    y_train = f["lvSaveDataInput"][:, :, :, :]
    x_train = f["lvLovalizerSave"][:, :, :, :]

# Move channel axis (1 → last)
x_train = np.moveaxis(x_train, 1, -1)
y_train = np.moveaxis(y_train, 1, -1)

print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)

# important for model building later on
input_shape  = x_train.shape[1:]   
output_shape = y_train.shape[1:]      

print("Input shape:", input_shape)
print("Output shape:", output_shape)

# =============================

x_train = stack_data_semi3d(x_train)
y_train = stack_data_semi3d(y_train)[...,1,:]

In [None]:
# =============================
# Validation data
# =============================

with h5py.File("ValData.mat", "r") as f:
    y_val = f["lvSaveDataInput"][:, :, :, :]
    x_val = f["lvLovalizerSave"][:, :, :, :]

# Move channel axis (1 → last)
x_val = np.moveaxis(x_val, 1, -1)
y_val = np.moveaxis(y_val, 1, -1)

print("x_val shape:", x_val.shape)
print("y_val shape:", y_val.shape)

# =============================

x_val = stack_data_semi3d(x_val)
y_val = stack_data_semi3d(y_val)[..., 1, :]

In [None]:
# =============================
# Testing data
# =============================

with h5py.File("TestData.mat", "r") as f:
    y_test = f["lvSaveDataInput"][:, :, :, :]
    x_test = f["lvLovalizerSave"][:, :, :, :]

# Move channel axis (1 → last)
x_test = np.moveaxis(x_test, 1, -1)
y_test = np.moveaxis(y_test, 1, -1)

print("x_test shape:", x_test.shape)
print("y_test shape:", y_test.shape)

# =============================

x_test = stack_data_semi3d(x_test)
y_test = stack_data_semi3d(y_test)[..., 1, :]

# Loss Function

In [None]:
def perpendicular_loss(target, prediction, eps=1e-8, l1factor=1.0, use_mask=False):
    """
    Perpendicular loss for complex-valued tensors:
    total = P + l1factor * L1

    target, prediction: complex tensors of same shape
    returns: scalar loss (mean over all elements, masked if use_mask=True)
    """

    # cross term = |target.real * pred.imag - target.imag * pred.real|
    t_real = tf.math.real(target)
    t_imag = tf.math.imag(target)
    p_real = tf.math.real(prediction)
    p_imag = tf.math.imag(prediction)

    cross = tf.abs(t_real * p_imag - t_imag * p_real)  # real, >=0

    # |prediction|
    abs_pred = tf.abs(prediction)
    abs_target = tf.abs(target)

    # perpendicular component (P_raw)
    ploss_raw = cross / (abs_pred + eps)

    # angle < 90° ?  (Re(target / prediction) > 0)
    ratio = target / prediction
    angle_smaller_90 = tf.math.real(ratio) > 0  # boolean

    # symmetric perpendicular loss P
    ploss = tf.where(angle_smaller_90,
                     ploss_raw,
                     2.0 * abs_target - ploss_raw)

    # L1 component
    l1loss = tf.abs(prediction - target)

    # P + L1
    total = ploss + l1factor * l1loss

    if use_mask:
        mask = abs_target > 1e-3  # boolean
        mask = tf.cast(mask, total.dtype)
        masked_sum = tf.reduce_sum(total * mask)
        mask_count = tf.reduce_sum(mask)
        loss = masked_sum / (mask_count + eps)
    else:
        loss = tf.reduce_mean(total)

    return loss


# Custom Learning-Rate Metric, Exponential Decay Scheduler, and Training Callback

In [None]:
def get_lr_metric(optimizer):
    def lr(y_true, y_pred):
       return optimizer.learning_rate# optimizer.lr
    return lr

def lr_scheduler(epoch, lr):
    decay_rate = 0.99944 
    decay_step = 1
    if epoch % decay_step == 0 and epoch:
        return lr * decay_rate
    return lr

class CustomCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Ensure logs is not None and contains the required keys
        if logs is None:
            logs = {}
        if "loss" in logs and "val_loss" in logs:
            print(
                "loss: {:.4e} - "
                "val_loss: {:.4e}".format(
                    logs["loss"], logs["val_loss"]
                )
            )
        else:
            print(f"Epoch {epoch} ended but logs are missing expected keys: {logs}")
            

# Double Convolutional Block 2D

In [None]:
def ConvBlock(n_filters, kernel_size=3, batchnorm=True):
    """
    Creates a reusable convolutional block as a Sequential model.
    """
    layers = [
        Conv2D(
            filters=n_filters, 
            kernel_size=(kernel_size, kernel_size),
            kernel_initializer=init,
            padding="same"
        )
    ]
    if batchnorm:
        layers.append(BatchNormalization())
    layers.append(ReLU())
    #layers.append(LeakyReLU(alpha=0.2))
    
    layers.append(
        Conv2D(
            filters=n_filters, 
            kernel_size=(kernel_size, kernel_size),
            kernel_initializer=init,
            padding="same"
        )
    )
    if batchnorm:
        layers.append(BatchNormalization())
    layers.append(ReLU())
    #layers.append(LeakyReLU(alpha=0.2))
    
    
    return Sequential(layers)

In [None]:
def C2D_BLock(input_tensor, n_filters, kernel_size=3, batchnorm=True):
    """
    Create a convolutional block with two Conv2D layers.
    
    Args:
        input_tensor: Input tensor from previous layer
        n_filters (int): Number of filters for convolution
        kernel_size (int): Size of convolution kernel (default: 3)
        batchnorm (bool): Whether to apply batch normalization (default: True)
    
    Returns:
        Output tensor after two convolution operations
    """
    # First convolution layer
    x = Conv2D(
        filters=n_filters, 
        kernel_size=(kernel_size, kernel_size), 
        kernel_initializer=init,
        padding="same"
    )(input_tensor)
    
    if batchnorm:
        x = BatchNormalization()(x)
    x = ReLU()(x)
   #x = LeakyReLU(alpha=0.2)(x)
  
    # Second convolution layer
    x = Conv2D(
        filters=n_filters, 
        kernel_size=(kernel_size, kernel_size), 
        kernel_initializer=init,
        padding="same"
    )(x)
    
    if batchnorm:
        x = BatchNormalization()(x)
    
    x = ReLU()(x)
   #x = LeakyReLU(alpha=0.2)(x)
    return x

# Encoder

In [None]:
def encoder_semi3D(x, convs, dropout):
    """Simple encoder with configurable iterations."""
    features = []
    for i,conv in enumerate(convs):
        # Convolution block
        down = conv(x)
        features.append(down)
        
        # Max pooling (except for last iteration - that's the bottleneck)
        if i < len(convs)-1:
            x = MaxPooling2D((2, 2))(down)
            # Dropout with increasing rate for deeper layers
            dropout_rate = dropout * (2 if i >= 2 else 1)
            x = Dropout(dropout_rate)(x)
        else:
            # Bottleneck dropout
            x = Dropout(dropout * 3)(down)
            features[-1] = x  # Update last feature with dropout
    
    return features

In [None]:
def multi_slice_encoder_semi3D(input_img, convolution_type, iterations, n_filters, dropout, batchnorm):
    """Process multi-slice input through encoder."""
    slices = tf.unstack(input_img, axis=-2)
    all_features = []
    print("filters", [n_filters*2**i for i in range(iterations)])
    convs = [convolution_type(n_filters*2**i, 3, batchnorm) for i in range(iterations)]
    for s in slices:
        # If slice is (batch, H, W) → make it (batch, H, W, 1)
        #if s.shape.rank == 3:
        #    s = tf.expand_dims(s, axis=-1)
        all_features.append(
            encoder_semi3D(s, convs, dropout)
        )

    combined_features = [
        Conv2D(n_filters * 2**depth, 1)(
            tf.concat([current_slice[depth] for current_slice in all_features], axis=-1)
        )
        for depth in range(iterations)
    ]
    print(combined_features)
    return combined_features

# Decoder

In [None]:
def decoder(bottleneck, skip_connections, convolution_type, transpose_conv_type, iterations, n_filters, dropout, batchnorm, heads=8):
    """Simple decoder with multiple heads."""
    outputs = []
    
    for head in range(heads):
        x = bottleneck
        
        for i in range(iterations):
            # Calculate filter size (decreasing: 8, 4, 2, 1)
            current_filters = n_filters * (2 ** (iterations - i - 1))
            
            # Upsampling with configurable transpose convolution
            x = transpose_conv_type(current_filters, (2, 2), strides=(2, 2), padding='same')(x)
            x = concatenate([x, skip_connections[i]])
            
            # Dropout (higher rate for first 2 iterations)
            dropout_rate = dropout * (2 if i < 2 else 1)
            x = Dropout(dropout_rate)(x)
            
            # Convolution block
            x = convolution_type(x, n_filters=current_filters, kernel_size=3, batchnorm=batchnorm)
        
        # Output layer
        output = Conv2D(2, (1, 1), activation='tanh', name=f"outputsCh{head+1}")(x)
        outputs.append(output)
    
    return outputs

# U Net

In [None]:
def define_unet_semi3D(input_img, n_filters=16, dropout=0.5, batchnorm=True):
    
    features = multi_slice_encoder_semi3D(
        input_img=input_img,
        convolution_type=ConvBlock,
        iterations=5,
        n_filters=n_filters,
        dropout=dropout,
        batchnorm=batchnorm,
    )
    down1, down2, down3, down4, down5 = features
    
    outputs = decoder(  # ← Use your new function with the fix
        bottleneck=down5,
        skip_connections=[down4, down3, down2, down1],
        convolution_type=C2D_BLock,
        transpose_conv_type=Conv2DTranspose,
        iterations=4,
        n_filters=n_filters,
        dropout=dropout,
        batchnorm=batchnorm,
        heads=8,
    )
    
    model = Model(inputs=[input_img], outputs=outputs)
    return model

# Initialize Model


In [None]:
# ============================================================
# Weight initialization
# ============================================================
# Small random initialization helps stabilize early training
init = RandomNormal(stddev=0.02)


# ============================================================
# Checkpointing & callbacks
# ============================================================
# Path where the best model weights will be stored
checkpoint_path = "checks.weights.h5"

# Save best weights (based on training loss) every 100 batches
checkpoint = ModelCheckpoint(
    filepath=checkpoint_path,
    monitor="loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="min",
    save_freq=100,   # NOTE: save frequency is in *batches*, not epochs
)

# Callback list:
# - ModelCheckpoint: saves best weights
# - LearningRateScheduler: applies custom LR decay
# - CustomCallback: prints loss / val_loss at epoch end
callbacks_list = [
    checkpoint,
    LearningRateScheduler(lr_scheduler),
    CustomCallback(),
]


# ============================================================
# Model definition
# ============================================================
# Input tensor
input_img = Input(input_shape)

# 2D U-Net architecture
model = define_unet_semi3D(
    input_img,
    n_filters=32,
    dropout=0.001,
    batchnorm=False,
)


# ============================================================
# Optimizer & metrics
# ============================================================
optimizer = Adam(
    learning_rate=1e-4,
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-7,
    clipnorm=1.0,    # Gradient clipping for training stability
    amsgrad=False,
)

# Expose current learning rate as a metric (for logging)
lr_metric = get_lr_metric(optimizer)


# ============================================================
# Compile model
# ============================================================
model.compile(
    optimizer=optimizer,
    loss="mean_squared_error",
    metrics=[
        "mse",       # Standard MSE metric
        lr_metric,   # Learning rate tracking
    ],
)

# Print model architecture
model.summary()


# Train Model

In [None]:
# ============================================================
# Model training
# ============================================================
history = model.fit(
    x=x_train,   # input localizer images
    y={
        # Each output head predicts one Tx channel (complex: real + imag)
        "outputsCh1": y_train[:, :, :,  0: 2],
        "outputsCh2": y_train[:, :, :,  2: 4],
        "outputsCh3": y_train[:, :, :,  4: 6],
        "outputsCh4": y_train[:, :, :,  6: 8],
        "outputsCh5": y_train[:, :, :,  8:10],
        "outputsCh6": y_train[:, :, :, 10:12],
        "outputsCh7": y_train[:, :, :, 12:14],
        "outputsCh8": y_train[:, :, :, 14:16],
    },
    validation_data=(
        x_val,     # validation inputs
        {
            "outputsCh1": y_val[:, :, :,  0: 2],
            "outputsCh2": y_val[:, :, :,  2: 4],
            "outputsCh3": y_val[:, :, :,  4: 6],
            "outputsCh4": y_val[:, :, :,  6: 8],
            "outputsCh5": y_val[:, :, :,  8:10],
            "outputsCh6": y_val[:, :, :, 10:12],
            "outputsCh7": y_val[:, :, :, 12:14],
            "outputsCh8": y_val[:, :, :, 14:16],
        },
    ),
    shuffle=True,          # shuffle samples each epoch
    epochs=4000,           # long training for convergence
    batch_size=1,          # slice-wise training
    callbacks=callbacks_list,
)


# Save model weights & training history

In [None]:
# Save trained model weights
model.save_weights('Model_Weights.h5', 'r')

In [None]:
# Save training history (losses, metrics, LR, etc.)
np.save('History', history.history)

# Inference on unseen test data

In [None]:
prediction = np.array(model.predict(x_test))


# Save complex-valued ground truth (unseen / validation data)


In [None]:
# ------------------------------------------------------------
# Ground-truth B1+ maps (complex-valued)
# ------------------------------------------------------------

# Reconstruct complex-valued ground truth from real/imag pairs
b1p_groundtruth_complex = y_val[..., 0::2] + 1j * y_val[..., 1::2]

# Save complex-valued ground truth for downstream analysis
np.save("b1p_gt_complex.npy", b1p_groundtruth_complex)

# Derived representations (channel-first for convenience)
b1p_gt_magnitude = np.moveaxis(np.abs(b1p_groundtruth_complex),   -1, 0)
b1p_gt_phase     = np.moveaxis(np.angle(b1p_groundtruth_complex), -1, 0)

In [None]:
# ------------------------------------------------------------
# Predicted B1+ maps (complex-valued)
# ------------------------------------------------------------

# Reconstruct complex-valued prediction from real/imag pairs

b1p_prediction_complex = prediction[..., 0] + 1j * prediction[..., 1]

# Save complex-valued prediction for downstream analysis
np.save("b1p_pr_complex.npy", b1p_prediction_complex)

# Derived representations (channel-first for convenience)
b1p_pr_magnitude = np.moveaxis(np.abs(b1p_prediction_complex),   -1, 0)
b1p_pr_phase     = np.moveaxis(np.angle(b1p_prediction_complex), -1, 0)
