# import package

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from scipy.linalg import orth

import random
from scipy import linalg
from scipy import signal
import time
import scipy
import math
import h5py
from tqdm.auto import tqdm
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False


myseed = 42069
torch.backends.cudnn.deterministic = True  

torch.backends.cudnn.benchmark = False
np.random.seed(myseed)  
torch.manual_seed(myseed)  
torch.cuda.manual_seed_all(myseed)  
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'


print('The device used is', device)

# construct training data

In [4]:
d = 5   # Number of targets
m = 8   # Number of antenna arrays
snr = 10   # Signal to Noise Ratio

mean_signal_power = 0  #Signal mean
var_signal_power = 1  #Signal variance

mean_noise = 0  #Mean noise
var_noise = 1  #noise variance


doa = np.pi * (np.random.rand(d) - 1/2)   
p = np.sqrt(1) * (np.random.randn(d) + np.random.randn(d) * 1j)   

array = np.linspace(0, m, m, endpoint=False) 
angles = np.array((np.linspace(- np.pi/2, np.pi/2, 360, endpoint=False),))   
r = angles.shape[1]
snapshots = 200



#*******************************************#
#   uniform linear array steering vector    #
#*******************************************#
def ULA_action_vector(theta):

    array = np.linspace(0, m, m, endpoint=False) 
    return np.exp(- 1j * np.pi * array * np.sin(theta))



#***********************#
#   construct signals   #
#***********************#
def construct_signal(thetas,snr,snapshots):
  

    d = len(thetas)
    
    signal = np.sqrt(var_signal_power) * (10 ** (snr / 10)) *  (np.random.randn(d, snapshots) + 1j * np.random.randn(d, snapshots)) + mean_signal_power

    A = np.array([ULA_action_vector(thetas[j]) for j in range(d)])

    noise = np.sqrt(var_noise) * (np.random.randn(m, snapshots) + 1j *np.random.randn(m, snapshots)) + mean_noise

    return np.dot(A.T, signal) + noise, signal



#********************************#
#   construct coherent signals   #
#********************************#
def construct_coherent_signal(thetas,snr,snapshots):

    d = len(thetas)

    signal = np.sqrt(var_signal_power) * (10 ** (snr / 10)) * (np.random.randn(1, snapshots) + 1j * np.random.randn(1, snapshots)) + mean_signal_power

    signal = np.repeat(signal, d, axis=0)

    A = np.array([ULA_action_vector(thetas[j]) for j in range(d)])
    noise = np.sqrt(var_noise) * (np.random.randn(m, snapshots) + 1j *np.random.randn(m, snapshots)) + mean_noise

    return np.dot(A.T, signal) + noise, signal




# ********************#
#   create dataset   #
# ********************#
def create_dataset(name, size,snr,snapshots ,coherent=False, save=True):
   
    X = np.zeros((size, m, snapshots)) + 1j * np.zeros((size, m, snapshots))
    Thetas = np.zeros((size, d))

    for i in tqdm(range(size)):
        thetas = np.pi * (np.random.rand(d) - 1 / 2)  
        if coherent:
            X[i] = construct_coherent_signal(thetas,snr,snapshots)[0]
        else:
            X[i] = construct_signal(thetas,snr,snapshots)[0]
        Thetas[i] = thetas

    if save:
        hf = h5py.File(name + '.h5', 'w')
        hf.create_dataset('X', data=X)
        hf.create_dataset('Y', data=Thetas)
        hf.close()

    return X, Thetas



    
# ********************#
#   create dataset   #
# ********************#
def create_dataset(name, size,snr,snapshots ,coherent=False, save=True):
 
    X = np.zeros((size, m, snapshots)) + 1j * np.zeros((size, m, snapshots))
    Thetas = np.zeros((size, d))
    
    for i in tqdm(range(size)):
        thetas = np.pi * (np.random.rand(d) - 1 / 2)  # random source directions
        if coherent:
            X[i] = construct_coherent_signal(thetas,snr,snapshots)[0]
        else:
            X[i] = construct_signal(thetas,snr,snapshots)[0]
        Thetas[i] = thetas

    if save:
        hf = h5py.File(name + '.h5', 'w')
        hf.create_dataset('X', data=X)
        hf.create_dataset('Y', data=Thetas)
        hf.close()

    # return X, Thetas



#***************************************#
#   create dataset with large variety   #
#***************************************#
def create_complete_dataset(name, size,snr, snapshots,num_sources=[d], coherent=False, save=True):
   
    X = np.zeros((size, m, snapshots)) + 1j * np.zeros((size, m, snapshots))
    Thetas = np.zeros((size, 6))
    for i in tqdm(range(size)):
        num = num_sources[i % len(num_sources)]   # create equal sized sets for each num. of sources
        thetas = np.pi * (np.random.rand(num) - 1/2)   # random source direction

        if coherent: 
            X[i] = construct_coherent_signal(thetas,snr,snapshots)[0]
        else: 
            X[i] = construct_signal(thetas,snr,snapshots)[0]
        Thetas[i] = np.pad(thetas, (0, 6 - num), 'constant', constant_values=np.pi)

    if save:
        hf = h5py.File(name + '.h5', 'w')
        hf.create_dataset('X', data=X)
        hf.create_dataset('Y', data=Thetas)
        hf.close()

    # return X, Thetas


#***************************************#
#   create dataset with large variety   #
#***************************************#
def create_complete_mixsnr_dataset(name,size,snapshots,num_sources=[d], coherent=False, save=True):
   
    X = np.zeros((size, m, snapshots)) + 1j * np.zeros((size, m, snapshots))
    Thetas = np.zeros((size, 6))
   
    snr_list=[ 0, 5,10]

    for i in tqdm(range(size)):
        num = num_sources[i % len(num_sources)]   # create equal sized sets for each num. of sources
        thetas = np.pi * (np.random.rand(num) - 1/2)   # random source direction

        snr=snr_list[i % len(snr_list)]

        if coherent: 
            X[i] = construct_coherent_signal(thetas,snr,snapshots)[0]
        else: 
            X[i] = construct_signal(thetas,snr,snapshots)[0]
        Thetas[i] = np.pad(thetas, (0, 6 - num), 'constant', constant_values=np.pi)


    if save:
        hf = h5py.File(name + '.h5', 'w')
        hf.create_dataset('X', data=X)
        hf.create_dataset('Y', data=Thetas)
        hf.close()






# ********************#
#   create dataset   #
# ********************#
def create_test_data(size,snr,snapshots ,coherent=False):
   
    X = np.zeros((size, m, snapshots)) + 1j * np.zeros((size, m, snapshots))
    Thetas = np.zeros((size, d))
    for i in tqdm(range(size)):
        thetas = np.pi * (np.random.rand(d) - 1 / 2)  # random source directions
        if coherent:
            X[i] = construct_coherent_signal(thetas,snr,snapshots)[0]
        else:
            X[i] = construct_signal(thetas,snr,snapshots)[0]
        Thetas[i] = thetas

    return X, Thetas



def create_complete_test_dataset(size,snr, snapshots,num_sources=[d], coherent=False):
   
    X = np.zeros((size, m, snapshots)) + 1j * np.zeros((size, m, snapshots))
    Thetas = np.zeros((size, 6))
    for i in (range(size)):
        num = num_sources[i % len(num_sources)]   # create equal sized sets for each num. of sources
        thetas = np.pi * (np.random.rand(num) - 1/2)   # random source direction

        if coherent: 
            X[i] = construct_coherent_signal(thetas,snr,snapshots)[0]
        else: 
            X[i] = construct_signal(thetas,snr,snapshots)[0]
        Thetas[i] = np.pad(thetas, (0, 6 - num), 'constant', constant_values=np.pi)


    return X, Thetas





def permutations(predDoA):
 
    if len(predDoA) == 0:
        return []
    if len(predDoA) == 1:
        return [predDoA]

    perms = []

    for i in range(len(predDoA)):
       remaining = predDoA[:i] + predDoA[i + 1:]

       for perm in permutations(remaining):
        
           perms.append([predDoA[i]] + perm)

    return perms





#*******************************#
#   cluster small eigenvalues   #
#*******************************#
def cluster(evs):
 
  
    threshold = 1.25   # non-coherent
    # threshold = 0.1   # coherent
    return evs[ np.where(abs(evs) <    abs(evs[-1]) + threshold   )        ]

#*********************************#
#   the classic MUSIC algorithm   #
#*********************************#
def classicMUSIC(incident, array, continuum, sources=None):
  
    covariance = np.cov(incident)
    eigenvalues, eigenvectors = linalg.eig(covariance)

    if sources:   # number of sources known
        d = sources
    else:
        n = cluster(eigenvalues).shape[0]   
        d = array.shape[0] - n   # and get number of signal sources

    En = eigenvectors[:, d:]

    numSamples = continuum.shape[1]
    spectrum = np.zeros(numSamples)
    for axis in continuum:
        for i in range(numSamples):
            a = ULA_action_vector(array, axis[i])
            spectrum[i] = 1./(a.conj().transpose() @ En @ En.conj().transpose() @ a)

    DoA, _ = signal.find_peaks(spectrum)

    DoA = DoA[np.argsort(spectrum[DoA])[-d:]]

    return DoA, spectrum,d



#*********************************#
#   the classic MUSIC algorithm   #
#*********************************#
def One_bit_MUSIC(incident, array, continuum, sources=None):
  
    covariance = np.cov(incident)##covariance matrix

    coveriance=np.pi/2*covariance;

    coveriance=np.sin(coveriance)+1j*np.sin(covariance)

    eigenvalues, eigenvectors = linalg.eig(covariance)


    if sources:   # number of sources known
        d = sources
    else:
        n = cluster(eigenvalues).shape[0]   # estimate multiplicity of smallest eigenvalue...
        d = array.shape[0] - n   # and get number of signal sources

    En = eigenvectors[:, d:]

    numSamples = continuum.shape[1]
    spectrum = np.zeros(numSamples)
    for axis in continuum:
        for i in range(numSamples):
            a = ULA_action_vector(array, axis[i])
            spectrum[i] = 1./(a.conj().transpose() @ En @ En.conj().transpose() @ a)

    DoA, _ = signal.find_peaks(spectrum)

    DoA = DoA[np.argsort(spectrum[DoA])[-d:]]

    return DoA, spectrum,d


#******************************#
#   the Beamformer algorithm   #
#******************************#
def beamformer(incident, array, continuum, sources=None):
 
    covariance = np.cov(incident)

    numSamples = continuum.shape[1]
    spectrum = np.zeros(numSamples)


    for axis in continuum:
        for i in range(numSamples):
           
            a = ULA_action_vector(array, axis[i])
            spectrum[i] = abs(a.conj().transpose() @ covariance @ a) / linalg.norm(a)**2

    DoAsMUSIC, _ = signal.find_peaks(spectrum)


   
    if sources: DoAsMUSIC = DoAsMUSIC[np.argsort(spectrum[DoAsMUSIC])[-sources:]]

    
    else: DoAsMUSIC = DoAsMUSIC[np.argsort(- spectrum[DoAsMUSIC])]

    return DoAsMUSIC, spectrum

#***********************************#
#   mean minimal permutation rmse   #
#***********************************#
def mean_min_perm_rmse(predDoA, trueDoA):

    num_samples = trueDoA.shape[0]

    allMSE = np.zeros(num_samples)
    for i in range(num_samples):

        
        diffs = np.zeros(np.math.factorial(trueDoA.shape[1]))
        for j, perm in enumerate(permutations(list(predDoA[i]))):

          
            diff = ((perm - trueDoA[i]) + np.pi / 2) % np.pi - np.pi / 2

          
            diffs[j] = np.mean(diff ** 2) ** (1 / 2)

      
        allMSE[i] = np.amin(diffs)

    return np.mean(allMSE)






    

# generate training data

In [None]:
num_train=200000
num_val=20000
num_val_snr=5000


create_complete_dataset('train_nocoherent_snr10_200k', num_train, snr,snapshots,num_sources=[2,3,4,5], coherent=False)
create_complete_dataset('validation_nocoherent_snr10_20k', num_val, snr,snapshots,num_sources=[2,3,4,5], coherent=False)
create_complete_dataset('vali(100)_nocoherent_snr10_5k', num_val_snr, 10,100,num_sources=[2,3,4,5], coherent=False)
create_complete_dataset('vali(0)2d_nocoherent_snr10_5k', num_val_snr, 10,snapshots,num_sources=[2], coherent=False)



# create_complete_mixsnr_dataset('data/train_nocoherent_snrmix_200k', num_train,200,num_sources=[2,3,4,5], coherent=False)
# create_complete_mixsnr_dataset('data/validation_nocoherent_snrmix_20k', num_val,200,num_sources=[2,3,4,5], coherent=False)
# create_complete_dataset('data/vali(0)_nocoherent_snr0_10k', num_val_snr, 0,200,num_sources=[2,3,4,5], coherent=False)
# create_complete_dataset('data/vali(-5)_nocoherent_snr-5_10k', num_val_snr, -5,200,num_sources=[2,3,4,5], coherent=False)



#  dataset

In [5]:
class Deep_augmented_MUSIC_Dataset(Dataset):
    def __init__(self, x, y=None):   #X is the input data, y is the target data

        self.data = torch.from_numpy(x).float()
        # self.data = torch.FloatTensor(x)
        

        if y is not None:

            self.label = torch.from_numpy(y).float()
        else:
            self.label = None



    def __getitem__(self, index):  
        if self.label is not None:
            return self.data[index], self.label[index]
        else:
            return self.data[index]

    def __len__(self):
        return len(self.data)  


In [None]:




hf_train = h5py.File('train_nocoherent_snr10_200k.h5', 'r')#read data
hf_val = h5py.File('validation_nocoherent_snr10_20k.h5', 'r')#
hf_val_snr = h5py.File('vali(100)_nocoherent_snr10_5k.h5', 'r')#
hf_val_s = h5py.File('vali(0)2d_nocoherent_snr10_5k.h5', 'r')#

# hf_train = h5py.File('data/train_coherent_d5_snr10_256k.h5', 'r')##read data
# hf_val = h5py.File('data/validation_coherent_d5_snr10_25.6k.h5', 'r')#

X_train = np.array(hf_train.get('X'))

Y_train = np.array(hf_train.get('Y'))

X_val = np.array(hf_val.get('X'))
Y_val = np.array(hf_val.get('Y'))


X_val_snr = np.array(hf_val_snr.get('X'))
Y_val_snr = np.array(hf_val_snr.get('Y'))

X_val_s = np.array(hf_val_s.get('X'))
Y_val_s = np.array(hf_val_s.get('Y'))



X_train=np.concatenate([X_train.real, X_train.imag], axis=1)#Splicing complex numbers into real numbers
X_val=np.concatenate([X_val.real, X_val.imag], axis=1)#Splicing complex numbers into real numbers
X_val_snr=np.concatenate([X_val_snr.real, X_val_snr.imag], axis=1)#Splicing complex numbers into real numbers
X_val_s=np.concatenate([X_val_s.real, X_val_s.imag], axis=1)#Splicing complex numbers into real numbers



# X_train=np.sign(X_train)#onebit bit quantization
# X_val=np.sign(X_val)#onebit bit quantization


X_train=np.sign(X_train)#onebit bit quantization
X_val=np.sign(X_val)#onebit bit quantization
X_val_snr=np.sign(X_val_snr)#onebit bit quantization
X_val_s=np.sign(X_val_s)#onebit bit quantization



print("Training Set Size:","Y_train",Y_train.shape,"X_train",X_train.shape,X_train.dtype)
print("Verification set size:","Y_val",Y_val.shape,"X_val",X_val.shape,X_val.dtype)
print("Verification set snapshot size:","Y_val_snr",Y_val_snr.shape,"X_val",X_val_snr.shape,X_val_snr.dtype)
print("Verify snapshot set size:","Y_val_s",Y_val_s.shape,"X_val",X_val_s.shape,X_val_s.dtype)


train_set = Deep_augmented_MUSIC_Dataset(X_train,Y_train)  
val_set = Deep_augmented_MUSIC_Dataset(X_val,Y_val)

val_set_snr = Deep_augmented_MUSIC_Dataset(X_val_snr,Y_val_snr)
val_set_s = Deep_augmented_MUSIC_Dataset(X_val_s,Y_val_s)


batch_size=128

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,drop_last=True)  
valid_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False,drop_last=True)  

valid_loader_snr = DataLoader(val_set_snr, batch_size=batch_size, shuffle=False,drop_last=True) 
valid_loader_s = DataLoader(val_set_s, batch_size=batch_size, shuffle=False,drop_last=True) 




#  doa estimation network

In [6]:
a=torch.zeros([m,r])+1j*torch.zeros([m,r])
for i in range(r):
    a[:,i] =torch.from_numpy(ULA_action_vector(array, angles[0,i]))
a=torch.complex(a.real.float(),a.imag.float()).to(device)


def calculate_spectrum(En):

    H1=torch.matmul(En.to(device)@ torch.conj(En.permute(0,2,1)).to(device),a).to(device)

    H2=torch.mul(H1,torch.conj(a))

    H3=torch.sum(H2,dim=1)

    return (1.0/abs(H3)).to(device)



class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
     
        x = x.to(device)  + self.pe[:x.size(0)].to(device) 

        return x



# #**********************************#
#   trans_music                    #
#**********************************#

class calculate_EVD(nn.Module):
    def __init__(self):
        super(calculate_EVD, self).__init__()
      

    def forward(self, x):
        
        size=x.shape[0]#Get batch_ Size size
        
        #Become pseudo Covariance matrix
        x=torch.complex(x[ :,:8 ,:].to(device),x[ :,8: ,:].to(device))    #  [size,8,8]
        #Eigenvector decomposition
        xVal, xVec=torch.linalg.eig(x)
        xVal_=abs(xVal).to(device)    #From small to large row
        xVal_,idx=torch.sort(xVal_, dim=1) #Sort feature values to obtain an index
        x4=xVec.to(device)
        order_vec=torch.zeros([size,m,m]).to(device)+1j*torch.zeros([size,m,m]).to(device)
        order_value=torch.zeros([size,m]).to(device)+1j*torch.zeros([size,m]).to(device)
        order_vec=order_vec.to(device)
        order_value=order_value.to(device)


        for i in range(size):

            order_vec[i,:,:]=x4[i].index_select(1,idx[i]).unsqueeze(dim=0)
            order_value[i,:]=xVal[i].index_select(0,idx[i]).unsqueeze(dim=0)

        return order_value,order_vec




class RNN_music_EVD_two(nn.Module):

    def __init__(self,m):

        super(RNN_music_EVD_two, self).__init__()

        self.m=m

        self.BN=torch.nn.BatchNorm1d(2*self.m).to(device)
        # self.LN= torch.nn.LayerNorm(normalized_shape = [16]).to(device)
  
        self.evd=calculate_EVD().to(device)

        self.gru_layer = nn.GRU(2*self.m, 2*self.m,batch_first=True).to(device)

        self.input_linear = nn.Linear(in_features=16, out_features=128).to(device) 

       
        self.change_noise = nn.Sequential(
            nn.Linear(in_features=16, out_features=16),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=8),
            nn.Sigmoid()       
                ).to(device) 

   

        self.output = nn.Sequential(
            nn.Linear(in_features=360, out_features=16) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=6) 
                ).to(device) 


        self.output_d = nn.Sequential(
            nn.Linear(in_features=16, out_features=128),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128),
            nn.Linear(in_features=128, out_features=4),
                ).to(device)        


    def forward(self,x,sub=True):

      
  

        size=x.shape[0]

        # x=x.permute(0,2,1).float().to(device)# 
        # x=self.LN(x).to(device)

        # x=x.permute(0,2,1).float().to(device)# 
        x=self.BN(x).to(device)

        
        x=x.permute(0,2,1).float().to(device)# Swap the first and second dimensions into [size, 200,16]
  
        _,x=self.gru_layer(x) #By gating the loop unit (GRU) to [size, 16]
      
        x=self.input_linear(x).to(device) #The GRU output is passed to a fully connected layer, and the output 2 * 8 * 8=128 dimensions becomes [size, 128]

        x=x.reshape(size,16,8).to(device) #Map it to two-dimensional pseudo covariance ˜ K 16 * 8 becomes [size, 16,8]]


        #Eigenvector decomposition
        xVal,xVec=self.evd(x)

        xVal=xVal.detach()#Truncated gradient flow
        xVal=torch.cat([xVal.real.float().to(device), xVal.imag.float().to(device)],dim=1)  # [size,16] 


        #Noise subspace selector
        p=self.change_noise(xVal.to(device)) #[size,8]
        p1=p
        p=torch.diag_embed(p)  #Weight value   [size,8,8]
        p=torch.complex(p,torch.zeros([size,8,8]).float().to(device)).to(device)  #[size,8,8]
        x1=torch.matmul(p.to(device),xVec.to(device)) #Weighting eigenvectors [size,8,8]*[size,8,8]


        #Calculate spectrum
        x1=calculate_spectrum(x1).to(device)           

        #Peak Finder
        x1=self.output(x1.to(device)).to(device)

        #Classifier for estimating d
        x2=self.output_d(xVal)

        return x1,x2,xVal



#**********************************#
#   trans_music                    #
#**********************************#

class trans_music_two(nn.Module):

    def __init__(self,m):

        super(trans_music_two, self).__init__()

        self.m=m

        self.BN=torch.nn.BatchNorm1d(16).to(device)
        # self.LN= torch.nn.LayerNorm(normalized_shape = [16]).to(device)

        self.pos_encoder = PositionalEncoding(16)#Position embedding

        encoder_layer=torch.nn.TransformerEncoderLayer(    #Define encoder layer
            d_model=16, 
            nhead=8, ##The number of heads in a multi head attention model
            dim_feedforward=1024, #Dimensions of feedforward network models
            dropout=0, #Dropout value
            activation="relu", 
            layer_norm_eps=1e-05, 
            batch_first=True, 
            norm_first=False, 
            device=None, 
            dtype=None).to(device) 


        self.encoder=torch.nn.TransformerEncoder(
            encoder_layer, 
            num_layers=3, 
            norm=None).to(device) 

        self.input_linear = nn.Linear(in_features=16, out_features=128).to(device) 

        self.output = nn.Sequential(

            nn.Linear(in_features=360, out_features=16) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=6) 
                ).to(device) 



        self.output_d = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=64), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=64, out_features=32),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=32, out_features=4)
                ).to(device)         


    def forward(self, x):

       
        
        #The input X is [size, 16200] batch_ Size=16, input dimension 16, sequence length 200


        size=x.shape[0]#Get batch_ Size size

        # x=x.permute(0,2,1).float().to(device)# Exchange dimension becomes [size, 200,16]

        x=self.BN(x).to(device)#Become [size, 16200]
        
        # x=self.LN(x.to(device)).to(device)
        

        x=x.permute(2,0,1).float().to(device)# Exchange dimension becomes [200, size, 16]

        #Position embedding
        x=self.pos_encoder(x.to(device)).to(device)  #x: Tensor, shape [seq_len, batch_size, embedding_dim]

        x=x.permute(1,0,2).float().to(device)# Exchange dimension becomes [size, 200,16]

        x1=self.encoder(x.to(device) )    #Transformer_ Encoder network output becomes [size, 200,16]

        x2=torch.mean(x1,dim=1) #Output becomes [size, 16]

        x3=self.input_linear(x2).to(device) #The output is passed to a fully connected layer and becomes [size, 128]

        vector=x3 #

        x4=x3.reshape(size,16,8).to(device) #Change its mapping covariance to [size, 16,8]

        # vector=x4  #CNN for classifier

        x5=torch.complex(x4[ :,:8 ,:].to(device),x4[ :,8: ,:].to(device))    #feature vector  [size,8,8]


        x6=calculate_spectrum(x5).to(device)  #Calculate spectrum

        x7=x6.float().to(device)

        x8=self.output(x7.to(device)).to(device)


        #Classifier for estimating D
        x9=x3  #[size,16]
        x9=x9.detach()#Truncated gradient flow
        x9=self.output_d(x9) 

    
        return x8,x9,vector
    



    
class trans_music_two_LN(nn.Module):

    def __init__(self,m):

        super(trans_music_two_LN, self).__init__()

        self.m=m

        # self.BN=torch.nn.BatchNorm1d(16).to(device)
        self.LN= torch.nn.LayerNorm(normalized_shape = [16]).to(device)

        self.pos_encoder = PositionalEncoding(16)#Positional embedding

        encoder_layer=torch.nn.TransformerEncoderLayer(    #Define encoder layer
            d_model=16, 
            nhead=8, ##The number of heads in the multi head attention model
            dim_feedforward=1024, #Dimension of feedforward network model
            dropout=0, #Dropout value
            activation="relu", 
            layer_norm_eps=1e-05, 
            batch_first=True, 
            norm_first=False, 
            device=None, 
            dtype=None).to(device) 


        self.encoder=torch.nn.TransformerEncoder(
            encoder_layer, 
            num_layers=3, 
            norm=None).to(device) 

        self.input_linear = nn.Linear(in_features=16, out_features=128).to(device) 

        self.output = nn.Sequential(

            nn.Linear(in_features=360, out_features=16) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=6) 
                ).to(device) 



        self.output_d = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=64), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=64, out_features=32),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=32, out_features=4)
                ).to(device)         


    def forward(self, x):

        #Suppose batch_ Size=16
       
        
        #The entered x is [size, 16200] batch_ Size=16, enter dimension 16, sequence length 200


        size=x.shape[0]#Get batch_ Size size

        x=x.permute(0,2,1).float().to(device)# Exchange dimension becomes [size, 200,16]
        # x=self.BN(x).to(device)#[size,16,200]
        
        x=self.LN(x.to(device)).to(device)


        x=x.permute(1,0,2).float().to(device)# [200,size,16]

        #Positional embedding
        x=self.pos_encoder(x.to(device)).to(device)  #x: Tensor, shape [seq_len, batch_size, embedding_dim]

        x=x.permute(1,0,2).float().to(device)# [size,200,16]

        x1=self.encoder(x.to(device) )    #Transformer_ The encoder network output becomes [size, 200,16]

        x2=torch.mean(x1,dim=1) #[size,16]

        x3=self.input_linear(x2).to(device) #The output is passed to a full connection layer and becomes [size, 128]

        vector=x3 #

        x4=x3.reshape(size,16,8).to(device) #Change its mapping covariance to [size, 16,8]

        # vector=x4  #

        x5=torch.complex(x4[ :,:8 ,:].to(device),x4[ :,8: ,:].to(device))    #Eigenvector [size, 8,8]


        x6=calculate_spectrum(x5).to(device)             #Calculate spectrum

        x7=x6.float().to(device)

        x8=self.output(x7.to(device)).to(device)


        #Classifier for estimating D
        x9=x3  #[size,16]
        x9=x9.detach()#Truncated gradient flow
        x9=self.output_d(x9) 

    
        return x8,x9,vector



#  classifier network

In [10]:


class separate_est_d(nn.Module):
    def __init__(self,name):
        super(separate_est_d, self).__init__()

        # self.BN=torch.nn.BatchNorm1d(128).to(device)

        self.output_d = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=64), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=64, out_features=32),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=32, out_features=4)
                ).to(device) 
        self.net_DOA=trans_music_two(m=8).to(device)      
        self.net_DOA.load_state_dict(torch.load(name))#Model parameter loading
    def forward(self, x):  
        with torch.no_grad():  # disable gradient calculation
            x8,x9,x3 = self.net_DOA(x)  #  value  [size,16]
        
        x10=x3  #[size,16]
        x10=x10.detach()#Truncated gradient flow
        # x10=self.BN(x10).to(device)
        x10=self.output_d(x10) 
        return x10


class separate_est_d2(nn.Module):
    def __init__(self,name):
        super(separate_est_d2, self).__init__()

        # self.BN=torch.nn.BatchNorm1d(128).to(device)

        self.output_d = nn.Sequential(
            nn.Linear(in_features=16, out_features=128),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128), 
            nn.ReLU(inplace=False),
            nn.Linear(in_features=128, out_features=128),
            nn.Linear(in_features=128, out_features=4)
                ).to(device) 
        self.net_DOA=RNN_music_EVD_two(m=8).to(device)      
        self.net_DOA.load_state_dict(torch.load(name)) 
    def forward(self, x):  
        with torch.no_grad():  # disable gradient calculation
            x8,x9,x3 = self.net_DOA(x)  #  value  [size,16]
       
        x10=x3  #[size,16]
        x10=x10.detach()
        # x10=self.BN(x10).to(device)
        x10=self.output_d(x10) 
        return x10




#  loss function

In [None]:
#****************************#
#     #Calculate permutation
#****************************#
def permutations(predDoA):
 
    if len(predDoA) == 0:
        return []
    if len(predDoA) == 1:
        return [predDoA]

    perms = []

    for i in range(len(predDoA)):
       remaining = predDoA[:i] + predDoA[i + 1:]

       for perm in permutations(remaining):
        
           perms.append([predDoA[i]] + perm)

    return perms


a2=[[1,0],[0,1]]
a2=permutations(a2)
b2=np.array(a2)
perm2=b2.reshape(np.math.factorial(2)*2,2)
perm2=perm2.swapaxes(0,1)
print(perm2.shape)
perm2=torch.from_numpy(perm2).float().to(device)  



a3=[[1,0,0],[0,1,0,],[0,0,1]]
a3=permutations(a3);
b3=np.array(a3)
perm3=b3.reshape(np.math.factorial(3)*3,3)
perm3=perm3.swapaxes(0,1)
print(perm3.shape)
perm3=torch.from_numpy(perm3).float().to(device)   


a4=[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]
a4=permutations(a4)
b4=np.array(a4)
perm4=b4.reshape(np.math.factorial(4)*4,4)
perm4=perm4.swapaxes(0,1) 
print(perm4.shape)
perm4=torch.from_numpy(perm4).float().to(device)   



a5=[[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,0,1]]
a5=permutations(a5)
b5=np.array(a5)
perm5=b5.reshape(np.math.factorial(5)*5,5)
perm5=perm5.swapaxes(0,1)
print(perm5.shape)
perm5=torch.from_numpy(perm5).float().to(device) 



def perm_rmse(predDoA, trueDoA):

    size=trueDoA.shape[0]
    num_sources = trueDoA.shape[1]

    allPerms=torch.zeros(size,np.math.factorial(num_sources)*num_sources,requires_grad=True);

    if num_sources==2:
        allPerms=torch.matmul(trueDoA.to(device) ,perm2.to(device) ).to(device)  #  [size,2]  [2,4]   =[size,4]
        allPerms=allPerms.reshape(size,2,2).to(device)      # size  2   2

    elif num_sources==3:
        allPerms=torch.matmul(trueDoA.to(device) ,perm3.to(device) ).to(device)  #  [size,3]  [3,18]   =[size,18]
        allPerms=allPerms.reshape(size,6,3).to(device)      # size  6   3

    elif num_sources==4:
        allPerms=torch.matmul(trueDoA.to(device) ,perm4.to(device) ).to(device)  #  [size,4]  [4,96]   =[size,96]
        allPerms=allPerms.reshape(size,24,4).to(device)      # size  24   4

    elif num_sources==5:
        allPerms=torch.matmul(trueDoA.to(device) ,perm5.to(device) ).to(device)  #  [size,5]  [5,600]   =[size,600]
        allPerms=allPerms.reshape(size,120,5).to(device)      # size  120   5

    predDoA=predDoA.unsqueeze(1).to(device) 
   
    diff=torch.fmod( allPerms-predDoA+np.pi / 2 , np.pi ) - np.pi / 2

    diff2,ind=torch.min(torch.mean(diff**2,dim=2)**(1/2),dim=1)

    return  torch.mean(diff2)


# def perm_rmse(predDoA, trueDoA):


#     size=trueDoA.shape[0]#
#     num_sources = trueDoA.shape[1]

#     allPerms = np.zeros((size, np.math.factorial(num_sources), num_sources))# 1 120   5

#     for i in range(size):

#         allPerms[i] = permutations(list(trueDoA[i].cpu().numpy()))

#     allPerms=torch.from_numpy(allPerms).to(device)

#     predDoA=predDoA.unsqueeze(1)
   
#     diff=torch.fmod( allPerms-predDoA+np.pi / 2 , np.pi ) - np.pi / 2

#     diff2,ind=torch.min(torch.mean(diff**2,dim=2)**(1/2),dim=1)

#     return  torch.mean(diff2)


#***************#
#   mean rmse   #
#***************#
def loss_doa(predDoA,trueDoA):

    num_true = torch.argmax(trueDoA, dim=1)

    diff = torch.zeros(1).to(device)  

    for i, elem in enumerate(num_true):

        diff=diff+perm_rmse(predDoA[i, :elem].unsqueeze(dim=0),trueDoA[i, :elem].unsqueeze(dim=0))

    diff= diff/len(num_true)  
        
    return diff



def  loss_d(predp,trueDoA):

    num_true = torch.argmax(trueDoA, dim=1)

    criterion = nn.CrossEntropyLoss()

    dd = torch.zeros(trueDoA.shape[0],4).to(device)#{size,4}

    for i in range(trueDoA.shape[0]):
       dd[i][num_true[i]-2]=1;

    return  criterion(predp,dd)





def  loss_d_by_one(predp,trueDoA):

    size=trueDoA.shape[0]

    num_true = torch.argmax(trueDoA, dim=1)

    num_true=num_true.reshape(size,1)

    diff=abs(predp-num_true)

    return  torch.mean(diff)



def loss_doa_test(predDoA,trueDoA,num_true_t,num_true_p):

    diff = torch.zeros(1).to(device)  

    for i, elem in enumerate(num_true_t):

        if num_true_p[i]==num_true_t[i]:

            diff=diff+perm_rmse(predDoA[i, :elem].unsqueeze(dim=0),trueDoA[i, :elem].unsqueeze(dim=0))

        elif num_true_p[i]>num_true_t[i]:

            diff=diff+perm_rmse(predDoA[i, :elem].unsqueeze(dim=0),trueDoA[i, :elem].unsqueeze(dim=0))

        elif num_true_p[i]<num_true_t[i]:

            predDoA[i,num_true_p[i]:]=(torch.rand(1,6-num_true_p[i])-0.5)*torch.pi

            diff=diff+perm_rmse(predDoA[i, :elem].unsqueeze(dim=0),trueDoA[i, :elem].unsqueeze(dim=0))

    diff= diff/len(num_true_t)  
        
    return diff

# define Network

In [None]:

T1=0 
T2=0 
T5=0
T6=1

T7=0
A3=0
A1=0
T3=0


train_loss_record=[]
valid_loss_record=[]

best_loss=1
break_flag=0
train_time =[]
best_acc=0

net=0
model_location='./model_d/'


# if DOAnet==1:
#     net=trans_music(m=8).to(device)
#     model_location='./model_d/trans_music.pt'


if T1==1:
    net=trans_music_two(m=8)
    model_location='./model_d/transmusic.pt'
  

if T2==1:
    net=separate_est_d('./model_d/transmusic.pt')
    model_location='./model_d/transmusic_分类器.pt'


if T5==1:
    net=trans_music_two(m=8)
    model_location='./model_d/transmusic_onebit.pt'


if T6==1:
    net=separate_est_d('./model_d/transmusic_onebit.pt')
    model_location='./model_d/transmusic_onebit_分类器.pt'



if T3==1:
    net=trans_music_two(m=8)
    model_location='./model_d/transmusic_onebit.pt'



if T7==1:
    net=trans_music_two(m=8)
    model_location='./model_d/transmusic_onebit_相干.pt'


if A3==1:
    net=separate_est_d2('./model_d/DA_music.pt')
    model_location='./model_d/DA_music_分类器.pt'

if A1==1:
    net=RNN_music_EVD_two(m=8)
    model_location='./model_d/DA_music.pt'



# for name, param in net.named_parameters():
#     print(name,'-->',param.type(),'-->',param.dtype,'-->',param.shape)


# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print('All parameters', total_params)
total_trainable_params = sum( p.numel() for p in net.parameters() if p.requires_grad)
print('All trainable parameters', total_trainable_params)



# training network

In [None]:

learning_rate=0.001
optimizer=torch.optim.Adam(net.parameters(),lr=learning_rate)

epoch_num=200

# net.load_state_dict(torch.load(model_location))#Load the model parameters saved above  


train_time.append(time.time())

for epoch in range(epoch_num):

    train_loss = 0.0
    train_loss1=0.0
    train_loss2=0.0


    val_loss = 0.0
    val_loss1 =0.0
    val_loss2 =0.0
    valsnr_loss = 0.0
    valsnr_loss1 =0.0
    valsnr_loss2 =0.0
    vals_loss=0.0


    net.train()  # set model to training mode
    for X_input,X_target in tqdm(train_loader):

        optimizer.zero_grad() #Gradient Zeroing

        X_input = X_input.to(device)

        X_target= X_target.to(device)


        X_pre,X_d,_ = net(X_input)  # forward propagation

        loss1=loss_doa(X_pre, X_target)
        
        loss2=loss_d(X_d,X_target)


        batch_loss = loss1+loss2

        batch_loss.backward()

        optimizer.step()


        train_loss += batch_loss.item()  
        train_loss1 += loss1.item()  
        train_loss2 += loss2.item()  


    train_loss = train_loss/len(train_loader)
    train_loss1 = train_loss1/len(train_loader)
    train_loss2 = train_loss2/len(train_loader)

    train_loss_record.append(train_loss)


    train_accs0=[]
    net.eval()  # set model to evalutation mode
    for x, y in valid_loader:  # iterate through the dataloader
            with torch.no_grad():  # disable gradient calculation
                x=x.to(device)
                y=y.to(device)
                pred,p_d ,_= net(x) 
                vloss1=loss_doa(pred, y)
                vloss2=loss_d(p_d,y)
                # vloss2=loss_d_by_one(p_d,y)
                v_loss = vloss1+vloss2
                val_loss+= v_loss.item()
                val_loss1+= vloss1.item()
                val_loss2+= vloss2.item()

                prob = F.softmax(p_d,dim=1)
                num_true_t = torch.argmax(y, dim=1)
                num_true_p=torch.argmax(prob,dim=1)
                num_true_p=num_true_p+2
                acc = (num_true_p == num_true_t).float().mean().item()

                train_accs0.append(acc)
                

    val_loss=val_loss/len(valid_loader)
    val_loss1=val_loss1/len(valid_loader)
    val_loss2=val_loss2/len(valid_loader)
    valid_loss_record.append(val_loss)
    train_acc0 = sum(train_accs0) / len(train_accs0)#accuracy



    net.eval()  # set model to evalutation mode
    train_accs = []
    for x, y in valid_loader_snr:  # iterate through the dataloader
            with torch.no_grad():  # disable gradient calculation
                x=x.to(device)
                y=y.to(device)
                pred_snr,p_d_snr,_ = net(x) 

                vloss2=loss_doa(pred_snr, y)
                
                vsnrloss=loss_d(p_d_snr,y)
                prob = F.softmax(p_d_snr,dim=1)
                num_true_t = torch.argmax(y, dim=1)
                num_true_p=torch.argmax(prob,dim=1)
                num_true_p=num_true_p+2
                acc = (num_true_p == num_true_t).float().mean().item()

                train_accs.append(acc)
                valsnr_loss+= vsnrloss.item()
                valsnr_loss1 += vloss2.item()

    valsnr_loss=valsnr_loss/len(valid_loader_snr)

    valsnr_loss1=valsnr_loss1/len(valid_loader_snr)
    train_acc = sum(train_accs) / len(train_accs)



    net.eval()  # set model to evalutation mode
    train_accs1 = []
    for x, y in valid_loader_s:  # iterate through the dataloader
            with torch.no_grad():  # disable gradient calculation
                x=x.to(device)
                y=y.to(device)
                pred,p_d_s,_= net(x) 
                vloss3=loss_doa(pred, y)
                vsloss=loss_d(p_d_s,y)


                prob = F.softmax(p_d_s,dim=1) 
                num_true_t = torch.argmax(y, dim=1)
                num_true_p=torch.argmax(prob,dim=1)
                num_true_p=num_true_p+2
                acc = (num_true_p == num_true_t).float().mean().item()

                train_accs1.append(acc)
                vals_loss+= vsloss.item()
                valsnr_loss2 += vloss3.item()

    vals_loss=vals_loss/len(valid_loader_s)
    valsnr_loss2=valsnr_loss2/len(valid_loader_s)
    train_acc1 = sum(train_accs1) / len(train_accs1)#accuracy




    
    # s0=str(train_acc0)
    # s1=str(train_acc)
    # s2=str(train_acc1)
    # md='./model_临时/transmusic_onebit_相干'+'_'+str(epoch)+'_'+s0[:6]+'_'+s1[:6]+'_'+s2[:6]+'.pt'
    # torch.save(net.state_dict(), md)  # Save model to specified path



    # if val_loss1<best_loss:#Update model with error
    #     best_loss=val_loss1
    #     torch.save(net.state_dict(), model_location)  # Save model to specified path

    if train_acc0>best_acc:#Update model with accuracy
        best_acc=train_acc0
        torch.save(net.state_dict(), model_location)  # Save model to specified path
        
        print('save the model','epoch:',epoch)
        break_flag = 0
    else:
        break_flag += 1
# 
    if break_flag > 20:
        break_flag=0
        print('_____________________________________________') 

    print('|epoch={:2d}/{:3d}|total_time={:.2f}m|train_loss={:.4f}({:.4f}dB) | val_loss={:.4f}({:.4f}dB | best_loss={:.4f}({:.4f}dB)'.format(
        epoch,epoch_num,(time.time()-train_time[0])/60,train_loss,20*np.log10(train_loss),val_loss,20*np.log10(val_loss),best_loss,20*np.log10(best_loss)      ))

    print('|train_loss1={:.4f}({:.4f}dB)|val_loss1={:.4f}({:.4f}dB|train_loss2={:.4f}({:.4f}dB)|val_loss2={:.4f}({:.4f}dB)|Accuracy:{:.4f}'.format(
       train_loss1,20*np.log10(train_loss1),val_loss1,20*np.log10(val_loss1),train_loss2,20*np.log10(train_loss2),val_loss2,20*np.log10(val_loss2) , train_acc0 ))

    print('|10100 snapshot SNR verification set=loss error {:.4f}({:.4f}dB)   | ACC error {:.4f}({:.4f}dB) |Accuracy:{:.4f}'.format( valsnr_loss1,20*np.log10(valsnr_loss1),valsnr_loss,20*np.log10(valsnr_loss),train_acc ))

    print('|d=2 Loss error {:.4f}({:.4f}dB)   | ACC error {:.4f}({:.4f}dB) |Accuracy:{:.4f}'.format( valsnr_loss2,20*np.log10(valsnr_loss2),vals_loss,20*np.log10(vals_loss),train_acc1 ))



# training classifiers

In [None]:
learning_rate=0.001
optimizer=torch.optim.Adam(net.parameters(),lr=learning_rate)
# best_acc=0
epoch_num=200

# net.load_state_dict(torch.load(model_location))#Load the model parameters saved above  

train_time.append(time.time())

for epoch in range(epoch_num):

    train_loss = 0.0

    val_loss = 0.0
   
    valsnr_loss = 0.0

    vals_loss = 0.0
  
    net.train()  # set model to training mode

    for X_input,X_target in tqdm(train_loader):
    # for X_input,X_target in (train_loader):

        optimizer.zero_grad() #Gradient Zeroing

        X_input = X_input.to(device)

        X_target= X_target.to(device)

        X_d = net(X_input)  # forward propagation
 
        batch_loss=loss_d(X_d,X_target)

        batch_loss.backward()

        optimizer.step()

        # print(batch_loss)
        train_loss += batch_loss.item()  # The total loss of all batches in an epoch
       

    train_loss = train_loss/len(train_loader)
    train_loss_record.append(train_loss)




    train_accs0=[]
    net.eval()  # set model to evalutation mode
    for x, y in valid_loader:  # iterate through the dataloader
            with torch.no_grad():  # disable gradient calculation
                x=x.to(device)
                y=y.to(device)
                p_d= net(x)  # Forward propagation and moving X into cuda
               
                vloss=loss_d(p_d,y)
                # vloss=loss_d_by_one(p_d,y)
               
                val_loss+= vloss.item()#The total loss of all batches in an epoch

                prob = F.softmax(p_d,dim=1) # Dim=0, perform Softmax on the column; Dim=1, perform Softmax on the line
                num_true_t = torch.argmax(y, dim=1)
                num_true_p=torch.argmax(prob,dim=1)
                num_true_p=num_true_p+2
                acc = (num_true_p == num_true_t).float().mean().item()
                train_accs0.append(acc)
                
    val_loss=val_loss/len(valid_loader)
    valid_loss_record.append(val_loss)
    train_acc0 = sum(train_accs0) / len(train_accs0)#accuracy


    train_accs = []
    net.eval()  # set model to evalutation mode
    for x, y in valid_loader_snr:  # iterate through the dataloader
            with torch.no_grad():  # disable gradient calculation
                x=x.to(device)
                y=y.to(device)
                p_d_snr= net(x)  # Forward propagation and moving X into cuda
                vsnrloss=loss_d(p_d_snr,y)
                # vsnrloss=loss_d_by_one(p_d_snr,y)

                prob = F.softmax(p_d_snr,dim=1) # Dim=0, perform Softmax on the column; Dim=1, perform Softmax on the line
                num_true_t = torch.argmax(y, dim=1)
                num_true_p=torch.argmax(prob,dim=1)
                num_true_p=num_true_p+2
                acc = (num_true_p == num_true_t).float().mean().item()

                train_accs.append(acc)
                valsnr_loss+= vsnrloss.item()#The total loss of all batches in an epoch

    valsnr_loss=valsnr_loss/len(valid_loader_snr)
    train_acc = sum(train_accs) / len(train_accs)#accuracy




    train_accs1 = []
    net.eval()  # set model to evalutation mode
    for x, y in valid_loader_s:  # iterate through the dataloader
            with torch.no_grad():  # disable gradient calculation
                x=x.to(device)
                y=y.to(device)
                p_d_s= net(x)  # Forward propagation and moving X into cuda
                vsloss=loss_d(p_d_s,y)
                # vsnrloss=loss_d_by_one(p_d_snr,y)

                prob = F.softmax(p_d_s,dim=1) # Dim=0, perform Softmax on the column; Dim=1, perform Softmax on the line
                num_true_t = torch.argmax(y, dim=1)
                num_true_p=torch.argmax(prob,dim=1)
                num_true_p=num_true_p+2
                acc = (num_true_p == num_true_t).float().mean().item()

                train_accs1.append(acc)
                vals_loss+= vsloss.item()#The total loss of all batches in an epoch

    vals_loss=vals_loss/len(valid_loader_s)
    train_acc1 = sum(train_accs1) / len(train_accs1)#accuracy



    if train_acc0>best_acc:#Update the model with validation data
        best_acc=train_acc0
        
        torch.save(net.state_dict(), model_location)  # Save model to specified path
        print('save the model','epoch:',epoch)
        break_flag = 0
    else:
        break_flag += 1
# 
    if break_flag > 20:
        break_flag=0
        print('_____________________________________________') 

    print('|epoch={:2d}/{:3d}|total_time={:.2f}m|train_loss={:.4f}({:.4f}dB) | val_loss={:.4f}({:.4f}dB | best_acc={:.4f}({:.4f}dB)'.format(
        epoch,epoch_num,(time.time()-train_time[0])/60,train_loss,20*np.log10(train_loss),val_loss,20*np.log10(val_loss),best_acc,20*np.log10(best_loss)      ))

    print('Verification set accuracy:{:.4f}'.format(train_acc0 ))

    print('|100Validation set={:.4f}({:.4f}dB) |accuracy:{:.4f}'.format( valsnr_loss,20*np.log10(valsnr_loss),train_acc ))

    print('|d=2Validation set={:.4f}({:.4f}dB) |accurac:{:.4f}'.format( vals_loss,20*np.log10(vals_loss),train_acc1 ))


# loss reduction chart

In [None]:

plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False

plt.figure(figsize=(20, 6))
plt.subplot(1,2,1)
plt.plot(train_loss_record, label='train_loss')
plt.plot(valid_loss_record, label='valid_loss')
plt.ylim(0., 0.5)
plt.xlabel('Training epoch')
plt.ylabel('perm_rmse loss')
plt.title('Learning curve of {}'.format("loss"))
plt.legend()
# plt.show()


# plt.figure(figsize=(8, 6))
plt.subplot(1,2,2)
plt.plot(20*np.log10(train_loss_record), label='train_loss')
plt.plot(20*np.log10(valid_loss_record), label='valid_loss')
plt.ylim(-26., -3.0)
plt.xlabel('Training epoch')
plt.ylabel('perm_rmse (dB))')
plt.title('Learning curve of {}'.format("loss"))

plt.legend()
plt.show()
