In [1]:
# This is the method that uses the MATLAB Engine API for Python
import matlab.engine
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import  models, datasets, transforms
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import timm
import pickle
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler, MinMaxScaler
import numpy as np
import scipy.io as scio
from scipy.io import savemat
import h5py
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
import gc

In [2]:
device = torch.device('mps') if torch.backends.mps.is_available() else 'cpu'

In [3]:
eng = matlab.engine.start_matlab()

In [4]:
algorithm_input = scio.loadmat('algorithm_input_single.mat')
algorithm_input_mat = algorithm_input['algorithm_input']
algorithm_input_torch = torch.from_numpy(algorithm_input_mat.astype(np.complex64))


In [5]:
print(f"The datatype of the algorithm_input is {algorithm_input_torch.dtype}")

The datatype of the algorithm_input is torch.complex64


In [6]:
print(f"The shape of the algorithm_input is {algorithm_input_torch.shape}")

The shape of the algorithm_input is torch.Size([100000, 70])


In [7]:
algorithm_output = scio.loadmat('algorithm_output_single.mat')
algorithm_output_mat = algorithm_output['algorithm_output']
algorithm_output_torch = torch.from_numpy(algorithm_output_mat.astype(np.complex64))

In [8]:
print(f"The data type of the algorithm_output is {algorithm_output_torch.dtype}")

The data type of the algorithm_output is torch.complex64


In [9]:
print(f"The shape of the algorithm_output is {algorithm_output_torch.shape}")

The shape of the algorithm_output is torch.Size([100000, 70])


In [10]:
main_channels = scio.loadmat('main_channels_single.mat')
main_channels_mat = main_channels['main_channels']
main_channels_torch = torch.from_numpy(main_channels_mat.astype(np.complex64))

In [11]:
print(f"The data type of the main_channels is {main_channels_torch.dtype}")

The data type of the main_channels is torch.complex64


In [12]:
print(f"The shape of the main_channels is {main_channels_torch.shape}")

The shape of the main_channels is torch.Size([100000, 10, 70])


main_channels_mat = torch.load('main_channels_tensor.pt', weights_only=True)

In [13]:
symbols_store = scio.loadmat('symbols_store_single.mat')
symbols_store_mat = symbols_store['symbols_store']
symbols_store_torch = torch.from_numpy(symbols_store_mat.astype(np.int32))

In [14]:
print(f"The data type of the symbols is {symbols_store_torch.dtype}")

The data type of the symbols is torch.int32


In [15]:
print(f"The shape of the symbols is {symbols_store_torch.shape}")

The shape of the symbols is torch.Size([100000, 10])


In [174]:
symbols_store_torch[3,:]

tensor([ 2,  7,  6,  8,  0, 15, 12,  7,  1, 14], dtype=torch.int32)

In [17]:
class CustomDataset(Dataset):
    def __init__(self, algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat):
        # convert into PyTorch tensors and remember them
        self.algorithm_input_torch = algorithm_input_torch
        self.algorithm_output_torch = algorithm_output_torch
        self.main_channels_torch = main_channels_torch
        self.symbols_store_torch = symbols_store_torch
        
    def __len__(self):
        # this should return the size of the dataset
        return len(self.algorithm_input_torch)
    
    def __getitem__(self, idx):
        # this should return one sample from the dataset
        algorithm_input_torch = self.algorithm_input_torch[idx,:]
        algorithm_output_torch = self.algorithm_output_torch[idx,:]
        main_channels_torch = self.main_channels_torch[idx,:,:]
        symbols_store_torch = self.symbols_store_torch[idx,:]
        return algorithm_input_torch, algorithm_output_torch, main_channels_torch, symbols_store_torch

In [18]:
dataset = CustomDataset(algorithm_input_torch, algorithm_output_torch, main_channels_torch, symbols_store_torch)

In [19]:
# First, split the dataset into train and remaining (val + test)
train_set, remaining_set = train_test_split(dataset, test_size=0.2, random_state=42)

# Now, split the remaining set into validation and test sets
val_set, test_set = train_test_split(remaining_set, test_size=0.5, random_state=42)

In [20]:
# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_set, shuffle=False, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size= batch_size)

In [21]:
len(test_loader)

157

In [22]:
batch_alg_in_torch, batch_alg_out_torch, batch_main_chan_torch, batch_sym_torch = next(iter(train_loader))
print(f'shape of batch feature is {batch_alg_in_torch.shape}')
print(f'shape of batch feature is {batch_alg_out_torch.shape}')
print(f'shape of batch feature is {batch_main_chan_torch.shape}')
print(f'shape of batch feature is {batch_sym_torch.shape}')

shape of batch feature is torch.Size([64, 70])
shape of batch feature is torch.Size([64, 70])
shape of batch feature is torch.Size([64, 10, 70])
shape of batch feature is torch.Size([64, 10])


In [23]:
batch_alg_in_torch, batch_alg_out_torch, batch_main_chan_torch, batch_sym_torch = next(iter(train_loader))
print(f'data type of batch feature is {batch_alg_in_torch.dtype}')
print(f'data type of batch feature is {batch_alg_out_torch.dtype}')
print(f'data type of batch feature is {batch_main_chan_torch.dtype}')
print(f'data type of batch feature is {batch_sym_torch.dtype}')

data type of batch feature is torch.complex64
data type of batch feature is torch.complex64
data type of batch feature is torch.complex64
data type of batch feature is torch.int32


In [24]:
print(torch.stack([torch.real(batch_alg_in_torch).float(), torch.imag(batch_alg_in_torch).float()], dim=1).shape)
print(torch.stack([torch.real(batch_alg_out_torch).float(), torch.imag(batch_alg_out_torch).float()], dim=1).shape)
print(torch.stack([torch.real(batch_main_chan_torch).float(), torch.imag(batch_main_chan_torch).float()], dim=1).shape)

torch.Size([64, 2, 70])
torch.Size([64, 2, 70])
torch.Size([64, 2, 10, 70])


In [25]:
def complex_to_interleaved_real(complex_signal):
    real_part = complex_signal.real.to(dtype=torch.float32) 
    imag_part = complex_signal.imag.to(dtype=torch.float32) 
    interleaved_signal = torch.stack((real_part, imag_part), dim=2).reshape(complex_signal.shape[0], -1)
    return interleaved_signal

In [26]:
a = torch.tensor([1.0, 2.0]).unsqueeze(0)
b = torch.tensor([3.0, 4.0]).unsqueeze(0)

test_tensor = a + 1j * b
test_tensor = torch.cat((test_tensor, test_tensor, test_tensor), dim=0)
test_tensor.shape

torch.Size([3, 2])

In [27]:
interleaved_tensor = complex_to_interleaved_real(test_tensor)
interleaved_tensor.shape

torch.Size([3, 4])

In [28]:
def interleaved_real_to_complex(interleaved_signal):
    signal_length = interleaved_signal.shape[1] // 2
    real_part = interleaved_signal[:, 0::2]  # Extract even indices
    imag_part = interleaved_signal[:, 1::2]  # Extract odd indices
    complex_signal = torch.complex(real_part, imag_part)
    return complex_signal

In [29]:
back_to_complex = interleaved_real_to_complex(interleaved_tensor)
back_to_complex

tensor([[1.+3.j, 2.+4.j],
        [1.+3.j, 2.+4.j],
        [1.+3.j, 2.+4.j]])

In [30]:
def compute_papr_complex(signal):
    # Compute |x[n]|^2: Magnitude squared of the complex signal
    power_signal = torch.abs(signal)**2
    
    # Peak power
    peak_power_signal= torch.max(power_signal, dim=1).values

    # Average power
    avg_power_signal = torch.mean(power_signal, dim=1)

    # PAPR
    papr_signal = peak_power_signal / avg_power_signal
    
    return papr_signal

In [31]:
def papr_loss(signal_going_out, signal_coming_in):
    # Compute PAPR before and after
    papr_going_out = compute_papr_complex(signal_going_out)  # Transformed signal
    papr_coming_in = compute_papr_complex(signal_coming_in)  # Original signal

    # Penalize only if PAPR after is greater than PAPR before
    papr_diff = torch.relu(papr_going_out - papr_coming_in)
    
    return torch.mean(papr_diff), torch.mean(papr_going_out), torch.mean(papr_coming_in)

In [32]:
def prepare_for_matlab(batch_alg_in_mat, batch_alg_out_mat, batch_nn_out, batch_main_chan_mat, batch_sym_mat):
    
    batch_alg_in_mat_real = matlab.double(batch_alg_in_mat.real.tolist())
    batch_alg_in_mat_imag = matlab.double(batch_alg_in_mat.imag.tolist())

    batch_alg_out_mat_real = matlab.double(batch_alg_out_mat.real.tolist())
    batch_alg_out_mat_imag = matlab.double(batch_alg_out_mat.imag.tolist())

    batch_nn_out_real = matlab.double(batch_nn_out.real.tolist())
    batch_nn_out_imag = matlab.double(batch_nn_out.imag.tolist())

    batch_main_chan_mat_real = matlab.double(batch_main_chan_mat.real.tolist())
    batch_main_chan_mat_imag = matlab.double(batch_main_chan_mat.imag.tolist())

    #batch_sym_mat = int(batch_sym_mat)
    batch_sym_mat = matlab.uint32(batch_sym_mat.tolist())

    return batch_alg_in_mat_real, batch_alg_in_mat_imag, batch_alg_out_mat_real, batch_alg_out_mat_imag, batch_nn_out_real, batch_nn_out_imag, batch_main_chan_mat_real, batch_main_chan_mat_imag, batch_sym_mat
    

In [169]:
def calculate_ser(alg_in_real, alg_in_imag, alg_out_real, alg_out_imag, nn_out_real, nn_out_imag, main_chan_real, main_chan_imag, sym):
    snr_db = 40
    
    alg_in = torch.complex(alg_in_real, alg_in_imag).float()
    alg_out = torch.complex(alg_out_real, alg_out_imag).float()
    nn_out = torch.complex(nn_out_real, nn_out_imag).float()
    main_chan = torch.complex(main_chan_real, main_chan_imag).float()
    
    batch = main_chan.shape[0]
    Mr = 10
    Mt = 70
    nf_initial = Mr / (Mt - Mr)
    nf = torch.tensor(nf_initial, dtype=torch.float32)
    M = 16
    theta = 0.9
    qam_power = torch.tensor(10, dtype=torch.float32)
    
    SNR = 10 ** (snr_db / 10)
    sigma2n = 1 / SNR
    
    wn = torch.sqrt(torch.tensor(sigma2n / 2)) * (
        torch.randn(batch, Mr) + 1j * torch.randn(batch, Mr)
    )
    
    algorithm_input_reshaped = alg_in.view(batch, 1, -1)
    algorithm_output_reshaped = alg_out.view(batch, 1, -1)
    nn_output_reshaped = nn_out.view(batch, 1, -1)
    
    output_1 = torch.sum(main_chan * algorithm_input_reshaped, dim=2) + wn
    output_1_sym = qamdemod(torch.sqrt(qam_power * nf) * output_1, M)
    
    output_2 = torch.sum(main_chan * algorithm_output_reshaped, dim=2) + wn
    output_2_sym = qamdemod(torch.sqrt((qam_power * nf) / theta) * output_2, M)
    
    output_3 = torch.sum(main_chan * nn_output_reshaped, dim=2) + wn
    output_3_sym = qamdemod(torch.sqrt((qam_power * nf) / theta) * output_3, M)
    
    ser_1 = torch.sum(output_1_sym != sym, dim=1) / Mr
    ser_2 = torch.sum(output_2_sym != sym, dim=1) / Mr
    ser_3 = torch.sum(output_3_sym != sym, dim=1) / Mr
    
    ser = torch.cat([ser_1.unsqueeze(1), ser_2.unsqueeze(1), ser_3.unsqueeze(1)], dim=1)
    return ser


In [142]:
def ser_loss_matlab(batch_alg_in_mat, batch_alg_out_mat, batch_nn_out, batch_main_chan_mat, batch_sym_mat):

    batch_alg_in_mat_real, batch_alg_in_mat_imag, batch_alg_out_mat_real, batch_alg_out_mat_imag, batch_nn_out_real, batch_nn_out_imag, batch_main_chan_mat_real, batch_main_chan_mat_imag , batch_sym_mat = prepare_for_matlab(batch_alg_in_mat, batch_alg_out_mat, batch_nn_out, batch_main_chan_mat, batch_sym_mat)
    ser_mat = eng.calculate_ser(batch_alg_in_mat_real, batch_alg_in_mat_imag, batch_alg_out_mat_real, batch_alg_out_mat_imag, batch_nn_out_real, batch_nn_out_imag, batch_main_chan_mat_real, batch_main_chan_mat_imag , batch_sym_mat)
    ser_torch = torch.tensor(ser_mat, dtype=torch.float32)
    ser_diff = torch.relu(ser_torch[:,2] - ser_torch[:,1])
    return torch.mean(ser_diff), torch.mean(ser_torch[:,2]), torch.mean(ser_torch[:,1])

In [165]:
def ser_loss(batch_alg_in, batch_alg_out, batch_nn_out, batch_main_chan, batch_sym):

    ser = calculate_ser(torch.real(batch_alg_in), torch.imag(batch_alg_in), torch.real(batch_alg_out), torch.imag(batch_alg_out), torch.real(batch_nn_out), torch.imag(batch_nn_out), torch.real(batch_main_chan), torch.imag(batch_main_chan), batch_sym)
    #ser_torch = torch.tensor(ser, dtype=torch.float32)
    ser_diff = torch.relu(ser[:,2] - ser[:,1])
    return torch.mean(ser_diff), torch.mean(ser[:,2]), torch.mean(ser[:,1])

In [128]:
def qammod(symbols, M=16):
    assert M == 16, "This function is specifically for 16-QAM."

    # Define Gray-coded QAM-16 constellation mapping
    real_values = torch.tensor([-3, -3, -3, -3, -1, -1, -1, -1, 3, 3, 3, 3, 1, 1, 1, 1], dtype=torch.float32)
    imag_values = torch.tensor([3, 1, -3, -1, 3, 1, -3, -1, 3, 1, -3, -1, 3, 1, -3, -1], dtype=torch.float32)

    # Flatten input before indexing
    orig_shape = symbols.shape
    symbols = symbols.view(-1)  # Flatten

    # Map input symbols to corresponding constellation points
    modulated_signal = torch.complex(real_values[symbols], imag_values[symbols])

    return modulated_signal.view(orig_shape)  # Reshape back to original

In [130]:
def qamdemod(signal, M=16):
    assert M == 16, "This function is specifically for 16-QAM."

    # Define Gray-coded QAM-16 constellation mapping
    real_values = torch.tensor([-3, -3, -3, -3, -1, -1, -1, -1, 3, 3, 3, 3, 1, 1, 1, 1], dtype=torch.float32)
    imag_values = torch.tensor([3, 1, -3, -1, 3, 1, -3, -1, 3, 1, -3, -1, 3, 1, -3, -1], dtype=torch.float32)

    constellation = torch.complex(real_values, imag_values)

    # Flatten input before processing
    orig_shape = signal.shape
    signal = signal.view(-1)  # Flatten

    # Find the closest constellation point for each received symbol
    distances = torch.abs(signal.unsqueeze(1) - constellation)
    closest_indices = torch.argmin(distances, dim=1)

    return closest_indices.view(orig_shape)  # Reshape back to original

In [136]:
data_to_use = torch.randint(0, 16, (64, 10))

In [137]:
modulated_signal = qammod(data_to_use)

In [138]:
closest_indices = qamdemod(modulated_signal)

In [80]:
batch_sym_torch.shape

torch.Size([64, 10])

In [81]:
batch_sym_torch_test = batch_sym_torch[0,:]

In [83]:
batch_sym_torch_test

tensor([ 8, 13,  3,  3,  5,  6,  3,  9, 14,  0], dtype=torch.int32)

In [85]:
a = 0.9487
b = 0.3162

In [98]:
test_data_real = [a, -a, -a, b, -a, -a, b, -b, -b, -a]
test_data_imag = [-b, -b, a, -a, -b, a, -b, -a, a, a]

In [99]:
test_data_real_torch = torch.tensor(test_data_real, dtype=torch.float32)
test_data_imag_torch = torch.tensor(test_data_imag, dtype=torch.float32)

In [100]:
test_data_torch = torch.complex(test_data_real_torch, test_data_imag_torch)

In [101]:
symbols = qamdemod(test_data_torch, 16)

In [102]:
symbols

tensor([9, 5, 6, 9, 5, 6, 9, 5, 6, 6])

In [64]:
def power_loss(nn_output, algorithm_output_mat_for_nn):
    power_nn_output = torch.sum(torch.square(torch.abs(nn_output)),dim=1)
    power_algorithm_output_mat_for_nn = torch.sum(torch.square(torch.abs(algorithm_output_mat_for_nn)),dim=1)
    power_diff = torch.abs(power_nn_output - power_algorithm_output_mat_for_nn)
    return torch.mean(power_diff), torch.mean(power_nn_output), torch.mean(power_algorithm_output_mat_for_nn)

In [65]:
a = torch.rand([64,10])
b = torch.rand([64,10])
c = torch.rand([64,10])
d = torch.rand([64,10])

In [66]:
ab = torch.complex(a,b)
cd = torch.complex(c,d)

In [67]:
m,n,o = power_loss(ab,cd)

In [68]:
m

tensor(1.4290)

In [41]:
loss = nn.MSELoss()

In [42]:
loss(a,b)

tensor(0.1719)

In [43]:
def mse_loss(nn_output, algorithm_output_mat_for_nn):
    mse = torch.mean(torch.mean(torch.abs(torch.square(nn_output - algorithm_output_mat_for_nn)), dim=1))
    return mse.round(decimals=5)

In [44]:
c = mse_loss(a,b)
print(f'{c:.4f} and {c.shape}')

0.1719 and torch.Size([])


In [47]:
batch_alg_in_torch.dtype

torch.complex64

In [50]:
mse_loss(complex_to_interleaved_real(batch_alg_in_torch), complex_to_interleaved_real(batch_alg_out_torch))

tensor(0.0009)

In [57]:
len(power_loss(complex_to_interleaved_real(batch_alg_in_torch), complex_to_interleaved_real(batch_alg_out_torch))[0])

64

In [55]:
power_loss(batch_alg_in_torch, batch_alg_out_torch)

(tensor([0.0232, 0.0100, 0.0217, 0.0090, 0.0375, 0.0086, 0.0166, 0.0394, 0.0065,
         0.0181, 0.0126, 0.0216, 0.0455, 0.0059, 0.0300, 0.0534, 0.0479, 0.0013,
         0.0030, 0.0168, 0.0001, 0.0015, 0.0054, 0.0318, 0.0212, 0.0653, 0.0372,
         0.0177, 0.0405, 0.0078, 0.0131, 0.0074, 0.0366, 0.0253, 0.0197, 0.0126,
         0.0334, 0.0015, 0.0199, 0.0058, 0.0335, 0.0195, 0.0085, 0.0067, 0.0320,
         0.0081, 0.0225, 0.0602, 0.0242, 0.0147, 0.0464, 0.0086, 0.0594, 0.0177,
         0.0074, 0.0555, 0.0261, 0.0421, 0.0262, 0.0159, 0.0328, 0.0201, 0.0120,
         0.0404]),
 tensor([0.9349, 1.0507, 0.9406, 1.1183, 0.7892, 1.2544, 0.9913, 0.7664, 1.2360,
         1.3330, 1.0260, 0.9670, 0.7224, 1.1259, 0.8657, 0.6489, 0.6813, 1.1918,
         1.1810, 0.9952, 1.1672, 1.1670, 1.1140, 0.8435, 0.9918, 0.4971, 0.8092,
         1.0498, 0.8340, 1.0760, 1.2925, 1.0987, 0.7875, 0.9152, 0.9766, 1.0277,
         0.8658, 1.1373, 0.9522, 1.1290, 0.8159, 0.9756, 1.2426, 1.0846, 0.8498,
         

In [45]:
ab = torch.complex(a,b)

In [None]:
ab

In [63]:
torch.sum(torch.square(torch.abs(ab)))

tensor(6.2460)

In [70]:
a-b

tensor([[-0.4149, -0.0350, -0.1526,  0.0091, -0.1886, -0.2704,  0.1865, -0.1763,
         -0.1867,  0.0612]])

In [49]:
class CSIModel(nn.Module):
    def __init__(self):
        super(CSIModel, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv2 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1)
        self.bnconv3 = nn.BatchNorm2d(1)
        
        self.flatten = nn.Flatten()
        
        self.linear1 = nn.Linear(700, 280)
        self.bnlin1 = nn.BatchNorm1d(280)
        
        self.linear2 = nn.Linear(280, 140)


    def forward(self, x):
        #x = x.unsqueeze(1)
        x = F.relu(self.bnconv1(self.conv1(x)))
        x = F.relu(self.bnconv2(self.conv2(x)))
        x = F.relu(self.bnconv3(self.conv3(x)))
        x = self.flatten(x)
        x = F.relu(self.bnlin1(self.linear1(x)))
        x = self.linear2(x)

        return x

In [50]:
test_output_1 = CSIModel()(torch.rand([64,2,10,70]))
test_output_1.shape

torch.Size([64, 140])

In [42]:
test_tensor = torch.rand([32,8,64,64])

In [43]:
F.adaptive_avg_pool2d(test_tensor,(1,1)).squeeze(-1).squeeze(-1).shape

torch.Size([32, 8])

In [51]:
class SignalModel(nn.Module):
    def __init__(self):
        super(SignalModel, self).__init__()
        
        self.linear1 = nn.Linear(140, 70)
        self.bnlin1 = nn.BatchNorm1d(70)

        self.linear2 = nn.Linear(70, 14)
        self.bnlin2 = nn.BatchNorm1d(14)

        self.linear3 = nn.Linear(14, 70)
        self.bnlin3 = nn.BatchNorm1d(70)

        self.linear4 = nn.Linear(70, 140)



    def forward(self, x):
        #x = x.unsqueeze(1)
        x = F.relu(self.bnlin1(self.linear1(x)))
        x = F.relu(self.bnlin2(self.linear2(x)))
        x = F.relu(self.bnlin3(self.linear3(x)))
        x = self.linear4(x)

        return x

In [52]:
test_output_2 = SignalModel()(torch.rand([64,140]))
test_output_2.shape

torch.Size([64, 140])

In [53]:
test_output = torch.cat([test_output_1, test_output_2], dim=1)
test_output.shape

torch.Size([64, 280])

In [54]:
class CombinedModel(nn.Module):
    def __init__(self):
        super(CombinedModel, self).__init__()
        
        self.csimodel = CSIModel()
        self.signalmodel = SignalModel()

        
        self.linear1 = nn.Linear(280,140)
        self.bnlin1 = nn.BatchNorm1d(280)
        

    def forward(self, x1, x2):
        #x = x.unsqueeze(1)
        x1 = self.csimodel(x1)
        x2 = self.signalmodel(x2)
        x = torch.cat([x1, x2], dim=1)
        x = self.linear1(self.bnlin1(x))
  
        return x

In [55]:
test_output_final= CombinedModel()(torch.rand([64,2,10,70]), torch.rand([64,140]))
test_output_final.shape

torch.Size([64, 140])

In [170]:
model = CombinedModel().to(device)

# Define the loss functions
loss = torch.nn.MSELoss()  # For classification

# Define an optimizer (both for the encoder and the decoder!)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

#scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.01)  # Learning rate decay scheduler
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=2)

# Variables for early stopping and best parameters
best_loss = float('inf')
patience_limit = 2


best_model = None

train_losses = []
val_losses = []

alpha = 0.4
beta = 0.4
gamma = 0.2

# Train the model
EPOCHS = 10
for epoch in range(EPOCHS):
    running_train_loss = 0.0
    
    model.train()
    progress_bar_train = tqdm(enumerate(train_loader), total=len(train_loader), ncols=200)
    for index, (algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat) in progress_bar_train:
        # Forward pass
        #algorithm_output_mat_for_nn = (batch_complex_autocorrelation(algorithm_output_mat)).to(device)
        main_channels_mat_for_nn = torch.stack([torch.real(main_channels_mat).float(), torch.imag(main_channels_mat).float()], dim=1).to(device)
        algorithm_input_mat_for_nn = (complex_to_interleaved_real(algorithm_input_mat)).to(device)
        algorithm_output_mat_for_nn = (complex_to_interleaved_real(algorithm_output_mat)).to(device)

        
        nn_output = model(main_channels_mat_for_nn, algorithm_input_mat_for_nn)
        
        # Calculate loss
        mse_train_loss = loss(nn_output, algorithm_output_mat_for_nn)

        nn_output_control =  interleaved_real_to_complex(nn_output)
        algorithm_output_mat_for_nn_control =  interleaved_real_to_complex(algorithm_output_mat_for_nn)
        
        papr_diff, nn_papr, alg_papr  = papr_loss(nn_output_control, algorithm_output_mat_for_nn_control)
        ser_diff, nn_ser, alg_ser = ser_loss(algorithm_input_mat, algorithm_output_mat_for_nn_control.cpu(), nn_output_control.cpu(), main_channels_mat, symbols_store_mat)
        power_diff, nn_power, alg_power = power_loss(nn_output_control, algorithm_output_mat_for_nn_control)
        
        train_loss = alpha*papr_diff + beta*ser_diff + gamma*power_diff
        #train_loss = mse_train_loss

        # Backward pass
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Update running loss
        running_train_loss += train_loss.item()
        avg_train_loss = running_train_loss / (index + 1)

        # Get current learning rate from the optimizer
        current_lr = optimizer.param_groups[0]['lr']

        # Print metrics
        #progress_bar_train.set_description(f'Epoch [{epoch + 1}/{EPOCHS}] MSELos:{avg_train_loss1:.4f} MSEWeig{mse_weight:.2f} CELos:{avg_train_loss2:.4f} CEWeig{ce_weight:.2f} TrLos:{avg_train_loss:.4f} Tr.Acc: {avg_train_acc*100:.2f}%')
        progress_bar_train.set_description(f" Epoch [{epoch + 1}/{EPOCHS}] T Loss:{avg_train_loss:.4f} PAPR_dff: {papr_diff:.4f} NN_PAPR: {nn_papr:.4f} Alg_PAPR: {alg_papr:.4f} SER_dff: {ser_diff:.4f} NN_SER: {nn_ser:.4f} Alg_SER: {alg_ser:.4f} Power_dff: {papr_diff:.4f} NN_Power: {nn_power:.4f} Alg_Power: {alg_power:.4f} LR is {current_lr}")
    
    #train_losses.append(avg_train_loss)
    train_losses.append(avg_train_loss)

    print(f"Training has completed epoch {epoch+1}")
    
    # Validation loop
    running_val_loss = 0.0

    
    model.eval()
    progress_bar_val = tqdm(enumerate(val_loader), total=len(val_loader), ncols=200)
    for index, (algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat) in progress_bar_val:
        
        #algorithm_output_mat_for_nn = (batch_complex_autocorrelation(algorithm_output_mat)).to(device)
        main_channels_mat_for_nn = torch.stack([torch.real(main_channels_mat).float(), torch.imag(main_channels_mat).float()], dim=1).to(device)
        algorithm_input_mat_for_nn = (complex_to_interleaved_real(algorithm_input_mat)).to(device)
        algorithm_output_mat_for_nn = (complex_to_interleaved_real(algorithm_output_mat)).to(device)

        
        with torch.no_grad():
            
            nn_output = model(main_channels_mat_for_nn, algorithm_input_mat_for_nn)

            # Calculate losses
            mse_val_loss = loss(nn_output, algorithm_output_mat_for_nn)

            nn_output_control =  interleaved_real_to_complex(nn_output)
            algorithm_output_mat_for_nn_control =  interleaved_real_to_complex(algorithm_output_mat_for_nn)

            papr_diff, nn_papr, alg_papr  = papr_loss(nn_output_control, algorithm_output_mat_for_nn_control)
            ser_diff, nn_ser, alg_ser = ser_loss(algorithm_input_mat, algorithm_output_mat_for_nn_control.cpu(), nn_output_control.cpu(), main_channels_mat, symbols_store_mat)
            power_diff, nn_power, alg_power = power_loss(nn_output_control, algorithm_output_mat_for_nn_control)
        
            val_loss = alpha*papr_diff + beta*ser_diff + gamma*power_diff

            # Update running loss
            running_val_loss += val_loss.item()
            avg_val_loss = running_val_loss / (index + 1)

    
            progress_bar_val.set_description(f" Epoch [{epoch + 1}/{EPOCHS}] V Loss:{avg_val_loss:.4f} PAPR_dff: {papr_diff:.4f} NN_PAPR: {nn_papr:.4f} Alg_PAPR: {alg_papr:.4f} SER_dff: {ser_diff:.4f} NN_SER: {nn_ser:.4f} Alg_SER: {alg_ser:.4f} Power_dff: {papr_diff:.4f} NN_Power: {nn_power:.4f} Alg_Power: {alg_power:.4f}")
    
    #val_losses.append(avg_val_loss)
    val_losses.append(avg_val_loss)
    
    scheduler.step(running_val_loss)


    # Early stopping
    if avg_val_loss < best_loss:  # Now checking for the best accuracy
        best_loss = avg_val_loss
        best_epoch = epoch + 1
        best_train_loss = avg_train_loss
        patience_ = 0
        best_weights = copy.deepcopy(model.state_dict())
        print(f"Best Validation Loss is now: {best_loss:.4f} at Epoch: {best_epoch}")
    else:
        patience_ += 1
        print(f"This is Epoch: {patience_} without improvement")
        print(f"Current Validation Loss is: {avg_val_loss:.4f} at Epoch: {epoch+1}")
        print(f"Best Validation Loss remains: {best_loss:.4f} at Epoch: {best_epoch}")
        if patience_ > patience_limit:  # Patience limit before stopping
            print("Early stopping triggered! Restoring best model weights.")
            print(f"Best Validation Loss was: {best_loss:.4f} at Epoch: {best_epoch}")
            break

best_model = model.cpu()
best_model.load_state_dict(best_weights)


Epoch [1/10] T Loss:1.4263 PAPR_dff: 0.8752 NN_PAPR: 3.0062 Alg_PAPR: 2.1361 SER_dff: 0.0797 NN_SER: 0.9328 Alg_SER: 0.8609 Power_dff: 0.8752 NN_Power: 2.4763 Alg_Power: 1.0154 LR is 0.001: 100%|█| 1

Training has completed epoch 1



Epoch [1/10] V Loss:0.6565 PAPR_dff: 0.6828 NN_PAPR: 2.7743 Alg_PAPR: 2.0916 SER_dff: 0.0688 NN_SER: 0.9500 Alg_SER: 0.8938 Power_dff: 0.6828 NN_Power: 1.9558 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

Best Validation Loss is now: 0.6565 at Epoch: 1



Epoch [2/10] T Loss:0.4336 PAPR_dff: 0.1868 NN_PAPR: 2.2535 Alg_PAPR: 2.1221 SER_dff: 0.0922 NN_SER: 0.9406 Alg_SER: 0.8531 Power_dff: 0.1868 NN_Power: 1.7116 Alg_Power: 1.0067 LR is 0.001: 100%|█| 1

Training has completed epoch 2



Epoch [2/10] V Loss:0.2943 PAPR_dff: 0.2441 NN_PAPR: 2.2875 Alg_PAPR: 2.0916 SER_dff: 0.0750 NN_SER: 0.9250 Alg_SER: 0.8625 Power_dff: 0.2441 NN_Power: 1.3400 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

Best Validation Loss is now: 0.2943 at Epoch: 2



Epoch [3/10] T Loss:0.1985 PAPR_dff: 0.0551 NN_PAPR: 1.8590 Alg_PAPR: 2.1104 SER_dff: 0.0531 NN_SER: 0.9422 Alg_SER: 0.8937 Power_dff: 0.0551 NN_Power: 1.4133 Alg_Power: 1.0374 LR is 0.001: 100%|█| 1

Training has completed epoch 3



Epoch [3/10] V Loss:0.2051 PAPR_dff: 0.0288 NN_PAPR: 1.8384 Alg_PAPR: 2.0916 SER_dff: 0.0688 NN_SER: 0.9187 Alg_SER: 0.8563 Power_dff: 0.0288 NN_Power: 1.1216 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

Best Validation Loss is now: 0.2051 at Epoch: 3



Epoch [4/10] T Loss:0.1332 PAPR_dff: 0.0379 NN_PAPR: 1.6923 Alg_PAPR: 2.0927 SER_dff: 0.0781 NN_SER: 0.9578 Alg_SER: 0.8859 Power_dff: 0.0379 NN_Power: 1.1997 Alg_Power: 1.0343 LR is 0.001: 100%|█| 1

Training has completed epoch 4



Epoch [4/10] V Loss:0.1165 PAPR_dff: 0.0983 NN_PAPR: 1.7893 Alg_PAPR: 2.0916 SER_dff: 0.0625 NN_SER: 0.9250 Alg_SER: 0.8687 Power_dff: 0.0983 NN_Power: 1.3204 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

Best Validation Loss is now: 0.1165 at Epoch: 4



Epoch [5/10] T Loss:0.1096 PAPR_dff: 0.0521 NN_PAPR: 1.6987 Alg_PAPR: 2.0605 SER_dff: 0.0594 NN_SER: 0.9484 Alg_SER: 0.8937 Power_dff: 0.0521 NN_Power: 1.1555 Alg_Power: 1.0621 LR is 0.001: 100%|█| 1

Training has completed epoch 5



Epoch [5/10] V Loss:0.1471 PAPR_dff: 0.1290 NN_PAPR: 1.7412 Alg_PAPR: 2.0916 SER_dff: 0.0938 NN_SER: 0.9375 Alg_SER: 0.8500 Power_dff: 0.1290 NN_Power: 1.2808 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

This is Epoch: 1 without improvement
Current Validation Loss is: 0.1471 at Epoch: 5
Best Validation Loss remains: 0.1165 at Epoch: 4



Epoch [6/10] T Loss:0.0992 PAPR_dff: 0.0326 NN_PAPR: 1.5995 Alg_PAPR: 2.1358 SER_dff: 0.0562 NN_SER: 0.9219 Alg_SER: 0.8750 Power_dff: 0.0326 NN_Power: 1.1577 Alg_Power: 1.0085 LR is 0.001: 100%|█| 1

Training has completed epoch 6



Epoch [6/10] V Loss:0.1497 PAPR_dff: 0.0000 NN_PAPR: 1.5803 Alg_PAPR: 2.0916 SER_dff: 0.0750 NN_SER: 0.9312 Alg_SER: 0.8562 Power_dff: 0.0000 NN_Power: 1.0556 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

This is Epoch: 2 without improvement
Current Validation Loss is: 0.1497 at Epoch: 6
Best Validation Loss remains: 0.1165 at Epoch: 4



Epoch [7/10] T Loss:0.0834 PAPR_dff: 0.0238 NN_PAPR: 1.7538 Alg_PAPR: 2.1012 SER_dff: 0.0688 NN_SER: 0.9422 Alg_SER: 0.8766 Power_dff: 0.0238 NN_Power: 1.1277 Alg_Power: 1.0104 LR is 0.001: 100%|█| 1

Training has completed epoch 7



Epoch [7/10] V Loss:0.0720 PAPR_dff: 0.0000 NN_PAPR: 1.6828 Alg_PAPR: 2.0916 SER_dff: 0.0625 NN_SER: 0.9438 Alg_SER: 0.8813 Power_dff: 0.0000 NN_Power: 1.0606 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

Best Validation Loss is now: 0.0720 at Epoch: 7



Epoch [8/10] T Loss:0.0683 PAPR_dff: 0.0085 NN_PAPR: 1.6112 Alg_PAPR: 2.1236 SER_dff: 0.0766 NN_SER: 0.9438 Alg_SER: 0.8703 Power_dff: 0.0085 NN_Power: 1.0318 Alg_Power: 1.0165 LR is 0.001: 100%|█| 1

Training has completed epoch 8



Epoch [8/10] V Loss:0.0636 PAPR_dff: 0.0000 NN_PAPR: 1.5354 Alg_PAPR: 2.0916 SER_dff: 0.0750 NN_SER: 0.9562 Alg_SER: 0.8813 Power_dff: 0.0000 NN_Power: 1.0438 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

Best Validation Loss is now: 0.0636 at Epoch: 8



Epoch [9/10] T Loss:0.0496 PAPR_dff: 0.0000 NN_PAPR: 1.5125 Alg_PAPR: 2.0842 SER_dff: 0.0781 NN_SER: 0.9422 Alg_SER: 0.8703 Power_dff: 0.0000 NN_Power: 1.0122 Alg_Power: 1.0370 LR is 0.001: 100%|█| 1

Training has completed epoch 9



Epoch [9/10] V Loss:0.0559 PAPR_dff: 0.0000 NN_PAPR: 1.4727 Alg_PAPR: 2.0916 SER_dff: 0.0563 NN_SER: 0.9375 Alg_SER: 0.8812 Power_dff: 0.0000 NN_Power: 1.0753 Alg_Power: 1.0805: 100%|█| 157/157 [00:0

Best Validation Loss is now: 0.0559 at Epoch: 9



Epoch [10/10] T Loss:0.0465 PAPR_dff: 0.0000 NN_PAPR: 1.4086 Alg_PAPR: 2.0901 SER_dff: 0.0719 NN_SER: 0.9328 Alg_SER: 0.8641 Power_dff: 0.0000 NN_Power: 1.0178 Alg_Power: 1.0573 LR is 0.001: 100%|█| 

Training has completed epoch 10



Epoch [10/10] V Loss:0.0473 PAPR_dff: 0.0000 NN_PAPR: 1.3860 Alg_PAPR: 2.0916 SER_dff: 0.0500 NN_SER: 0.9563 Alg_SER: 0.9062 Power_dff: 0.0000 NN_Power: 1.1113 Alg_Power: 1.0805: 100%|█| 157/157 [00:

Best Validation Loss is now: 0.0473 at Epoch: 10


<All keys matched successfully>

In [171]:
test_losses = []
running_test_loss = 0.0


progress_bar_test = tqdm(enumerate(test_loader), total=len(test_loader), ncols=150)
for index, (algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat) in progress_bar_test:
        
    main_channels_mat_for_nn = torch.stack([torch.real(main_channels_mat).float(), torch.imag(main_channels_mat).float()], dim=1)
    algorithm_input_mat_for_nn = (complex_to_interleaved_real(algorithm_input_mat))
    algorithm_output_mat_for_nn = (complex_to_interleaved_real(algorithm_output_mat))
    
    with torch.no_grad():
            
        nn_output = best_model(main_channels_mat_for_nn, algorithm_input_mat_for_nn)

        # Calculate losses
        test_loss = loss(nn_output, algorithm_output_mat_for_nn)

        # Update running loss
        running_test_loss += test_loss.item()
            
        avg_test_loss = running_test_loss / (index + 1)

        nn_output_control =  interleaved_real_to_complex(nn_output)
        algorithm_output_mat_for_nn_control =  interleaved_real_to_complex(algorithm_output_mat_for_nn)
        
        papr_diff, nn_papr, alg_papr = papr_loss(nn_output_control, algorithm_output_mat_for_nn_control)
        ser_diff, nn_ser, alg_ser = ser_loss(algorithm_input_mat, algorithm_output_mat_for_nn_control, nn_output_control, main_channels_mat, symbols_store_mat)

        progress_bar_test.set_description(f'Epoch [{epoch + 1}/{EPOCHS}] Te Loss:{avg_test_loss:.4f} PAPR_dff: {papr_diff:.4f} SER_dff: {ser_diff:.4f}')

        
        if index < 1:
            total_nn_out = nn_output_control
            total_alg_in = algorithm_input_mat
            total_alg_out = algorithm_output_mat_for_nn_control
            total_main_channels = main_channels_mat
            total_symbols = symbols_store_mat
        else:
            total_nn_out = torch.cat([total_nn_out, nn_output_control], dim=0, out=None)
            total_alg_in = torch.cat([total_alg_in, algorithm_input_mat], dim=0, out=None)
            total_alg_out = torch.cat([total_alg_out, algorithm_output_mat_for_nn_control], dim=0, out=None)
            total_main_channels = torch.cat([total_main_channels, main_channels_mat], dim=0, out=None)
            total_symbols = torch.cat([total_symbols, symbols_store_mat], dim=0, out=None)

test_losses.append(avg_test_loss)


poch [10/10] Te Loss:0.0146 PAPR_dff: 0.0000 SER_dff: 0.0438: 100%|████████████████████████████████████████████████| 157/157 [00:02<00:00, 59.42it/s]

In [172]:
total_alg_in_real, total_alg_in_imag, total_alg_out_real, total_alg_out_imag, total_nn_out_real, total_nn_out_imag, total_main_channels_real, total_main_channels_imag, total_symbols = prepare_for_matlab(total_alg_in, total_alg_out, total_nn_out, total_main_channels, total_symbols)

In [173]:
# Save all variables in a dictionary
savemat("output_from_pytorch.mat", {
    "total_alg_in_real": total_alg_in_real,
    "total_alg_in_imag": total_alg_in_imag,
    "total_alg_out_real": total_alg_out_real,
    "total_alg_out_imag": total_alg_out_imag,
    "total_nn_out_real": total_nn_out_real,
    "total_nn_out_imag": total_nn_out_imag,
    "total_main_channels_real": total_main_channels_real,
    "total_main_channels_imag": total_main_channels_imag,
    "total_symbols": total_symbols
})