In [1]:
# This is the method that uses the MATLAB Engine API for Python
import matlab.engine
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import  models, datasets, transforms
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import timm
import pickle
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler, MinMaxScaler
import numpy as np
import scipy.io as scio
from scipy.io import savemat
import h5py
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
import gc

In [2]:
device = torch.device('mps') if torch.backends.mps.is_available() else 'cpu'

In [3]:
eng = matlab.engine.start_matlab()

In [6]:
algorithm_input = scio.loadmat('algorithm_input_single.mat')
algorithm_input_mat = algorithm_input['algorithm_input']

In [9]:
algorithm_output = scio.loadmat('algorithm_output_single.mat')
algorithm_output_mat = algorithm_output['algorithm_output']

main_channels = scio.loadmat('main_channels_.mat')
main_channels_mat = main_channels['main_channels']

In [None]:
main_channels_mat = torch.load('main_channels_tensor.pt', weights_only=True)

In [None]:
symbols_store = scio.loadmat('symbols_store.mat')
symbols_store_mat = symbols_store['symbols_store']

In [14]:
class CustomDataset(Dataset):
    def __init__(self, algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat):
        # convert into PyTorch tensors and remember them
        self.algorithm_input_mat = algorithm_input_mat
        self.algorithm_output_mat = algorithm_output_mat
        self.main_channels_mat = main_channels_mat
        self.symbols_store_mat = symbols_store_mat
        
    def __len__(self):
        # this should return the size of the dataset
        return len(self.algorithm_input_mat)
    
    def __getitem__(self, idx):
        # this should return one sample from the dataset
        algorithm_input_mat = self.algorithm_input_mat[idx,:]
        algorithm_output_mat = self.algorithm_output_mat[idx,:]
        main_channels_mat = self.main_channels_mat[idx,:,:]
        symbols_store_mat = self.symbols_store_mat[idx,:]
        return algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat

In [15]:
dataset = CustomDataset(algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat)

In [16]:
# First, split the dataset into train and remaining (val + test)
train_set, remaining_set = train_test_split(dataset, test_size=200000, random_state=42)

# Now, split the remaining set into validation and test sets
val_set, test_set = train_test_split(remaining_set, test_size=100000, random_state=42)

In [17]:
# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_set, shuffle=False, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size= batch_size)

In [18]:
batch_alg_in_mat, batch_alg_out_mat, batch_main_chan_mat, batch_sym_mat = next(iter(train_loader))
print(f'shape of batch feature is {batch_alg_in_mat.shape}')
print(f'shape of batch feature is {batch_alg_out_mat.shape}')
print(f'shape of batch feature is {batch_main_chan_mat.shape}')
print(f'shape of batch feature is {batch_sym_mat.shape}')

shape of batch feature is torch.Size([64, 70])
shape of batch feature is torch.Size([64, 70])
shape of batch feature is torch.Size([64, 10, 70])
shape of batch feature is torch.Size([64, 10])


In [19]:
torch.stack([torch.real(batch_main_chan_mat).float(), torch.imag(batch_main_chan_mat).float()], dim=1).shape

torch.Size([64, 2, 10, 70])

In [20]:
def complex_to_interleaved_real(complex_signal):
    real_part = complex_signal.real.to(dtype=torch.float32) 
    imag_part = complex_signal.imag.to(dtype=torch.float32) 
    interleaved_signal = torch.stack((real_part, imag_part), dim=2).reshape(complex_signal.shape[0], -1)
    return interleaved_signal

In [21]:
def interleaved_real_to_complex(interleaved_signal):
    signal_length = interleaved_signal.shape[1] // 2
    real_part = interleaved_signal[:, 0::2]  # Extract even indices
    imag_part = interleaved_signal[:, 1::2]  # Extract odd indices
    complex_signal = torch.complex(real_part, imag_part)
    return complex_signal

In [22]:
def compute_papr_complex(signal):
    # Compute |x[n]|^2: Magnitude squared of the complex signal
    power_signal = torch.abs(signal)**2
    
    # Peak power
    peak_power_signal= torch.max(power_signal, dim=1).values

    # Average power
    avg_power_signal = torch.mean(power_signal, dim=1)

    # PAPR
    papr_signal = peak_power_signal / avg_power_signal
    
    return papr_signal

In [23]:
def papr_loss(signal_going_out, signal_coming_in):
    # Compute PAPR before and after
    papr_going_out = compute_papr_complex(signal_going_out)  # Transformed signal
    papr_coming_in = compute_papr_complex(signal_coming_in)  # Original signal

    # Penalize only if PAPR after is greater than PAPR before
    papr_diff = torch.relu(papr_going_out - papr_coming_in)
    
    return torch.mean(papr_diff), torch.mean(papr_going_out), torch.mean(papr_coming_in)

In [24]:
def prepare_for_matlab(batch_alg_in_mat, batch_alg_out_mat, batch_nn_out, batch_main_chan_mat, batch_sym_mat):
    
    batch_alg_in_mat_real = matlab.double(batch_alg_in_mat.real.tolist())
    batch_alg_in_mat_imag = matlab.double(batch_alg_in_mat.imag.tolist())

    batch_alg_out_mat_real = matlab.double(batch_alg_out_mat.real.tolist())
    batch_alg_out_mat_imag = matlab.double(batch_alg_out_mat.imag.tolist())

    batch_nn_out_real = matlab.double(batch_nn_out.real.tolist())
    batch_nn_out_imag = matlab.double(batch_nn_out.imag.tolist())

    batch_main_chan_mat_real = matlab.double(batch_main_chan_mat.real.tolist())
    batch_main_chan_mat_imag = matlab.double(batch_main_chan_mat.imag.tolist())

    batch_sym_mat = matlab.uint32(batch_sym_mat.tolist())

    return batch_alg_in_mat_real, batch_alg_in_mat_imag, batch_alg_out_mat_real, batch_alg_out_mat_imag, batch_nn_out_real, batch_nn_out_imag, batch_main_chan_mat_real, batch_main_chan_mat_imag, batch_sym_mat 

In [25]:
def ser_loss(batch_alg_in_mat, batch_alg_out_mat, batch_nn_out, batch_main_chan_mat, batch_sym_mat):

    batch_alg_in_mat_real, batch_alg_in_mat_imag, batch_alg_out_mat_real, batch_alg_out_mat_imag, batch_nn_out_real, batch_nn_out_imag, batch_main_chan_mat_real, batch_main_chan_mat_imag , batch_sym_mat = prepare_for_matlab(batch_alg_in_mat, batch_alg_out_mat, batch_nn_out, batch_main_chan_mat, batch_sym_mat)
    ser_mat = eng.calculate_ser(batch_alg_in_mat_real, batch_alg_in_mat_imag, batch_alg_out_mat_real, batch_alg_out_mat_imag, batch_nn_out_real, batch_nn_out_imag, batch_main_chan_mat_real, batch_main_chan_mat_imag , batch_sym_mat)
    ser_torch = torch.tensor(ser_mat, dtype=torch.float32)
    ser_diff = torch.relu(ser_torch[:,2] - ser_torch[:,1])
    return torch.mean(ser_diff), torch.mean(ser_torch[:,2]), torch.mean(ser_torch[:,1])

In [26]:
def batch_complex_autocorrelation(signals):
    signals_conj = torch.conj(signals)  # Compute complex conjugate
    results = torch.zeros((signals.size(0), signals.size(1)), dtype=torch.cfloat)  # Output buffer
    
    for i in range(signals.size(0)):  # Process each signal in the batch
        signal = signals[i]
        signal_conj = signals_conj[i]
        # Compute convolution (autocorrelation via convolution)
        result = torch.nn.functional.conv1d(
            signal.view(1, 1, -1),
            signal_conj.flip(0).view(1, 1, -1),
            padding=signal.size(0) - 1,
        )
        results[i] = result.view(-1)[signal.size(0) - 1:]  # Keep only positive lags

    # Separate magnitude and phase for the batch
    magnitudes = torch.abs(results)  # Shape: (batch_size, signal_length)
    phases = torch.angle(results)    # Shape: (batch_size, signal_length)

    # Create 2D autocorrelation maps for each signal in the batch
    auto_maps_real = torch.einsum('bi,bj->bij', torch.real(results), torch.real(results))
    auto_maps_imag = torch.einsum('bi,bj->bij', torch.imag(results), torch.imag(results))
    auto_maps_mag = torch.einsum('bi,bj->bij', magnitudes, magnitudes)  # Outer product: (batch_size, signal_length, signal_length)
    auto_maps_phase = torch.einsum('bi,bj->bij', phases, phases)  # Outer product: (batch_size, signal_length, signal_length)

    # Normalize maps to [0, 1]
    auto_maps_real_normalized = (auto_maps_real - auto_maps_real.min()) / (auto_maps_real.max() - auto_maps_real.min())
    auto_maps_imag_normalized = (auto_maps_imag - auto_maps_imag.min()) / (auto_maps_imag.max() - auto_maps_imag.min())
    auto_maps_mag_normalized = (auto_maps_mag - auto_maps_mag.min()) / (auto_maps_mag.max() - auto_maps_mag.min())
    auto_maps_phase_normalized = (auto_maps_phase - auto_maps_phase.min()) / (auto_maps_phase.max() - auto_maps_phase.min())

    output = torch.cat([auto_maps_real_normalized.unsqueeze(1), auto_maps_imag_normalized.unsqueeze(1), auto_maps_mag_normalized.unsqueeze(1), auto_maps_phase_normalized.unsqueeze(1)], dim = 1)
    
    return output


In [27]:
batch_alg_out_mat.shape

torch.Size([64, 70])

In [28]:
batch_out = batch_complex_autocorrelation(batch_alg_out_mat)

In [29]:
batch_out.shape

torch.Size([64, 4, 70, 70])

In [30]:
class CSIModel(nn.Module):
    def __init__(self):
        super(CSIModel, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=3, stride=1, padding=2)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv2 = nn.BatchNorm2d(32)

        self.flatten = nn.Flatten()
        
        self.linear1 = nn.Linear(1728, 280)
        self.bnlin1 = nn.BatchNorm1d(280)
        
        self.linear2 = nn.Linear(280, 140)



    def forward(self, x):
        #x = x.unsqueeze(1)
        x = F.relu(self.bnconv1(self.maxpool1(self.conv1(x))))
        x = F.relu(self.bnconv2(self.maxpool2(self.conv2(x))))
        x = self.flatten(x)
        x = F.relu(self.bnlin1(self.linear1(x)))
        x = self.linear2(x)

        return x

In [31]:
test_output_1 = CSIModel()(torch.rand([64,2,10,70]))
#test_output_shape

In [32]:
class SignalModel(nn.Module):
    def __init__(self):
        super(SignalModel, self).__init__()
        
        self.linear1 = nn.Linear(140, 140)
        self.bnlin1 = nn.BatchNorm1d(140)
        
        self.linear2 = nn.Linear(140, 140)



    def forward(self, x):
        #x = x.unsqueeze(1)
        x = F.relu(self.bnlin1(self.linear1(x)))
        x = self.linear2(x)

        return x

In [33]:
test_output_2 = SignalModel()(torch.rand([64,140]))
#test_output_shape

In [34]:
test_output = torch.cat([test_output_1, test_output_2], dim=1)

In [35]:
test_output.shape

torch.Size([64, 280])

In [36]:
class CombinedModel(nn.Module):
    def __init__(self):
        super(CombinedModel, self).__init__()
        
        self.csimodel = CSIModel()
        self.signalmodel = SignalModel()

        
        self.linear1 = nn.Linear(280,140)
        self.bnlin1 = nn.BatchNorm1d(140)
        

    def forward(self, x1, x2):
        #x = x.unsqueeze(1)
        x1 = self.csimodel(x1)
        x2 = self.signalmodel(x2)
        x = torch.cat([x1, x2], dim=1)
        x = self.bnlin1(self.linear1(x))
  
        return x

In [37]:
test_output_final= CombinedModel()(torch.rand([64,2,10,70]), torch.rand([64,140]))
test_output_final.shape

torch.Size([64, 140])

In [38]:
model = CombinedModel().to(device)

# Define the loss functions
loss = torch.nn.MSELoss()  # For classification

# Define an optimizer (both for the encoder and the decoder!)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

#scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.01)  # Learning rate decay scheduler
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=2)

# Variables for early stopping and best parameters
best_loss = float('inf')
patience_limit = 2


best_model = None

train_losses = []
val_losses = []

alpha = 1
beta = 0.1
gamma = 0.1

# Train the model
EPOCHS = 10
for epoch in range(EPOCHS):
    running_train_loss = 0.0
    
    model.train()
    progress_bar_train = tqdm(enumerate(train_loader), total=len(train_loader), ncols=150)
    for index, (algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat) in progress_bar_train:
        # Forward pass
        #algorithm_output_mat_for_nn = (batch_complex_autocorrelation(algorithm_output_mat)).to(device)
        main_channels_mat_for_nn = torch.stack([torch.real(main_channels_mat).float(), torch.imag(main_channels_mat).float()], dim=1).to(device)
        algorithm_input_mat_for_nn = (complex_to_interleaved_real(algorithm_input_mat)).to(device)
        algorithm_output_mat_for_nn = (complex_to_interleaved_real(algorithm_output_mat)).to(device)

        
        nn_output = model(main_channels_mat_for_nn, algorithm_input_mat_for_nn)
        
        # Calculate loss
        initial_loss = loss(nn_output, algorithm_output_mat_for_nn)

        nn_output_control =  interleaved_real_to_complex(nn_output)
        algorithm_output_mat_for_nn_control =  interleaved_real_to_complex(algorithm_output_mat_for_nn)
        
        papr_diff, nn_papr, alg_papr  = papr_loss(nn_output_control, algorithm_output_mat_for_nn_control)
        ser_diff, nn_ser, alg_ser = ser_loss(algorithm_input_mat, algorithm_output_mat_for_nn_control, nn_output_control, main_channels_mat, symbols_store_mat)
        
        train_loss = alpha*initial_loss + beta*papr_diff + gamma*ser_diff

        # Backward pass
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Update running loss
        running_train_loss += train_loss.item()
        avg_train_loss = running_train_loss / (index + 1)

        # Get current learning rate from the optimizer
        current_lr = optimizer.param_groups[0]['lr']

        # Print metrics
        #progress_bar_train.set_description(f'Epoch [{epoch + 1}/{EPOCHS}] MSELos:{avg_train_loss1:.4f} MSEWeig{mse_weight:.2f} CELos:{avg_train_loss2:.4f} CEWeig{ce_weight:.2f} TrLos:{avg_train_loss:.4f} Tr.Acc: {avg_train_acc*100:.2f}%')
        progress_bar_train.set_description(f" Epoch [{epoch + 1}/{EPOCHS}] T Loss:{avg_train_loss:.4f} PAPR_dff: {papr_diff:.4f} NN_PAPR: {nn_papr:.4f} Alg_PAPR: {alg_papr:.4f} SER_dff: {ser_diff:.4f} NN_SER: {nn_ser:.4f} Alg_SER: {alg_ser:.4f} LR is {current_lr}")
    
    #train_losses.append(avg_train_loss)
    train_losses.append(avg_train_loss)

    print(f"Training has completed epoch {epoch+1}")
    
    # Validation loop
    running_val_loss = 0.0

    
    model.eval()
    progress_bar_val = tqdm(enumerate(val_loader), total=len(val_loader), ncols=150)
    for index, (algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat) in progress_bar_val:
        
        #algorithm_output_mat_for_nn = (batch_complex_autocorrelation(algorithm_output_mat)).to(device)
        main_channels_mat_for_nn = torch.stack([torch.real(main_channels_mat).float(), torch.imag(main_channels_mat).float()], dim=1).to(device)
        algorithm_input_mat_for_nn = (complex_to_interleaved_real(algorithm_input_mat)).to(device)
        algorithm_output_mat_for_nn = (complex_to_interleaved_real(algorithm_output_mat)).to(device)

        
        with torch.no_grad():
            
            nn_output = model(main_channels_mat_for_nn, algorithm_input_mat_for_nn)

            # Calculate losses
            val_loss = loss(nn_output, algorithm_output_mat_for_nn)

            # Update running loss
            running_val_loss += val_loss.item()
            
            avg_val_loss = running_val_loss / (index + 1)

            nn_output_control =  interleaved_real_to_complex(nn_output)
            algorithm_output_mat_for_nn_control =  interleaved_real_to_complex(algorithm_output_mat_for_nn)
        
            papr_diff, nn_papr, alg_papr = papr_loss(nn_output_control, algorithm_output_mat_for_nn_control)
            ser_diff, nn_ser, alg_ser = ser_loss(algorithm_input_mat, algorithm_output_mat_for_nn_control, nn_output_control, main_channels_mat, symbols_store_mat)

            progress_bar_val.set_description(f" Epoch [{epoch + 1}/{EPOCHS}] V Loss:{avg_val_loss:.4f} PAPR_dff: {papr_diff:.4f} NN_PAPR: {nn_papr:.4f} Alg_PAPR: {alg_papr:.4f} SER_dff: {ser_diff:.4f} NN_SER: {nn_ser:.4f} Alg_SER: {alg_ser:.4f}")
    
    #val_losses.append(avg_val_loss)
    val_losses.append(avg_val_loss)
    
    scheduler.step(running_val_loss)


    # Early stopping
    if avg_val_loss < best_loss:  # Now checking for the best accuracy
        best_loss = avg_val_loss
        best_epoch = epoch + 1
        best_train_loss = avg_train_loss
        patience_ = 0
        best_weights = copy.deepcopy(model.state_dict())
        print(f"Best Validation Loss is now: {best_loss:.4f} at Epoch: {best_epoch}")
    else:
        patience_ += 1
        print(f"This is Epoch: {patience_} without improvement")
        print(f"Current Validation Loss is: {avg_val_loss:.4f} at Epoch: {epoch+1}")
        print(f"Best Validation Loss remains: {best_loss:.4f} at Epoch: {best_epoch}")
        if patience_ > patience_limit:  # Patience limit before stopping
            print("Early stopping triggered! Restoring best model weights.")
            print(f"Best Validation Loss was: {best_loss:.4f} at Epoch: {best_epoch}")
            break

best_model = model.cpu()
best_model.load_state_dict(best_weights)


Epoch [1/10] T Loss:0.1646 PAPR_dff: 0.0006 NN_PAPR: 1.7089 Alg_PAPR: 2.1211 SER_dff: 0.9344 NN_SER: 0.9344 Alg_SER: 0.0000 LR is 0.001: 100%|█| 7813

Training has completed epoch 1



Epoch [1/10] V Loss:0.0080 PAPR_dff: 0.0000 NN_PAPR: 1.6889 Alg_PAPR: 2.1117 SER_dff: 0.9219 NN_SER: 0.9219 Alg_SER: 0.0000: 100%|█| 1563/1563 [00:24

Best Validation Loss is now: 0.0080 at Epoch: 1



Epoch [2/10] T Loss:0.1018 PAPR_dff: 0.0000 NN_PAPR: 1.7370 Alg_PAPR: 2.0894 SER_dff: 0.9375 NN_SER: 0.9438 Alg_SER: 0.0063 LR is 0.001: 100%|█| 7813

Training has completed epoch 2



Epoch [2/10] V Loss:0.0079 PAPR_dff: 0.0000 NN_PAPR: 1.7446 Alg_PAPR: 2.1117 SER_dff: 0.9219 NN_SER: 0.9219 Alg_SER: 0.0000: 100%|█| 1563/1563 [00:27

Best Validation Loss is now: 0.0079 at Epoch: 2



Epoch [3/10] T Loss:0.1025 PAPR_dff: 0.0000 NN_PAPR: 1.7252 Alg_PAPR: 2.1503 SER_dff: 0.9219 NN_SER: 0.9219 Alg_SER: 0.0000 LR is 0.001: 100%|█| 7813

Training has completed epoch 3



Epoch [3/10] V Loss:0.0127 PAPR_dff: 0.0000 NN_PAPR: 1.7223 Alg_PAPR: 2.1117 SER_dff: 0.9500 NN_SER: 0.9531 Alg_SER: 0.0031: 100%|█| 1563/1563 [00:31

This is Epoch: 1 without improvement
Current Validation Loss is: 0.0127 at Epoch: 3
Best Validation Loss remains: 0.0079 at Epoch: 2



Epoch [4/10] T Loss:0.1030 PAPR_dff: 0.0000 NN_PAPR: 1.7179 Alg_PAPR: 2.1090 SER_dff: 0.8937 NN_SER: 0.8938 Alg_SER: 0.0000 LR is 0.001: 100%|█| 7813

Training has completed epoch 4



Epoch [4/10] V Loss:0.0119 PAPR_dff: 0.0000 NN_PAPR: 1.7222 Alg_PAPR: 2.1117 SER_dff: 0.9344 NN_SER: 0.9375 Alg_SER: 0.0031: 100%|█| 1563/1563 [00:25

This is Epoch: 2 without improvement
Current Validation Loss is: 0.0119 at Epoch: 4
Best Validation Loss remains: 0.0079 at Epoch: 2



Epoch [5/10] T Loss:0.1051 PAPR_dff: 0.0054 NN_PAPR: 1.7152 Alg_PAPR: 2.1028 SER_dff: 0.9203 NN_SER: 0.9234 Alg_SER: 0.0031 LR is 0.001:   2%| | 135/

KeyboardInterrupt: 

In [39]:
test_losses = []
running_test_loss = 0.0


progress_bar_test = tqdm(enumerate(test_loader), total=len(test_loader), ncols=150)
for index, (algorithm_input_mat, algorithm_output_mat, main_channels_mat, symbols_store_mat) in progress_bar_test:
        
    main_channels_mat_for_nn = torch.stack([torch.real(main_channels_mat).float(), torch.imag(main_channels_mat).float()], dim=1)
    algorithm_input_mat_for_nn = (complex_to_interleaved_real(algorithm_input_mat))
    algorithm_output_mat_for_nn = (complex_to_interleaved_real(algorithm_output_mat))
    
    with torch.no_grad():
            
        nn_output = best_model(main_channels_mat_for_nn, algorithm_input_mat_for_nn)

        # Calculate losses
        test_loss = loss(nn_output, algorithm_output_mat_for_nn)

        # Update running loss
        running_test_loss += test_loss.item()
            
        avg_test_loss = running_test_loss / (index + 1)

        nn_output_control =  interleaved_real_to_complex(nn_output)
        algorithm_output_mat_for_nn_control =  interleaved_real_to_complex(algorithm_output_mat_for_nn)
        
        papr_diff, nn_papr, alg_papr = papr_loss(nn_output_control, algorithm_output_mat_for_nn_control)
        ser_diff, nn_ser, alg_ser = ser_loss(algorithm_input_mat, algorithm_output_mat_for_nn_control, nn_output_control, main_channels_mat, symbols_store_mat)

        progress_bar_test.set_description(f'Epoch [{epoch + 1}/{EPOCHS}] Te Loss:{avg_test_loss:.4f} PAPR_dff: {papr_diff:.4f} SER_dff: {ser_diff:.4f}')

        
        if index < 1:
            total_nn_out = nn_output_control
            total_alg_in = algorithm_input_mat
            total_alg_out = algorithm_output_mat_for_nn_control
            total_main_channels = main_channels_mat
            total_symbols = symbols_store_mat
        else:
            total_nn_out = torch.cat([total_nn_out, nn_output_control], dim=0, out=None)
            total_alg_in = torch.cat([total_alg_in, algorithm_input_mat], dim=0, out=None)
            total_alg_out = torch.cat([total_alg_out, algorithm_output_mat_for_nn_control], dim=0, out=None)
            total_main_channels = torch.cat([total_main_channels, main_channels_mat], dim=0, out=None)
            total_symbols = torch.cat([total_symbols, symbols_store_mat], dim=0, out=None)

test_losses.append(avg_test_loss)


 0%|                                                                                                                        | 0/1563 [00:00<?, ?it/s]

TypeError: 'NoneType' object is not callable

In [119]:
total_alg_in_real, total_alg_in_imag, total_alg_out_real, total_alg_out_imag, total_nn_out_real, total_nn_out_imag, total_main_channels_real, total_main_channels_imag, total_symbols = prepare_for_matlab(total_alg_in, total_alg_out, total_nn_out, total_main_channels, total_symbols)

In [120]:
# Save all variables in a dictionary
savemat("output_from_pytorch.mat", {
    "total_alg_in_real": total_alg_in_real,
    "total_alg_in_imag": total_alg_in_imag,
    "total_alg_out_real": total_alg_out_real,
    "total_alg_out_imag": total_alg_out_imag,
    "total_nn_out_real": total_nn_out_real,
    "total_nn_out_imag": total_nn_out_imag,
    "total_main_channels_real": total_main_channels_real,
    "total_main_channels_imag": total_main_channels_imag,
    "total_symbols": total_symbols
})