# Packages

In [None]:
! git clone https://github.com/soumickmj/pytorch-complex.git
! mv /content/pytorch-complex/* .
!pip install torchinfo

In [6]:
# Import necessary libraries
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
from torchvision.transforms import CenterCrop
from torch.utils.data import Dataset, TensorDataset, random_split, SubsetRandomSampler, ConcatDataset, DataLoader
from sklearn.model_selection import KFold, train_test_split
from skimage.metrics import structural_similarity as ssim
from torchinfo import summary

# Import custom complex number support for PyTorch
import torchcomplex
from torchcomplex import nn



In [7]:
device = "cuda" # if torch.cuda.is_available() else "cpu"

device = torch.device('cuda:0')

# Custom Functions

In [8]:
def size_of(x):
    print(x.numel()*x.element_size()/1024/1024)

def count(net):
    return sum(p.numel() for p in net.parameters())

In [None]:
def plot_pred_gt_diff(P_test, ground_truth, slice_idx):
    """
    P_test2, ground_truth2 : [channels, slices, X, Y]
    slice_idx : int
    """

    # --- ensure numpy ---
    if torch.is_tensor(P_test):
        P_test = P_test.detach().cpu().numpy()
    if torch.is_tensor(ground_truth):
        ground_truth = ground_truth.detach().cpu().numpy()

    fig, axes = plt.subplots(3, 8, figsize=(14, 8))
    fig.suptitle(fr"$B_1^+$ Magnitudes – Slice {slice_idx}", fontsize=14)

    for i in range(8):
        axes[0, i].set_title(f"Channel {i + 1}")

    row_labels = ['PR', 'GT', '|PR − GT|']
    for r, label in enumerate(row_labels):
        axes[r, 0].annotate(
            label, xy=(-0.4, 0.5), xycoords='axes fraction',
            ha='right', va='center', rotation=90, fontsize=11
        )

    for i in range(8):
        img_pr = np.abs(P_test[i, slice_idx, :, :])
        img_gt = np.abs(ground_truth[i, slice_idx, :, :])
        img_df = np.abs(img_pr - img_gt)

        img_pr = np.where(img_pr < 0.01, np.nan, img_pr)
        img_gt = np.where(img_gt < 0.01, np.nan, img_gt)
        img_df = np.where(img_df < 0.01, np.nan, img_df)

        axes[0, i].imshow(img_pr.T, cmap='plasma', aspect=1, vmin=0.0, vmax=0.25)
        axes[1, i].imshow(img_gt.T, cmap='plasma', aspect=1, vmin=0.0, vmax=0.25)
        axes[2, i].imshow(img_df.T, cmap='inferno', aspect=1, vmin=0.0, vmax=0.1)

        axes[0, i].axis('off')
        axes[1, i].axis('off')
        axes[2, i].axis('off')

    cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.02])
    cbar = fig.colorbar(axes[0, 0].images[0], cax=cbar_ax, orientation='horizontal')
    cbar.set_label("Intensity in a.u.", fontsize=10)

    plt.subplots_adjust(left=0.03, right=0.98, top=0.88, bottom=0.12,
                        wspace=0.0, hspace=0.0)

    plt.show()

# Data Process

In [None]:
def process_training_data(file_path):
    """
    Process training data from an HDF5 file to create complex-valued arrays.
    
    Parameters:
    -----------
    file_path : str
        Path to the training HDF5 file
        
    Returns:
    --------
    tuple
        (x_, y_, complex_localizer_data_train, complex_input_data_train) where:
        - x_ is the complex-valued localizer data (same as complex_localizer_data_train)
        - y_ is the complex-valued input data (same as complex_input_data_train)
        - complex_localizer_data_train is the complex-valued localizer data (for compatibility)
        - complex_input_data_train is the complex-valued input data (for compatibility)
    """
    # Load data from file
    with h5py.File(file_path, 'r') as train_file:
        localizer_data_train = train_file['lvLovalizerSave'][:,:,:,:]
        input_data_train = train_file['lvSaveDataInput'][:,:,:,:]
    
    # Move axis
    localizer_data_train = np.moveaxis(localizer_data_train, 2, -1)
    input_data_train = np.moveaxis(input_data_train, 2, -1)
    
    # Remove magnitude value
    localizer_data_train = np.delete(localizer_data_train, 0, axis=1)
    
    # Separate real and imaginary parts for localizer data
    localizer_real_train = localizer_data_train[:, ::2, :, :]
    localizer_imag_train = localizer_data_train[:, 1::2, :, :]
    
    # Separate real and imaginary parts for input data
    input_real_train = input_data_train[:, ::2, :, :]
    input_imag_train = input_data_train[:, 1::2, :, :]
    
    # Combine into complex values
    x_ = complex_localizer_data_train = localizer_real_train + 1j * localizer_imag_train
    y_ = complex_input_data_train = input_real_train + 1j * input_imag_train
    
    return x_, y_, complex_localizer_data_train, complex_input_data_train

In [None]:
def process_validation_data(file_path):
    """
    Process validation data from an HDF5 file to create complex-valued arrays.
    
    Parameters:
    -----------
    file_path : str
        Path to the validation HDF5 file
        
    Returns:
    --------
    tuple
        (x_test_, y_test_, complex_localizer_data_val, complex_input_data_val, input_data_val) where:
        - x_test_ is the complex-valued localizer data (same as complex_localizer_data_val)
        - y_test_ is the complex-valued input data
        - complex_localizer_data_val is the complex-valued localizer data (for compatibility)
        - complex_input_data_val is the complex-valued input data (for compatibility)
        - input_data_val is the processed input data before complex conversion
    """
    # Load data from file
    with h5py.File(file_path, 'r') as val_file:
        localizer_data_val = val_file['lvLovalizerSave'][:,:,:]
        input_data_val = val_file['lvSaveDataInput'][:,:,:,:]
    
    # Move axis for localizer data
    localizer_data_val = np.moveaxis(localizer_data_val, 2, -1)
    
    # Move axes for input data (note the double moveaxis operation)
    input_data_val = np.moveaxis(input_data_val, 2, -1)
    y_test_ = input_data_val = np.moveaxis(input_data_val, 1, -1)
    
    # Remove magnitude value
    localizer_data_val = np.delete(localizer_data_val, 0, axis=1)
    
    # Separate real and imaginary parts for localizer data
    localizer_real_val = localizer_data_val[:, ::2, :, :]
    localizer_imag_val = localizer_data_val[:, 1::2, :, :]
    
    # Separate real and imaginary parts for input data
    input_real_val = input_data_val[:, ::2, :, :]
    input_imag_val = input_data_val[:, 1::2, :, :]
    
    # Combine into complex values
    x_test_ = complex_localizer_data_val = localizer_real_val + 1j * localizer_imag_val
    complex_input_data_val = input_real_val + 1j * input_imag_val
    
    return x_test_, y_test_, complex_localizer_data_val, complex_input_data_val, input_data_val


# Initialization Weights

In [9]:
def _init_weights(module):
        if isinstance(module, torchcomplex.nn.Conv2d):
            module.weight.data.normal_(mean=0.0, std=0.02) 

# Loss Function

In [10]:
class ComplexMSELoss:
    def __call__(self, true, prediction):
        # Convert NumPy arrays to PyTorch tensors of complex64 type right at the beginning
        true_tensor = torch.tensor(true, dtype=torch.complex64)
        prediction_tensor = torch.tensor(prediction, dtype=torch.complex64)
        
        # Perform the MSE computation
        return (0.5 * (true_tensor - prediction_tensor) ** 2).mean()
    

    
class PerpLoss(nn.Module):
    def __init__(self, eps=1e-8, l1factor=1.0, mask=False):
        super(PerpLoss, self).__init__()
        self.eps = eps
        self.l1factor = l1factor
        self.mask = mask

    def forward(self, target, prediction):
        # Calculate the cross term as the absolute value of the determinant of the complex numbers
        cross = torch.abs(target.real * prediction.imag - target.imag * prediction.real)
        # Calculate the perpendicular loss component
        ploss_raw = cross / (torch.abs(prediction) + self.eps)
        # Corrected: Ensure the mask is a boolean tensor
        # Here, it's assumed that you want to mask based on the condition that involves 'target'
        # Adjust the condition according to your specific requirements
        mask = target.abs() > 1e-3  # This now produces a boolean tensor
        angle_smaller_90 = ((target / prediction).real > 0).detach()  # is the angle < pi/2 ?
        # Calculate the final loss with the conditional mask applied
        # torch.where now receives a boolean tensor as expected
        ploss = torch.where(angle_smaller_90, ploss_raw, 2 * torch.abs(target) - ploss_raw)
        l1loss = torch.nn.functional.l1_loss(prediction, target, reduction='none')
        
        loss = ploss + self.l1factor * l1loss
        if self.mask:
            loss = (loss * mask).sum() / (mask.sum() + self.eps)  # Returning the maksed mean loss over all elements
        else:
            loss = loss.mean()  # return the mean over all elements
        return loss



# 2D Convolutional Block

In [11]:
class C2D_Block(nn.Module):
    def __init__(self, in_c, n_filters, batchnorm, skip):
        super().__init__()
        self.conv1 = torchcomplex.nn.Conv2d(in_c, n_filters, kernel_size=3, padding=1)

        if batchnorm:
            self.bn1 = torchcomplex.nn.BatchNorm2d(n_filters)
            self.bn2 = torchcomplex.nn.BatchNorm2d(n_filters)
        else:
            self.bn1 = None
            self.bn2 = None

        # self.relu1 = torchcomplex.nn.CReLU() - do not use CReLU here, look up in docs
        self.relu1 = torchcomplex.nn.AdaptiveCmodReLU(n_filters)
        self.conv2 = torchcomplex.nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1)
        # self.relu2 = torchcomplex.nn.CReLU() - do not use CReLU here, look up in docs
        self.relu2 = torchcomplex.nn.AdaptiveCmodReLU(n_filters)
        if skip:
            self.skip = torchcomplex.nn.Conv2d(in_c, n_filters, kernel_size=1)
            with torch.no_grad():
                self.skip.bias.zero_()
        else:
            self.skip = None

    def forward(self, xin):
        x = self.conv1(xin)
        if self.bn1:
            x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        if self.bn2:
            x = self.bn2(x)
        if self.skip:
            x = x + self.skip(xin)
        x = self.relu2(x)
        return x

# Encoder

In [12]:
class Encoder(nn.Module):
    def __init__(self, in_c, dropout, features, maxpool, batchnorm, skip):
        super().__init__()
        self.encBlocks = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        for feature in features:
            self.encBlocks.append(C2D_Block(in_c, feature, batchnorm=batchnorm, skip=skip))
            if maxpool:
                self.downsamples.append(torchcomplex.nn.MaxPool2d(2))
                in_c = feature
            else:
                down = torch.nn.Sequential(
                    torchcomplex.nn.Conv2d(feature, features[-1], kernel_size=3, stride=2, padding=1), torchcomplex.nn.AdaptiveCmodReLU(features[-1])
                )
                self.downsamples.append(down)
                in_c = features[-1]

        self.dropouts = torchcomplex.nn.Dropout2d(dropout) if dropout else torch.nn.Identity()
        self.dropout2 = torchcomplex.nn.Dropout2d(dropout * 2) if dropout else torch.nn.Identity()
        self.dropout3 = torchcomplex.nn.Dropout2d(dropout * 3) if dropout else torch.nn.Identity()
        # self.bottleneck = C2D_Block(features[-1], features[-1]*2)
        self.bottleneck = torch.nn.Sequential(
            torchcomplex.nn.Conv2d(features[-1], features[-1], kernel_size=3, padding=1), torchcomplex.nn.AdaptiveCmodReLU( features[-1])
        )

    def forward(self, x):
        skip_connections = []
        # downsampling
        for depth, (block, down) in enumerate(zip(self.encBlocks, self.downsamples)):
            x = block(x)
            skip_connections.append(x)
            x = down(x)
        if depth < 2:
            x = self.dropout1(x)
        else:
            x = self.dropout2(x)

        x = self.bottleneck(x)
        x = self.dropout3(x)

        return x, skip_connections

# Decoder

In [13]:
class Decoder(nn.Module):
    def __init__(self, dropout, features, upsample, batchnorm, skip):
        super().__init__()

        features_out = list(reversed(features))
        features_in = [features[-1], *features[:0:-1]]
        
        self.upConvs = nn.ModuleList()
        self.decBlocks = nn.ModuleList()
        for fin, fout in zip(features_in, features_out):
            if upsample:
                self.upConvs.append(
                    torch.nn.Sequential(
                        torchcomplex.nn.Upsample(mode="bilinear", scale_factor=2, size=None),
                        torchcomplex.nn.Conv2d(fin, fout, kernel_size=3, padding=1),
                        torchcomplex.nn.AdaptiveCmodReLU(fout),
                    )
                )
            else:
                self.upConvs.append(
                    torch.nn.Sequential(
                        torchcomplex.nn.ConvTranspose2d(fin, fout, 2, stride=2),
                        torchcomplex.nn.AdaptiveCmodReLU(fout),
                    )
                )
            self.decBlocks.append(C2D_Block(2 * fout, fout, batchnorm=batchnorm, skip=skip))
        # with torch.no_grad():
        #     self.upConvs.apply(_init_weights)

        self.dropout1 = torchcomplex.nn.Dropout2d(dropout) if dropout else torch.nn.Identity()
        self.dropout2 = torchcomplex.nn.Dropout2d(dropout * 2) if dropout else torch.nn.Identity()

    def forward(self, x, skipped_feautures):
        for depth, (up, block, skipped) in enumerate(zip(self.upConvs, self.decBlocks, skipped_feautures, strict=True)):
            x = up(x)
            x = torch.cat([x, skipped], dim=1)
            x = block(x)
            if depth < 2:
                x = self.dropout2(x)
            else:
                x = self.dropout1(x)
        return x

    def crop(self, encFeaturs, x):
        (_, _, H, W) = x.shape
        encFeaturs = CenterCrop([H, W])(encFeaturs)

        return encFeaturs

In [None]:
class Head(nn.Module):
    def __init__(self, features_in, features_out=1, features_hidden=(64,32,16)):
        super().__init__()
        
        modules = []
        fin = features_in
        for fout in features_hidden:
            modules.append(torchcomplex.nn.Conv2d(fin,fout, 3, padding=1))
            modules.append(torchcomplex.nn.AdaptiveCmodReLU(fout))
            fin = fout
        modules.append(torchcomplex.nn.Conv2d(fin, features_out, 3, padding=1))
        self.net = torch.nn.Sequential(*modules)
        
        with torch.no_grad():
            self.net[-1].bias.zero_()
    
    def forward(self,x):
        return self.net(x)

In [None]:
class Head(nn.Module):
    def __init__(self, features_in, features_out=1, features_hidden=(64,32,16)):
        super().__init__()
        
        modules = []
        fin = features_in
        for fout in features_hidden:
            modules.append(torchcomplex.nn.Conv2d(fin,fout, 3, padding=1))
            modules.append(torchcomplex.nn.AdaptiveCmodReLU(fout))
            fin = fout
        modules.append(torchcomplex.nn.Conv2d(fin, features_out, 3, padding=1))
        self.net = torch.nn.Sequential(*modules)
        
        with torch.no_grad():
            self.net[-1].bias.zero_()
    
    def forward(self,x):
        return self.net(x)

# UNet

In [15]:
class UNet(nn.Module):
    def __init__(self, in_c, out_c, dropout, features, maxpool=True, upsample=False, batchnorm=False, skip=True):
        super().__init__()
        self.encoder = Encoder(in_c, dropout, features, maxpool=maxpool, batchnorm=batchnorm, skip=skip)

        self.decoder1 = Decoder(dropout, features, upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder2 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder3 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder4 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder5 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder6 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder7 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder8 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)

        self.head1 = Head(features[0], out_c)
        self.head2 = Head(features[0], out_c)
        self.head3 = Head(features[0], out_c)
        self.head4 = Head(features[0], out_c)
        self.head5 = Head(features[0], out_c)
        self.head6 = Head(features[0], out_c)
        self.head7 = Head(features[0], out_c)
        self.head8 = Head(features[0], out_c)

    def forward(self, x):
        x, encFeatures = self.encoder(x)

        decFeatures1 = self.decoder1(x, encFeatures[::-1])
        output1 = self.head1(decFeatures1)

        decFeatures2 = self.decoder2(x, encFeatures[::-1])
        output2 = self.head2(decFeatures2)

        decFeatures3 = self.decoder3(x, encFeatures[::-1])
        output3 = self.head3(decFeatures3)

        decFeatures4 = self.decoder4(x, encFeatures[::-1])
        output4 = self.head4(decFeatures4)

        decFeatures5 = self.decoder5(x, encFeatures[::-1])
        output5 = self.head5(decFeatures5)

        decFeatures6 = self.decoder6(x, encFeatures[::-1])
        output6 = self.head6(decFeatures6)

        decFeatures7 = self.decoder7(x, encFeatures[::-1])
        output7 = self.head7(decFeatures7)

        decFeatures8 = self.decoder8(x, encFeatures[::-1])
        output8 = self.head8(decFeatures8)

        return output1, output2, output3, output4, output5, output6, output7, output8

# Data Import

# Site A

In [111]:
train_file_path = 'TrainingData.mat'

x_, y_, complex_localizer_data_train, complex_input_data_train = process_training_data(train_file_path)

print("\nComplex Training Data Shapes:")
print("x_ shape:", x_.shape)
print("y_ shape:", y_.shape)
print("complex_localizer_data_train shape:", complex_localizer_data_train.shape)
print("complex_input_data_train shape:", complex_input_data_train.shape)


Complex Training Data Shapes:
x_ shape: (156, 32, 128, 96)
y_ shape: (156, 8, 128, 96)
complex_localizer_data_train shape: (156, 32, 128, 96)
complex_input_data_train shape: (156, 8, 128, 96)


In [112]:
test_file_path = 'TestDataSiteA.mat'

x_test_, y_test_, complex_localizer_data_val, complex_input_data_val, input_data_val = process_validation_data(val_file_path)

print("\nComplex Validation Data Shapes:")
print("x_test_ shape:", x_test_.shape)
print("y_test_ shape:", y_test_.shape)
print("complex_localizer_data_val shape:", complex_localizer_data_val.shape)
print("complex_input_data_val shape:", complex_input_data_val.shape)
print("input_data_val shape:", input_data_val.shape)


Complex Validation Data Shapes:
x_test_ shape: (39, 32, 128, 96)
y_test_ shape: (39, 128, 96, 16)
complex_localizer_data_val shape: (39, 32, 128, 96)
complex_input_data_val shape: (39, 64, 96, 16)
input_data_val shape: (39, 128, 96, 16)


In [115]:
# Convert to PyTorch tensor and move to device
x_test_tensor  = complex_localizer_data_val_tensor = torch.from_numpy(complex_localizer_data_val).to(device)


# Preparation of Data
x_tensor = torch.from_numpy(x_)

ys= torch.tensor(y_[:,0:8,:,:])
data = TensorDataset(x_tensor,*ys.unsqueeze(2).unbind(1))
train_dataset, val_dataset = random_split(data, [0.8, 0.2])

dataset = ConcatDataset([train_dataset, val_dataset])

###

In [116]:
# ============================================================
# Hyperparameters
# ============================================================

# Add light Gaussian noise to input images in a fraction of cases
params = dict(
    lr=1e-4,                 # lower LR for long, stable training
    gamma=0.9985,            # very slow exponential decay over 4000 epochs
    batch_size=1,
    dropout=0.005,           # slightly higher to counter long training
    num_epochs=4000,
    weight_decay=0.05,       # regularization to prevent overfitting
    features=(32, 32, 64, 128, 256),
    maxpool=True,
    batchnorm=False,
    skip=True,
    upsample=True,
    clip_grad_norm=1.0,      # safety for long runs
)

# ------------------------------------------------------------
# Loss function
# ------------------------------------------------------------
criterion = lambda gt, pred: torch.nn.functional.mse_loss(
    torch.view_as_real(gt),
    torch.view_as_real(pred),
)


In [119]:
# Step 1: Re-initialize the model and optimizer
model = UNet(in_c=32, out_c=1, dropout=params['dropout'], features=params['features'], 
             maxpool=params['maxpool'], skip=params['skip'], batchnorm=params['batchnorm'], 
             upsample=params['upsample']).to(device)

optimizer = optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])

# Step 2: Load the checkpoint
checkpoint = torch.load('SavedModel.pth')

# Step 3: Load the model state
model.load_state_dict(checkpoint['model_state_dict'])

# Step 4: Load the optimizer state
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Step 5: Load the training state (optional, if you want to resume training)
last_saved_epoch = checkpoint['epochs']
train_loss = checkpoint['train_loss']
validation_loss = checkpoint['validation_loss']

In [121]:
print(f"Last saved Epoch: {last_saved_epoch}")
print(f"Train Loss: {train_loss}")
print(f"Validation Loss: {validation_loss}")

Last saved Epoch: 499
Train Loss: 9.56017876856833e-05
Validation Loss: 0.00018803449123739623


In [122]:
with torch.no_grad():
   x_test_tensor = x_test_tensor.to(device)
   P_test = torch.stack(model(x_test_tensor)).cpu().detach()

gt = ground_truth = torch.tensor((input_data_val[:,:,:,::2] + 1j*input_data_val[:,:,:,1::2])).permute(-1,0,-3,-2)
print('ground truth dimensions:', ground_truth.shape)

pt = prediction = P_test
print('prediction dimensions:', prediction.shape)

ground truth dimensions: torch.Size([8, 39, 128, 96])
prediction dimensions: torch.Size([8, 39, 1, 128, 96])


In [None]:
plot_pred_gt_diff(pt, gt, slice_idx=2)


In [None]:
# Check total memory and allocated memory on the GPU
total_memory = torch.cuda.get_device_properties(0).total_memory
allocated_memory = torch.cuda.memory_allocated(0)
cached_memory = torch.cuda.memory_reserved(0)

print(f"Total GPU memory: {total_memory / (1024**3):.2f} GB")
print(f"Allocated GPU memory: {allocated_memory / (1024**3):.2f} GB")
print(f"Cached GPU memory: {cached_memory / (1024**3):.2f} GB")

# Site B

In [None]:
test_file_path2 = 'TestDataSiteB.mat'

x_test_2, y_test_2, complex_localizer_data_val2, complex_input_data_val2, input_data_val2 = process_validation_data(test_file_path2)

In [None]:
with torch.no_grad():
   x_test_tensor2 = torch.tensor(complex_localizer_data_val2).to(device=device,dtype=torch.complex64)
   P_test2 = torch.stack(model(x_test_tensor2)).cpu().detach()

gt2 = ground_truth2 = torch.tensor((input_data_val2[:,:,:,::2] + 1j*input_data_val2[:,:,:,1::2])).permute(3, 0, 1, 2)
print('ground truth dimensions:', ground_truth2.shape)

pt2 = prediction2 = P_test2.squeeze(2)
print('prediction dimensions:', prediction2.shape)

In [None]:
plot_pred_gt_diff(pt2, gt2, slice_idx=2)
