In [None]:
# -*- coding: utf-8 -*-
"""
Created on Thu Jun  1 10:25:26 2023

@author: ZLi27
"""
# import tensorflow as tf
# from tensorflow import keras
import scipy.io as scio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import numpy as np
from tqdm import tqdm

Rx_num = 4
Tx_num = 32
SC_num = 128

def DownPrecoding(channel_est):  ### Optional
    ### estimated channel
    HH_est = np.reshape(channel_est, (-1, Rx_num, Tx_num, SC_num, 2)) ## Rx, Tx, Subcarrier, RealImag
    HH_complex_est = HH_est[:,:,:,:,0] + 1j * HH_est[:,:,:,:,1]  ## Rx, Tx, Subcarrier
    HH_complex_est = np.transpose(HH_complex_est, [0,3,1,2])

    ### precoding based on the estimated channel
    _, _, MatTx = np.linalg.svd(HH_complex_est, full_matrices=False) ## SVD
    PrecodingVector = np.conj(MatTx[:,:,0,:])  ## The best eigenvector (MRT transmission)
    PrecodingVector = np.reshape(PrecodingVector,(-1, SC_num, Tx_num, 1))
    return PrecodingVector

def EqChannelGain(channel, PrecodingVector):
    ### The authentic CSI
    HH = np.reshape(channel, (-1, Rx_num, Tx_num, SC_num, 2)) ## Rx, Tx, Subcarrier, RealImag
    HH_complex = HH[:,:,:,:,0] + 1j * HH[:,:,:,:,1]  ## Rx, Tx, Subcarrier
    HH_complex = np.transpose(HH_complex, [0,3,1,2])

    ### Power Normalization of the precoding vector
    Power = np.matmul(np.transpose(np.conj(PrecodingVector), (0, 1, 3, 2)), PrecodingVector)
    PrecodingVector = PrecodingVector/ np.sqrt(Power)

    ### Effective channel gain
    R = np.matmul(HH_complex, PrecodingVector)
    R_conj = np.transpose(np.conj(R), (0, 1, 3, 2))
    h_sub_gain =  np.matmul(R_conj, R)
    h_sub_gain = np.reshape(np.absolute(h_sub_gain), (-1, SC_num))  ### channel gain of SC_num subcarriers
    return h_sub_gain

def DataRate(h_sub_gain, sigma2_UE):  ### Score
    SNR = h_sub_gain / sigma2_UE
    Rate = np.log2(1 + SNR)  ## rate
    Rate_OFDM = torch.mean(Rate, axis=-1)  ###  averaging over subcarriers
    Rate_OFDM_mean = torch.mean(Rate_OFDM)  ### averaging over CSI samples
    return Rate_OFDM_mean


##############################################Resnet9D################################################
import torch
import torch.nn as nn
import torch.nn.functional as F

class Mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.mish1 = Mish()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.mish2 = Mish()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        out = self.mish1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.mish2(out)
        return out

class ResNet9D(nn.Module):
    def __init__(self, input_dim=3,num_classes=2):
        super(ResNet9D, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.mish1 = Mish()
        self.layer1 = self._make_layer(64, 32, 1)
        self.layer2 = self._make_layer(32, 128, 2)
        self.layer3 = self._make_layer(128, 256, 2)
        self.layer4 = self._make_layer(256, 512, 2)
        self.linear = nn.Linear(512, num_classes)

    def _make_layer(self, in_channels, out_channels, stride):
        layers = []
        layers.append(ResidualBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = x.unsqueeze(2)
        out = self.mish1(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool1d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        out = out.view(-1,4,32,128,2)
        return out


class ANN_TypeI(nn.Module):
    def __init__(self,input_dim=3):
        super().__init__()
        self.res = ResNet9D(input_dim=input_dim, num_classes=4*32*128*2)

    def forward(self, x):
        out = self.res(x)
        return out

class RadioMap_Model_TypeI(nn.Module): ### Generate RadioMapI (Input:location, Output:beamforming vector)
    def __init__(self, input_dim=3):
        super(RadioMap_Model_TypeI, self).__init__()
        self.ann = ANN_TypeI(input_dim) ## Neural Network (input:location, output: the estimated CSI)

    def forward(self, p):
        CSI_est = self.ann(p)
        # HH_est = torch.reshape(CSI_est, [-1, Rx_num, Tx_num, SC_num, 2]) ## the estimated CSI, shape: Rx, Tx, Subcarrier, RealImag
        # HH_complex_est = torch.complex(HH_est[:,:,:,:,0], HH_est[:,:,:,:,1])
        # HH_complex_est = HH_complex_est.permute(0,3,1,2)
        # _, _, MatTx = torch.linalg.svd(HH_complex_est, full_matrices=True) ## Note that tf.svd is different from np.svd
        # PrecodingVector = MatTx[:,:,:,0] ## The best eigenvector (MRT transmission)
        PrecodingVector = torch.tensor(DownPrecoding(CSI_est.cpu()))
        # PrecodingVector = torch.reshape(PrecodingVector,[-1, SC_num, Tx_num, 1])
        return PrecodingVector


def get_device():
    # Get cpu, gpu or mps device for training.
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device")
    return device

def get_dataset_iter(features, label, device, batch_size=128, shuffle=True, validation_split=0.1):
    size = len(features)
    val_len = int(size*validation_split)
    lengths = [size-val_len, val_len]
    features_tensor = torch.tensor(features, device=device)
    label_tensor = torch.tensor(label, device=device)
    dataset = Data.TensorDataset(features_tensor, label_tensor)
    train_dataset, val_dataset = Data.random_split(dataset, lengths=lengths, generator=torch.Generator().manual_seed(42))
    train_dataset_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    val_dataset_iter = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return train_dataset_iter, val_dataset_iter

def get_tensor(features, device):
    features_tensor = torch.tensor(features, device=device)
    return features_tensor


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(tqdm(dataloader)):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred.to(torch.float64), y.to(torch.float64))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def verify(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in tqdm(dataloader):
            pred = model(X)
            test_loss += loss_fn(pred.to(torch.float64), y.to(torch.float64)).item()
    test_loss /= num_batches
    print(f"Test Error: Avg loss: {test_loss:>8f} \n")
    return test_loss

def inference(tensor, model):
    model.eval()
    with torch.no_grad():
        pred = model(tensor)
    return pred.cpu().numpy()


## device
device = get_device()

## Parameters
SC_num = 128  ## subcarrier number
Tx_num = 32 ## Tx antenna number
Rx_num = 4  ## Rx antenna number
sigma2_UE =  0.1

#### Read Data
f = scio.loadmat('./train.mat')
data = f['train']
location_data = data[0,0]['loc']
channel_data = data[0,0]['CSI']
print(location_data.shape)
print(channel_data.shape)
## basic train parameters
loss_fn = nn.MSELoss()
EPOCHS = 1000
LR = 5e-4
batch_size=128

########################################################   Scheme 1   ########################################################
## ANN for CSI estimation
train_dataset_iter, val_dataset_iter = get_dataset_iter(location_data, channel_data, device)

# from VIT import ANN_TypeI
input_dim = 3
net_type1 = ANN_TypeI(input_dim).to(device)
opt_adam_type1 = torch.optim.Adam(net_type1.parameters(),lr=LR)
best_loss = 0
test_losses = []
RadioMap_TypeI = RadioMap_Model_TypeI(input_dim=input_dim)
RadioMap_TypeI.to(device)

for i in range(EPOCHS):
    print(f"Epoch {i+1}\n-------------------------------")
    train(train_dataset_iter, net_type1, loss_fn, opt_adam_type1)
    test_loss = verify(val_dataset_iter, net_type1, loss_fn)
    if test_loss > best_loss:
        best_loss = test_loss
        RadioMap_TypeI.ann=net_type1
        torch.save(RadioMap_TypeI.state_dict(), "./model.pth")
        print("Saved PyTorch Model State to model.pth")
        # torch.save(net_type1.state_dict(), f'csi_pre.pth')
    test_losses.append(test_loss)
    print(f"the best test loss is {max(test_losses)}, epoch {test_losses.index(max(test_losses))}")
print("Done!")





