# Packages

In [None]:
# =============================
# Core
# =============================
import numpy as np
import pandas as pd
import h5py

# =============================
# Plotting
# =============================
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Optional: if you actually use these
# import matplotlib.cm as cm
# import seaborn as sns

# =============================
# TensorFlow / Keras  (use tf.keras only)
# =============================
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import regularizers
from tensorflow.keras.callbacks import Callback, LearningRateScheduler, ModelCheckpoint
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import (
    Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose,
    BatchNormalization, Dropout, Flatten, Reshape, Concatenate, Add,
    Activation, LeakyReLU
)
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.utils import plot_model

# =============================
# Sklearn
# =============================
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# =============================
# Stats / Metrics
# =============================
from scipy.stats import pearsonr
from skimage.metrics import structural_similarity as compare_ssim

# =============================
# Jupyter widgets
# =============================
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display

# =============================
# Datashader (only if used)
# =============================
# import datashader as ds
# from datashader.mpl_ext import dsshow


In [None]:
# =============================
# Training data
# =============================

with h5py.File("TrainData.mat", "r") as f:
    y_train = f["lvSaveDataInput"][:, :, :, :]
    x_train = f["lvLovalizerSave"][:, :, :, :]

# Move channel axis (1 → last)
x_train = np.moveaxis(x_train, 1, -1)
y_train = np.moveaxis(y_train, 1, -1)

print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)

# important for model building later on
input_shape  = x_train.shape[1:]   
output_shape = y_train.shape[1:]      

print("Input shape:", input_shape)
print("Output shape:", output_shape)

# =============================

In [None]:
# =============================
# Validation data
# =============================

with h5py.File("ValData.mat", "r") as f:
    y_val = f["lvSaveDataInput"][:, :, :, :]
    x_val = f["lvLovalizerSave"][:, :, :, :]

# Move channel axis (1 → last)
x_val = np.moveaxis(x_val, 1, -1)
y_val = np.moveaxis(y_val, 1, -1)

print("x_val shape:", x_val.shape)
print("y_val shape:", y_val.shape)

# =============================

# Double Convolutional Block 2D

In [None]:
def C2D_BLock(x, n_filters, kernel_size=3, batchnorm=True):
    x = Conv2D(n_filters, (kernel_size, kernel_size), padding="same", kernel_initializer=init)(x)
    if batchnorm:
        x = BatchNormalization()(x)
    x = ReLU()(x)
    # LeakyReLU uses a small slope (alpha) for x < 0 instead of zeroing negatives
    #x = LeakyReLU(alpha=0.2)(x)
 

    x = Conv2D(n_filters, (kernel_size, kernel_size), padding="same", kernel_initializer=init)(x)
    if batchnorm:
        x = BatchNormalization()(x)
    x = ReLU()(x)
    # LeakyReLU uses a small slope (alpha) for x < 0 instead of zeroing negatives
    #x = LeakyReLU(alpha=0.2)(x)


    return x

# Encoder

In [None]:
def encoder(input_img, convolution_type, iterations, n_filters, dropout, batchnorm):
    """Simple encoder with configurable iterations."""
    features = []
    x = input_img
    
    for i in range(iterations):
        # Convolution block
        down = convolution_type(x, n_filters=n_filters*(2**i), kernel_size=3, batchnorm=batchnorm)
        features.append(down)
        
        # Max pooling (except for last iteration - that's the bottleneck)
        if i < iterations - 1:
            x = MaxPooling2D((2, 2))(down)
            # Dropout with increasing rate for deeper layers
            dropout_rate = dropout * (2 if i >= 2 else 1)
            x = Dropout(dropout_rate)(x)
        else:
            # Bottleneck dropout
            x = Dropout(dropout * 3)(down)
            features[-1] = x  # Update last feature with dropout
    
    return features

# Decoder

In [None]:
def decoder(bottleneck, skip_connections, convolution_type, transpose_conv_type, iterations, n_filters, dropout, batchnorm, heads=8):
    """Simple decoder with multiple heads."""
    outputs = []
    
    for head in range(heads):
        x = bottleneck
        
        for i in range(iterations):
            # Calculate filter size (decreasing: 8, 4, 2, 1)
            current_filters = n_filters * (2 ** (iterations - i - 1))
            
            # Upsampling with configurable transpose convolution
            x = transpose_conv_type(current_filters, (2, 2), strides=(2, 2), padding='same')(x)
            x = concatenate([x, skip_connections[i]])
            
            # Dropout (higher rate for first 2 iterations)
            dropout_rate = dropout * (2 if i < 2 else 1)
            x = Dropout(dropout_rate)(x)
            
            # Convolution block
            x = convolution_type(x, n_filters=current_filters, kernel_size=3, batchnorm=batchnorm)
        
        # Output layer
        output = Conv2D(2, (1, 1), activation='tanh', name=f"outputsCh{head+1}")(x)
        outputs.append(output)
    
    return outputs

# U Net

In [None]:
def define_unet2D(input_img, n_filters=16, dropout=0.5, batchnorm=True):
    
    # --- Single-slice encoder (no multi-slice fusion) ---
    features = encoder(
        input_img=input_img,
        convolution_type=C2D_BLock,
        iterations=5,
        n_filters=n_filters,
        dropout=dropout,
        batchnorm=batchnorm,
    )
    
    # Unpack encoder features
    down1, down2, down3, down4, down5 = features   # normal UNet skip connections
    
    # --- Decoder with 8 heads ---
    outputs = decoder(
        bottleneck=down5,
        skip_connections=[down4, down3, down2, down1],
        convolution_type=C2D_BLock,
        transpose_conv_type=Conv2DTranspose,
        iterations=4,
        n_filters=n_filters,
        dropout=dropout,
        batchnorm=batchnorm,
        heads=8,
    )
    
    model = Model(inputs=[input_img], outputs=outputs)
    return model

# Initialize Model


In [None]:
input_img = Input(input_shape)
model = define_unet2D(input_img, n_filters=32, dropout=0.001, batchnorm=False)
model.load_weights("weights.h5")

model.summary()

# Load model weights & training history

# Site A

In [None]:
weights_site_a = model.load_weights('.h5')
history_site_a = np.load('.npy',allow_pickle='TRUE').item()


# Load Testing Data 

# Site A

In [None]:
# =============================
# Testing data Site A
# =============================

with h5py.File("TestDataSiteA.mat", "r") as f:
    y_test_site_a = f["lvSaveDataInput"][:, :, :, :]
    x_test_site_a = f["lvLovalizerSave"][:, :, :, :]

# Move channel axis (1 → last)
x_test_site_a = np.moveaxis(x_test_site_a, 1, -1)
y_test_site_a = np.moveaxis(y_test_site_a, 1, -1)

print("x_test_site_a shape:", x_test_site_a.shape)
print("y_test_site_a shape:", y_test_site_a.shape)

# =============================

In [None]:
# ------------------------------------------------------------
# Ground-truth B1+ maps (complex-valued)
# ------------------------------------------------------------

# Reconstruct complex-valued ground truth from real/imag pairs
b1p_groundtruth_complex_site_a = y_val[..., 0::2] + 1j * y_val[..., 1::2]

# or load from saved file
# b1p_groundtruth_complex_site_a = np.load("b1p_gt_complex_site_a.npy") 

# Derived representations (channel-first for convenience)
b1p_gt_magnitude_site_a = np.moveaxis(np.abs(b1p_groundtruth_complex_site_a),   -1, 0)
b1p_gt_phase_site_a     = np.moveaxis(np.angle(b1p_groundtruth_complex_site_a), -1, 0)


# Inference on unseen test data Site A

In [None]:
prediction_site_a = np.array(model.predict(x_test_site_a))


In [None]:
# ------------------------------------------------------------
# Predicted B1+ maps (complex-valued)
# ------------------------------------------------------------

# Reconstruct complex-valued prediction from real/imag pairs

b1p_prediction_complex_site_a = prediction_site_a[..., 0] + 1j * prediction_site_a[..., 1]
# or load from saved file
# b1p_prediction_complex_site_a = np.load("b1p_pr_complex_site_a.npy") 


# Derived representations (channel-first for convenience)
b1p_pr_magnitude_site_a = np.abs(b1p_prediction_complex_site_a)
b1p_pr_phase_site_a     = np.angle(b1p_prediction_complex_site_a)


# Apply Mask Site A

In [None]:
with h5py.File('.mat', 'r') as f:
    mask_site_a = f["mask"][:, :, :]   

# =============================
# Expand mask to 8 Tx channels
# =============================

mask_site_a = np.tile(mask_site_a, [8, 1, 1, 1])

# =============================
# Apply mask 
# =============================

# Ground truth: 
b1p_gt_magnitude_site_a_masked   = b1p_gt_magnitude_site_a * mask_site_a
b1p_gt_phase_site_a_masked       = b1p_gt_phase_site_a * mask_site_a

# Prediction: 
b1p_pr_magnitude_site_a_masked   = b1p_pr_magnitude_site_a   * mask_site_a
b1p_pr_phase_site_a_masked       = b1p_pr_phase_site_a * mask_site_a


# Site B


In [None]:
# =============================
# Testing data Site B
# =============================

with h5py.File("TestDataSiteB.mat", "r") as f:
    y_test_site_b = f["lvSaveDataInput"][:, :, :, :]
    x_test_site_b = f["lvLovalizerSave"][:, :, :, :]

# Move channel axis (1 → last)
x_test_site_b = np.moveaxis(x_test_site_b, 1, -1)
y_test_site_b = np.moveaxis(y_test_site_b, 1, -1)

print("x_test_site_b shape:", x_test_site_b.shape)
print("y_test_site_b shape:", y_test_site_b.shape)

# =============================

In [None]:
# ------------------------------------------------------------
# Ground-truth B1+ maps (complex-valued)
# ------------------------------------------------------------

# Reconstruct complex-valued ground truth from real/imag pairs
b1p_groundtruth_complex_site_b = y_test_site_b[..., 0::2] + 1j * y_test_site_b[..., 1::2]

# or load from saved file

# Derived representations (channel-first for convenience)
b1p_gt_magnitude_site_b = np.moveaxis(np.abs(b1p_groundtruth_complex_site_b),   -1, 0)
b1p_gt_phase_site_b    = np.moveaxis(np.angle(b1p_groundtruth_complex_site_b), -1, 0)


# Inference on unseen test data Site B

In [None]:
prediction_site_b = np.array(model.predict(x_test_site_b))


In [None]:
# ------------------------------------------------------------
# Predicted B1+ maps (complex-valued)
# ------------------------------------------------------------

# Reconstruct complex-valued prediction from real/imag pairs

b1p_prediction_complex_site_b = prediction_site_b[..., 0] + 1j * prediction_site_b[..., 1]
# or load from saved file


# Derived representations (channel-first for convenience)
b1p_pr_magnitude_site_b = np.abs(b1p_prediction_complex_site_b)
b1p_pr_phase_site_b     = np.angle(b1p_prediction_complex_site_b)

# Apply Mask Site B

In [None]:
with h5py.File('.mat', 'r') as f:
    mask_site_b = f["mask"][:, :, :]   

# =============================
# Expand mask to 8 Tx channels
# =============================

mask_site_b = np.tile(mask_site_b, [8, 1, 1, 1])

# =============================
# Apply mask 
# =============================

# Ground truth: 
b1p_gt_magnitude_site_b_masked   = b1p_gt_magnitude_site_b * mask_site_b
b1p_gt_phase_site_b_masked       = b1p_gt_phase_site_b * mask_site_b

# Prediction: 
b1p_pr_magnitude_site_b_masked   = b1p_pr_magnitude_site_b   * mask_site_b
b1p_pr_phase_site_b_masked       = b1p_pr_phase_site_b * mask_site_b

# Simple Plot Comaprison

In [None]:
def plot_b1p_est_gt_diff(
    b1p_est_complex,
    b1p_gt_complex,
    slice_idx,
    vmin_mag=0.0,
    vmax_mag=0.25,
    vmax_diff=0.10,
):
    
    # ------------------------------------------------------------
    # Figure layout
    # ------------------------------------------------------------
    fig, axes = plt.subplots(3, 8, figsize=(14, 8))
    fig.suptitle(fr"$B_1^+$ Magnitude – Slice {slice_idx}", fontsize=14)

    # Column titles (Tx channels)
    for ch in range(8):
        axes[0, ch].set_title(f"Ch {ch + 1}", fontsize=10)

    # Row labels
    row_labels = ["EST", "GT", "|EST − GT|"]
    for r, label in enumerate(row_labels):
        axes[r, 0].annotate(
            label,
            xy=(-0.35, 0.5),
            xycoords="axes fraction",
            ha="right",
            va="center",
            rotation=90,
            fontsize=11,
        )

    # ------------------------------------------------------------
    # Plot per channel
    # ------------------------------------------------------------
    for ch in range(8):
        mag_est = np.abs(b1p_est_complex[ch, slice_idx])
        mag_gt  = np.abs(b1p_gt_complex[ch, slice_idx])
        mag_df  = np.abs(mag_est - mag_gt)

        # Mask background / noise floor
        mag_est = np.where(mag_est < 0.01, np.nan, mag_est)
        mag_gt  = np.where(mag_gt  < 0.01, np.nan, mag_gt)
        mag_df  = np.where(mag_df  < 0.01, np.nan, mag_df)

        axes[0, ch].imshow(
            mag_est.T, cmap="plasma", vmin=vmin_mag, vmax=vmax_mag
        )
        axes[1, ch].imshow(
            mag_gt.T,  cmap="plasma", vmin=vmin_mag, vmax=vmax_mag
        )
        axes[2, ch].imshow(
            mag_df.T,  cmap="inferno", vmin=0.0, vmax=vmax_diff
        )

        for r in range(3):
            axes[r, ch].axis("off")

    # ------------------------------------------------------------
    # Shared colorbar (magnitude)
    # ------------------------------------------------------------
    cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.02])
    cbar = fig.colorbar(
        axes[0, 0].images[0], cax=cbar_ax, orientation="horizontal"
    )
    cbar.set_label("B1⁺ magnitude (a.u.)", fontsize=10)

    plt.subplots_adjust(
        left=0.03, right=0.98, top=0.88, bottom=0.12,
        wspace=0.0, hspace=0.0
    )

    plt.show()


In [None]:
plot_b1p_est_gt_diff(b1p_pr_magnitude_site_a_masked,b1p_gt_magnitude_site_a_masked, slice_idx=10)

In [None]:
plot_b1p_est_gt_diff(b1p_pr_magnitude_site_b_masked,b1p_gt_magnitude_site_b_masked, slice_idx=10)