# Packages

In [None]:
! git clone https://github.com/soumickmj/pytorch-complex.git
! mv /content/pytorch-complex/* .
!pip install torchinfo

In [None]:
# Import necessary libraries
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler

from torch.utils.data import Dataset, TensorDataset, random_split, SubsetRandomSampler, ConcatDataset, DataLoader
from sklearn.model_selection import KFold, train_test_split
from skimage.metrics import structural_similarity as ssim
from torchinfo import summary

# Import custom complex number support for PyTorch
import torchcomplex
from torchcomplex import nn


In [None]:
device = "cuda" # if torch.cuda.is_available() else "cpu"

device = torch.device('cuda:0')

# Data Import


In [None]:
def process_training_data(train_file_path):
    """
    Load training data from an HDF5 (.mat) file and convert
    real/imaginary channel pairs into complex-valued arrays.

    Parameters
    ----------
    train_file_path : str
        Path to the training HDF5 file.

    Returns
    -------
    x_ : np.ndarray
        Complex-valued localizer training data.
    y_ : np.ndarray
        Complex-valued input (target) training data.
    """

    # ------------------------------------------------------------
    # Load data
    # ------------------------------------------------------------
    with h5py.File(train_file_path, "r") as train_file:
        # "x" data (localizer)
        localizer_data_train = train_file["lvLovalizerSave"][:, :, :, :].astype(np.float32)
        # "y" data (input / target)
        input_data_train = train_file["lvSaveDataInput"][:, :, :, :].astype(np.float32)

    print("\nOriginal Shapes:")
    print("Training Input Data Shape:", input_data_train.shape)
    print("Training Localizer Data Shape:", localizer_data_train.shape)

    # ------------------------------------------------------------
    # Move channel axis: axis 2 -> last
    # ------------------------------------------------------------
    localizer_data_train = np.moveaxis(localizer_data_train, 2, -1)
    input_data_train = np.moveaxis(input_data_train, 2, -1)

    print("\nAfter Moving Axis:")
    print("Training Input Data Shape:", input_data_train.shape)
    print("Training Localizer Data Shape:", localizer_data_train.shape)

    # ------------------------------------------------------------
    # Remove magnitude channel (localizer with all Tx channels on)
    # ------------------------------------------------------------
    localizer_data_train = np.delete(localizer_data_train, 0, axis=1)

    print("\nAfter Deleting Magnitude Value:")
    print("Training Localizer Data Shape:", localizer_data_train.shape)

    # ------------------------------------------------------------
    # Split real / imaginary parts
    # ------------------------------------------------------------
    # Localizer
    localizer_real_train = localizer_data_train[:, ::2, :, :]
    localizer_imag_train = localizer_data_train[:, 1::2, :, :]

    # Input / target
    input_real_train = input_data_train[:, ::2, :, :]
    input_imag_train = input_data_train[:, 1::2, :, :]

    # ------------------------------------------------------------
    # Combine into complex-valued arrays
    # ------------------------------------------------------------
    x_ = localizer_real_train + 1j * localizer_imag_train
    y_ = input_real_train + 1j * input_imag_train

    print("\nComplex Training Data Shapes:")
    print("Complex Training Localizer Data Shape:", x_.shape)
    print("Complex Training Input Data Shape:", y_.shape)

    return x_, y_


# ------------------------------------------------------------
# Example usage
# ------------------------------------------------------------

train_file_path = "TrainingData.mat"
x_, y_ = process_training_data(train_file_path)

In [None]:
def process_validation_data(val_file_path):
    """
    Load validation data from an HDF5 (.mat) file and convert
    real/imaginary channel pairs into complex-valued arrays.

    Parameters
    ----------
    val_file_path : str
        Path to the validation HDF5 file.

    Returns
    -------
    x_test_ : np.ndarray
        Complex-valued localizer validation data.
    y_test_ : np.ndarray
        Complex-valued input (target) validation data.
    """

    # ------------------------------------------------------------
    # Load data
    # ------------------------------------------------------------
    with h5py.File(val_file_path, "r") as val_file:
        # "x_test" data (localizer)
        localizer_data_val = val_file["lvLovalizerSave"][:, :, :]
        # "y_test" data (input / target)
        input_data_val = val_file["lvSaveDataInput"][:, :, :, :]

    print("\nOriginal Shapes:")
    print("Validation Input Data Shape:", input_data_val.shape)
    print("Validation Localizer Data Shape:", localizer_data_val.shape)

    # ------------------------------------------------------------
    # Move axes
    # ------------------------------------------------------------
    localizer_data_val = np.moveaxis(localizer_data_val, 2, -1)
    input_data_val = np.moveaxis(input_data_val, 2, -1)
    input_data_val = np.moveaxis(input_data_val, 1, -1)

    print("\nAfter Moving Axis:")
    print("Validation Input Data Shape:", input_data_val.shape)
    print("Validation Localizer Data Shape:", localizer_data_val.shape)

    # ------------------------------------------------------------
    # Remove magnitude channel (localizer with all Tx channels on)
    # ------------------------------------------------------------
    localizer_data_val = np.delete(localizer_data_val, 0, axis=1)

    print("\nAfter Deleting Magnitude Value:")
    print("Validation Localizer Data Shape:", localizer_data_val.shape)

    # ------------------------------------------------------------
    # Split real / imaginary parts
    # ------------------------------------------------------------
    # Localizer
    localizer_real_val = localizer_data_val[:, ::2, :, :]
    localizer_imag_val = localizer_data_val[:, 1::2, :, :]

    # Input / target
    input_real_val = input_data_val[:, ::2, :, :]
    input_imag_val = input_data_val[:, 1::2, :, :]

    # ------------------------------------------------------------
    # Combine into complex-valued arrays
    # ------------------------------------------------------------
    x_test_ = localizer_real_val + 1j * localizer_imag_val
    y_test_ = input_real_val + 1j * input_imag_val

    print("\nComplex Validation Data Shapes:")
    print("Complex Validation Localizer Data Shape:", x_test_.shape)
    print("Complex Validation Input Data Shape:", y_test_.shape)

    return x_test_, y_test_


# ------------------------------------------------------------
# Example usage
# ------------------------------------------------------------

val_file_path = ("ValidationData.mat")
x_test_, y_test_ = process_validation_data(val_file_path)

In [None]:
# Convert to PyTorch tensor and move to device
x_test_tensor = torch.from_numpy(x_test_).to(device)


# Custom Functions

In [None]:
def size_of(x):
    print(x.numel()*x.element_size()/1024/1024)

def count(net):
    return sum(p.numel() for p in net.parameters())

In [None]:
###---###---###---###

''' Initial weights '''

###---###---###---###


def _init_weights(module):
        if isinstance(module, torchcomplex.nn.Conv2d):
            module.weight.data.normal_(mean=0.0, std=0.02) 

# Loss Function

In [None]:
###---###---###---###

""" Loss Function """

###---###---###---###

class ComplexMSELoss:
    def __call__(self, true, prediction):
        # Convert NumPy arrays to PyTorch tensors of complex64 type right at the beginning
        true_tensor = torch.tensor(true, dtype=torch.complex64)
        prediction_tensor = torch.tensor(prediction, dtype=torch.complex64)
        
        # Perform the MSE computation
        return (0.5 * (true_tensor - prediction_tensor) ** 2).mean()
    

    
class PerpLoss(nn.Module):
    def __init__(self, eps=1e-8, l1factor=1.0, mask=False):
        super(PerpLoss, self).__init__()
        self.eps = eps
        self.l1factor = l1factor
        self.mask = mask

    def forward(self, target, prediction):
        # Calculate the cross term as the absolute value of the determinant of the complex numbers
        cross = torch.abs(target.real * prediction.imag - target.imag * prediction.real)
        # Calculate the perpendicular loss component
        ploss_raw = cross / (torch.abs(prediction) + self.eps)
        # Corrected: Ensure the mask is a boolean tensor
        # Here, it's assumed that you want to mask based on the condition that involves 'target'
        # Adjust the condition according to your specific requirements
        mask = target.abs() > 1e-3  # This now produces a boolean tensor
        angle_smaller_90 = ((target / prediction).real > 0).detach()  # is the angle < pi/2 ?
        # Calculate the final loss with the conditional mask applied
        # torch.where now receives a boolean tensor as expected
        ploss = torch.where(angle_smaller_90, ploss_raw, 2 * torch.abs(target) - ploss_raw)
        l1loss = torch.nn.functional.l1_loss(prediction, target, reduction='none')
        
        loss = ploss + self.l1factor * l1loss
        if self.mask:
            loss = (loss * mask).sum() / (mask.sum() + self.eps)  # Returning the maksed mean loss over all elements
        else:
            loss = loss.mean()  # return the mean over all elements
        return loss



# 2D Convolutional Block

In [None]:
###---###---###---###

""" 2D Convolutional Block """

###---###---###---###


class C2D_Block(nn.Module):
    def __init__(self, in_c, n_filters, batchnorm, skip):
        super().__init__()
        self.conv1 = torchcomplex.nn.Conv2d(in_c, n_filters, kernel_size=3, padding=1)

        if batchnorm:
            self.bn1 = torchcomplex.nn.BatchNorm2d(n_filters)
            self.bn2 = torchcomplex.nn.BatchNorm2d(n_filters)
        else:
            self.bn1 = None
            self.bn2 = None

        # self.relu1 = torchcomplex.nn.CReLU() - do not use CReLU here, look up in docs
        self.relu1 = torchcomplex.nn.AdaptiveCmodReLU(n_filters)
        self.conv2 = torchcomplex.nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1)
        # self.relu2 = torchcomplex.nn.CReLU() - do not use CReLU here, look up in docs
        self.relu2 = torchcomplex.nn.AdaptiveCmodReLU(n_filters)
        if skip:
            self.skip = torchcomplex.nn.Conv2d(in_c, n_filters, kernel_size=1)
            with torch.no_grad():
                self.skip.bias.zero_()
        else:
            self.skip = None

    def forward(self, xin):
        x = self.conv1(xin)
        if self.bn1:
            x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        if self.bn2:
            x = self.bn2(x)
        if self.skip:
            x = x + self.skip(xin)
        x = self.relu2(x)
        return x

# Encoder

In [None]:
###---###---###---###

""" Encoder """

###---###---###---###


class Encoder(nn.Module):
    def __init__(self, in_c, dropout, features, maxpool, batchnorm, skip):
        super().__init__()
        self.encBlocks = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        for feature in features:
            self.encBlocks.append(C2D_Block(in_c, feature, batchnorm=batchnorm, skip=skip))
            if maxpool:
                self.downsamples.append(torchcomplex.nn.MaxPool2d(2))
                in_c = feature
            else:
                down = torch.nn.Sequential(
                    torchcomplex.nn.Conv2d(feature, features[-1], kernel_size=3, stride=2, padding=1), torchcomplex.nn.AdaptiveCmodReLU(features[-1])
                )
                self.downsamples.append(down)
                in_c = features[-1]

        self.dropouts = torchcomplex.nn.Dropout2d(dropout) if dropout else torch.nn.Identity()
        self.dropout2 = torchcomplex.nn.Dropout2d(dropout * 2) if dropout else torch.nn.Identity()
        self.dropout3 = torchcomplex.nn.Dropout2d(dropout * 3) if dropout else torch.nn.Identity()
        # self.bottleneck = C2D_Block(features[-1], features[-1]*2)
        self.bottleneck = torch.nn.Sequential(
            torchcomplex.nn.Conv2d(features[-1], features[-1], kernel_size=3, padding=1), torchcomplex.nn.AdaptiveCmodReLU( features[-1])
        )

    def forward(self, x):
        skip_connections = []
        # downsampling
        for depth, (block, down) in enumerate(zip(self.encBlocks, self.downsamples)):
            x = block(x)
            skip_connections.append(x)
            x = down(x)
        if depth < 2:
            x = self.dropout1(x)
        else:
            x = self.dropout2(x)

        x = self.bottleneck(x)
        x = self.dropout3(x)

        return x, skip_connections

# Decoder

In [None]:
###---###---###---###

""" Decoder """

###---###---###---###


class Decoder(nn.Module):
    def __init__(self, dropout, features, upsample, batchnorm, skip):
        super().__init__()

        features_out = list(reversed(features))
        features_in = [features[-1], *features[:0:-1]]
        
        self.upConvs = nn.ModuleList()
        self.decBlocks = nn.ModuleList()
        for fin, fout in zip(features_in, features_out):
            if upsample:
                self.upConvs.append(
                    torch.nn.Sequential(
                        torchcomplex.nn.Upsample(mode="bilinear", scale_factor=2, size=None),
                        torchcomplex.nn.Conv2d(fin, fout, kernel_size=3, padding=1),
                        torchcomplex.nn.AdaptiveCmodReLU(fout),
                    )
                )
            else:
                self.upConvs.append(
                    torch.nn.Sequential(
                        torchcomplex.nn.ConvTranspose2d(fin, fout, 2, stride=2),
                        torchcomplex.nn.AdaptiveCmodReLU(fout),
                    )
                )
            self.decBlocks.append(C2D_Block(2 * fout, fout, batchnorm=batchnorm, skip=skip))
        # with torch.no_grad():
        #     self.upConvs.apply(_init_weights)

        self.dropout1 = torchcomplex.nn.Dropout2d(dropout) if dropout else torch.nn.Identity()
        self.dropout2 = torchcomplex.nn.Dropout2d(dropout * 2) if dropout else torch.nn.Identity()

    def forward(self, x, skipped_feautures):
        for depth, (up, block, skipped) in enumerate(zip(self.upConvs, self.decBlocks, skipped_feautures, strict=True)):
            x = up(x)
            x = torch.cat([x, skipped], dim=1)
            x = block(x)
            if depth < 2:
                x = self.dropout2(x)
            else:
                x = self.dropout1(x)
        return x

    def crop(self, encFeaturs, x):
        (_, _, H, W) = x.shape
        encFeaturs = CenterCrop([H, W])(encFeaturs)

        return encFeaturs

In [None]:
class Head(nn.Module):
    def __init__(self, features_in, features_out=1, features_hidden=(64,32,16)):
        super().__init__()
        
        modules = []
        fin = features_in
        for fout in features_hidden:
            modules.append(torchcomplex.nn.Conv2d(fin,fout, 3, padding=1))
            modules.append(torchcomplex.nn.AdaptiveCmodReLU(fout))
            fin = fout
        modules.append(torchcomplex.nn.Conv2d(fin, features_out, 3, padding=1))
        self.net = torch.nn.Sequential(*modules)
        
        with torch.no_grad():
            self.net[-1].bias.zero_()
    
    def forward(self,x):
        return self.net(x)

# UNet

In [None]:
###---###---###---###

""" UNet """

###---###---###---###


class UNet(nn.Module):
    def __init__(self, in_c, out_c, dropout, features, maxpool=True, upsample=False, batchnorm=False, skip=True):
        super().__init__()
        self.encoder = Encoder(in_c, dropout, features, maxpool=maxpool, batchnorm=batchnorm, skip=skip)

        self.decoder1 = Decoder(dropout, features, upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder2 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder3 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder4 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder5 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder6 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder7 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)
        self.decoder8 = Decoder(dropout, features ,upsample=upsample, batchnorm=batchnorm, skip=skip)

        self.head1 = Head(features[0], out_c)
        self.head2 = Head(features[0], out_c)
        self.head3 = Head(features[0], out_c)
        self.head4 = Head(features[0], out_c)
        self.head5 = Head(features[0], out_c)
        self.head6 = Head(features[0], out_c)
        self.head7 = Head(features[0], out_c)
        self.head8 = Head(features[0], out_c)

    def forward(self, x):
        x, encFeatures = self.encoder(x)

        decFeatures1 = self.decoder1(x, encFeatures[::-1])
        output1 = self.head1(decFeatures1)

        decFeatures2 = self.decoder2(x, encFeatures[::-1])
        output2 = self.head2(decFeatures2)

        decFeatures3 = self.decoder3(x, encFeatures[::-1])
        output3 = self.head3(decFeatures3)

        decFeatures4 = self.decoder4(x, encFeatures[::-1])
        output4 = self.head4(decFeatures4)

        decFeatures5 = self.decoder5(x, encFeatures[::-1])
        output5 = self.head5(decFeatures5)

        decFeatures6 = self.decoder6(x, encFeatures[::-1])
        output6 = self.head6(decFeatures6)

        decFeatures7 = self.decoder7(x, encFeatures[::-1])
        output7 = self.head7(decFeatures7)

        decFeatures8 = self.decoder8(x, encFeatures[::-1])
        output8 = self.head8(decFeatures8)

        return output1, output2, output3, output4, output5, output6, output7, output8

# Data preparation

In [None]:
# ============================================================
# Data preparation
# ============================================================

# Convert input data to PyTorch tensor
x_tensor = torch.from_numpy(x_)

# Select first 8 channels from target data
y_tensor = torch.tensor(y_[:, 0:8, :, :])

# Create TensorDataset:
#   x_tensor        : input
#   y_tensor split  : one tensor per Tx channel
data = TensorDataset(
    x_tensor,
    *y_tensor.unsqueeze(2).unbind(1)
)

# Split into training and validation sets (80 / 20)
train_dataset, val_dataset = random_split(
    data,
    [0.8, 0.2]
)

# Concatenate back if a unified dataset is required
dataset = ConcatDataset([train_dataset, val_dataset])


# Hyperparameters

In [None]:
# ============================================================
# Hyperparameters
# ============================================================

# Add light Gaussian noise to input images in a fraction of cases
params = dict(
    lr=1e-4,                 # lower LR for long, stable training
    gamma=0.9985,            # very slow exponential decay over 4000 epochs
    batch_size=1,
    dropout=0.005,           # slightly higher to counter long training
    num_epochs=4000,
    weight_decay=0.05,       # regularization to prevent overfitting
    features=(32, 32, 64, 128, 256),
    maxpool=True,
    batchnorm=False,
    skip=True,
    upsample=True,
    clip_grad_norm=1.0,      # safety for long runs
)

# ------------------------------------------------------------
# Loss function
# ------------------------------------------------------------
criterion = lambda gt, pred: torch.nn.functional.mse_loss(
    torch.view_as_real(gt),
    torch.view_as_real(pred),
)


# Train Model

In [None]:
# ============================================================
# Reproducibility
# ============================================================
torch.manual_seed(42)

# ============================================================
# Model
# ============================================================
model = UNet(
    in_c=32,
    out_c=1,
    dropout=params["dropout"],
    features=params["features"],
    maxpool=params["maxpool"],
    skip=params["skip"],
    batchnorm=params["batchnorm"],
    upsample=params["upsample"],
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"{n_params / 1e6:.2f} Mio.")
print("number_trainable_parameters =", n_params)
print(model)

# If criterion is a callable/loss instance, this is just informational
try:
    print("criterion =", criterion.__class__.__name__)
except Exception:
    pass

# ============================================================
# Train/val split (via samplers)
# ============================================================
train_loss = []
valid_loss = []
start_time = time.time()

train_idx, val_idx = torch.utils.data.random_split(
    torch.arange(len(dataset)),
    (0.9, 0.1),
)

train_loader = DataLoader(
    dataset,
    batch_size=params["batch_size"],
    sampler=SubsetRandomSampler(train_idx),
    num_workers=10,
)

val_loader = DataLoader(
    dataset,
    batch_size=params["batch_size"],
    sampler=SubsetRandomSampler(val_idx),
    num_workers=10,
)

# ============================================================
# Optimizer + scheduler
# ============================================================
optimizer = optim.AdamW(
    model.parameters(),
    lr=params["lr"],
    weight_decay=params["weight_decay"],
)

scheduler = lr_scheduler.ExponentialLR(
    optimizer,
    gamma=params["gamma"],
)

# ============================================================
# Training loop
# ============================================================
for epoch in range(params["num_epochs"]):
    epoch_start = time.time()

    # ------------------------
    # Train
    # ------------------------
    model.train()
    train_epoch_loss = 0.0

    for step, (x, *y) in enumerate(train_loader):
        x = x.to(device)
        y = [yi.to(device) for yi in y]

        optimizer.zero_grad()

        outputs = model(x)  # expected: iterable/list of heads
        head_losses = [criterion(out_i, y_i) for out_i, y_i in zip(outputs, y)]
        loss = sum(head_losses) / len(head_losses)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), params["clip_grad_norm"])
        optimizer.step()

        train_epoch_loss += loss.item()

    avg_train_loss = train_epoch_loss / len(train_loader)
    train_loss.append(avg_train_loss)

    # LR step (printed explicitly)
    lr_before = optimizer.param_groups[0]["lr"]
    scheduler.step()
    lr_after = optimizer.param_groups[0]["lr"]

    print("\n" + "-" * 60)
    print(f"Epoch {epoch + 1:4d}/{params['num_epochs']} | train_loss: {avg_train_loss:.6f}")
    print(f"lr: {lr_before:.6e} -> {lr_after:.6e}")

    # ------------------------
    # Validation
    # ------------------------
    model.eval()
    val_epoch_loss = 0.0

    with torch.no_grad():
        for x, *y in val_loader:
            x = x.to(device)
            y = [yi.to(device) for yi in y]

            outputs = model(x)
            head_losses = [criterion(out_i, y_i) for out_i, y_i in zip(outputs, y)]
            val_loss = sum(head_losses) / len(head_losses)

            val_epoch_loss += val_loss.item()

    avg_val_loss = val_epoch_loss / len(val_loader)
    valid_loss.append(avg_val_loss)

    epoch_time = time.time() - epoch_start
    print(f"val_loss: {avg_val_loss:.6f} | epoch_time: {epoch_time:.2f}s")

# ============================================================
# Done
# ============================================================
total_time = time.time() - start_time
print(f"\n[INFO] total time taken to train the model: {total_time:.2f}s")


# Save Model

In [None]:
torch.save(
    {
        "epochs": epoch,
        "parameters": params,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_loss": train_loss[-1],
        "validation_loss": valid_loss[-1],
    },
    'SaveModel.pth',
)
