In [1]:
# This is the method that uses the MATLAB Engine API for Python
import matlab.engine
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import  models, datasets, transforms
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import timm
import pickle
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler, MinMaxScaler
import numpy as np
import scipy.io as scio
from scipy.io import savemat
import h5py
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
import gc

In [2]:
main_channels = scio.loadmat('main_channels_single.mat')
main_channels_mat = main_channels['main_channels']

In [3]:
class CustomDataset(Dataset):
    def __init__(self, main_channels_mat):
        # convert into PyTorch tensors and remember them
        self.main_channels_mat = main_channels_mat

    def __len__(self):
        # this should return the size of the dataset
        return len(self.main_channels_mat)
    
    def __getitem__(self, idx):
        # this should return one sample from the dataset
        main_channels_mat = self.main_channels_mat[idx,:,:]
        return main_channels_mat

In [4]:
dataset = CustomDataset(main_channels_mat)

In [5]:
# Create DataLoaders
batch_size = 64
test_loader = DataLoader(dataset, shuffle=False, batch_size=batch_size)


In [6]:
batch_main_chan_mat = next(iter(test_loader))

print(f'Shape of batch feature is {batch_main_chan_mat.shape}')
print(f'Data type of batch feature is {batch_main_chan_mat.dtype}')

Shape of batch feature is torch.Size([64, 10, 70])
Data type of batch feature is torch.complex64


class CSIEncoder(nn.Module):
    def __init__(self):
        super(CSIEncoder, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv2 = nn.BatchNorm2d(32)

        self.flatten = nn.Flatten()
        
        self.linear1 = nn.Linear(1728, 280)
        self.bnlin1 = nn.BatchNorm1d(280)
        
        self.linear2 = nn.Linear(280, 140)



    def forward(self, x):
        #x = x.unsqueeze(1)
        x = F.relu(self.bnconv1(self.conv1(x)))
        x = F.relu(self.bnconv2(self.conv2(x)))
        x = self.flatten(x)
        x = F.relu(self.bnlin1(self.linear1(x)))
        x = self.linear2(x)

        return x

In [29]:
class CSIEncoder(nn.Module):
    def __init__(self):
        super(CSIEncoder, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=3, stride=2, padding=(2,0))
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=2)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bnconv3 = nn.BatchNorm2d(128)

        self.flatten = nn.Flatten()
        
        self.linear1 = nn.Linear(128*2*9, 280)
        self.bnlin1 = nn.BatchNorm1d(280)
        
        self.linear2 = nn.Linear(280, 140)



    def forward(self, x):
        #x = x.unsqueeze(1)
        x1 = F.relu(self.bnconv1(self.conv1(x)))
        x2 = F.relu(self.bnconv2(self.conv2(x1)))
        x3 = F.relu(self.bnconv3(self.conv3(x2)))
        x = self.flatten(x3)
        x4 = F.relu(self.bnlin1(self.linear1(x)))
        x_encoded = self.linear2(x4)

  
        return x_encoded, x4, x3, x2, x1

In [30]:
test_output_enc = CSIEncoder()(torch.rand([64,2,10,70]))
test_output_enc

(tensor([[-0.4682, -0.1584, -0.1964,  ..., -0.0342,  0.1277, -0.3639],
         [ 0.3153,  0.2861,  0.1187,  ..., -0.4069,  0.2458,  0.1366],
         [ 0.1136, -0.4356,  0.3852,  ..., -0.1684,  0.2199,  0.0412],
         ...,
         [ 0.4471,  0.4414, -0.2760,  ..., -0.7024,  0.1327,  0.1739],
         [-0.0367,  0.3105, -0.0724,  ..., -0.4543, -0.3287,  0.3006],
         [ 0.5778,  0.4133, -0.1631,  ..., -0.2400,  0.7636,  0.2785]],
        grad_fn=<AddmmBackward0>),
 tensor([[0.2436, 1.5620, 0.5737,  ..., 0.7832, 0.9754, 0.0655],
         [0.0000, 0.0000, 0.0000,  ..., 0.1173, 0.0000, 0.0000],
         [0.0000, 0.2173, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 1.8614, 0.5286,  ..., 0.0000, 0.0000, 0.5341],
         [0.0000, 0.0000, 0.4026,  ..., 0.0000, 1.0744, 0.0000],
         [0.0734, 0.0000, 0.0509,  ..., 0.0000, 1.0900, 0.2804]],
        grad_fn=<ReluBackward0>),
 tensor([[[[0.3678, 0.5906, 0.8336,  ..., 1.1359, 0.7094, 1.5801],
           [1.3286

class CSIDecoder(nn.Module):
    def __init__(self):
        super(CSIDecoder, self).__init__()
        
        self.linear1 = nn.Linear(140, 280)
        self.bnlin1 = nn.BatchNorm1d(280)
        
        self.linear2 = nn.Linear(280, 1728)
        self.bnlin2 = nn.BatchNorm1d(1728)

        self.unflatten = nn.Unflatten(1,[32, 3, 18])
        
        self.convT1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bnconv1 = nn.BatchNorm2d(32)

        self.convT2 = nn.ConvTranspose2d(in_channels=32, out_channels=2, kernel_size=3, stride=2, padding=2, output_padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)


    def forward(self, x):
        #x = x.unsqueeze(1)
        x = F.relu(self.bnlin1(self.linear1(x)))
        x = F.relu(self.bnlin2(self.linear2(x)))
        x = self.unflatten(x)
        x = F.relu(self.bnconv1(self.convT1(x)))
        x = self.convT2(x)
       
        return x

In [31]:
class CSIDecoder(nn.Module):
    def __init__(self):
        super(CSIDecoder, self).__init__()

        self.encoder = CSIEncoder()
        
        self.linear1 = nn.Linear(140, 280)
        self.bnlin1 = nn.BatchNorm1d(280)
        
        self.linear2 = nn.Linear(280, 128*2*9)
        self.bnlin2 = nn.BatchNorm1d(128*2*9)

        self.unflatten = nn.Unflatten(1,[128, 2, 9])
        
        self.convT1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bnconv1 = nn.BatchNorm2d(64)
        
        self.convT2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=2, output_padding=1)
        self.bnconv2 = nn.BatchNorm2d(32)

        self.convT3 = nn.ConvTranspose2d(in_channels=32, out_channels=2, kernel_size=3, stride=2, padding=2, output_padding=1)
        self.bnconv3 = nn.BatchNorm2d(2)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.convT4 = nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=(1,5), stride=1)

    def forward(self, x_enc_input):
        #x = x.unsqueeze(1)
        x_encoded, x4, x3, x2, x1 = self.encoder(x_enc_input)
        
        x = F.relu(self.bnlin1(self.linear1(x_encoded)))
        x = F.relu(self.bnlin2(self.linear2(x + x4)))
        #x = F.relu(self.bnlin2(self.linear2(torch.cat([x,x4], dim=1))))
        x = self.unflatten(x)
        x = F.relu(self.bnconv1(self.convT1(x + x3)))
        #x = F.relu(self.bnconv1(self.convT1(torch.cat([x,x3], dim=1))))
        x = F.relu(self.bnconv2(self.convT2(x + x2)))
        #x = F.relu(self.bnconv2(self.convT2(torch.cat([x,x2], dim=1))))
        x = F.relu(self.bnconv3(self.convT3(x + x1)))
        #x = F.relu(self.bnconv3(self.convT3(torch.cat([x,x1],dim=1))))
        x = self.convT4(x)
       
        return x

In [32]:
test_output_dec = CSIDecoder()(torch.rand([64,2,10,70]))
test_output_dec.shape

torch.Size([64, 2, 10, 70])

In [33]:
class AutoEncode(nn.Module):
    def __init__(self):
        super(AutoEncode, self).__init__()
        
        self.encoder = CSIEncoder()
        self.decoder = CSIDecoder()


    def forward(self, x):
        #x = x.unsqueeze(1)
        x = self.encoder(x)
        x = self.decoder(x)

        return x

In [34]:
class AutoEncode(nn.Module):
    def __init__(self):
        super(AutoEncode, self).__init__()
        
        self.decoder = CSIDecoder()


    def forward(self, x_enc_input):
        x = self.decoder(x_enc_input)

        return x

In [35]:
test_output_auto = AutoEncode()(torch.rand([64,2,10,70]))
test_output_auto.shape

torch.Size([64, 2, 10, 70])

In [36]:
best_encoder_model = CSIEncoder()
best_encoder_model.load_state_dict(torch.load("best_encoder_weights_with_power_loss.pth", weights_only=True))

<All keys matched successfully>

In [37]:
best_decoder_model = CSIDecoder()
best_decoder_model.load_state_dict(torch.load("best_decoder_weights_with_power_loss.pth", weights_only=True))

<All keys matched successfully>

encoder_path = "best_encoder_with_power_loss.pth"
best_encoder_model = torch.load(encoder_path, weights_only=False)

decoder_path = "best_decoder_with_power_loss.pth"
best_decoder_model = torch.load(decoder_path, weights_only=False)

In [38]:
test_losses = []
running_test_loss = 0.0

loss = torch.nn.MSELoss()  # For classification

progress_bar_test = tqdm(enumerate(test_loader), total=len(test_loader), ncols=100)
for index, (main_channels_mat) in progress_bar_test:
        
    main_channels_mat_for_nn = torch.stack([torch.real(main_channels_mat).float(), torch.imag(main_channels_mat).float()], dim=1)
    
    with torch.no_grad():
            
        nn_output = best_decoder_model(main_channels_mat_for_nn)

        # Calculate losses
        test_loss = loss(nn_output, main_channels_mat_for_nn)

        # Update running loss
        running_test_loss += test_loss.item()
            
        avg_test_loss = running_test_loss / (index + 1)

        progress_bar_test.set_description(f'Test Loss:{avg_test_loss:.4f}')

        if index < 1:
            total_main_channels = nn_output
        else:
            total_main_channels = torch.cat([total_main_channels, nn_output], dim=0, out=None)


test_losses.append(avg_test_loss)


est Loss:0.1330: 100%|█████████████████████████████████████████| 1563/1563 [01:06<00:00, 23.67it/s]

In [39]:
total_main_channels.shape

torch.Size([100000, 2, 10, 70])

In [40]:
# Save all variables in a dictionary
savemat("output_from_pytorch_channel_test_power_loss.mat", {
    "total_main_channels_test": total_main_channels
})