# **Unfolded matrix-inverse-free WMMSE versus WMMSE**

The WMMSE algorithm is implemented as described in the paper "An Iteratively Weighted MMSE Approach to Distributed Sum-Utility Maximization for a MIMO Interfering Broadcast Channel", by Q. Shi et al.

Note:

Functions and variables with "_nn" (i.e. neural network) in the name refer to the deep unfolded WMMSE algorithm.

In the code we use the term precoder to indicate the beamformer

##**Import libraries and set variables**

In [None]:
%tensorflow_version 1.x
# Import libraries
import tensorflow as tf # tensorflow_version 1.x is needed
from tensorflow.python.framework import ops
import numpy as np
import copy
from copy import deepcopy
import time
import matplotlib.pyplot as plt

# Set variables
nr_of_users = 4
nr_of_BS_antennas = 8
nr_of_UE_antennas = 2
nr_of_data_streams = 2
total_power = 10 # power constraint in the weighted sum rate maximization problem 
noise_power = 1
nr_of_iterations = 4
scale_V_every_iteration = True # used to normalize V at every iteration such that the power constraint is met with equality

# For the WMMSE
epsilon = 0.0001 # used to end the iterations of the WMMSE algorithm in Shi et al. when the number of iterations is not fixed (note that the stopping criterion has precendence over the fixed number of iterations)
power_tolerance = 0.0001 # used to end the bisection search in the WMMSE algorithm in Shi et al.
nr_of_iterations_WMMSE = nr_of_iterations # for WMMSE algorithm in Shi et al.
user_weights_WMMSE = np.ones((nr_of_users),)

# For the matrxi-inverse-free WMMSE
nr_of_batches_training = 10000#  used for training
nr_of_batches_test = 1000 # used for testing
nr_of_samples_per_batch = 100
batch_size = nr_of_samples_per_batch

nr_of_iterations_nn = nr_of_iterations  # for the deep unfolded WMMSE in our paper


TensorFlow 1.x selected.


#**Function to run the WMMSE algorithm described in the paper by Shi et al**

In [None]:
def compute_sinr_MIMO(channel, precoder, noise_power, user_id):
  result = 0
  nr_of_users = np.size(channel,0)
  nr_of_UE_antennas = np.size(channel,1)
  inter_user_interference = np.zeros((nr_of_UE_antennas,nr_of_UE_antennas)) + 1j*np.zeros((nr_of_UE_antennas,nr_of_UE_antennas))

  numerator = np.matmul(np.matmul(np.matmul(channel[user_id,:,:],precoder[user_id,:,:]), np.transpose(np.conj(precoder[user_id,:,:]))), np.transpose(np.conj(channel[user_id,:,:])))

  for user_index in range(nr_of_users):
    if user_index != user_id:
      inter_user_interference = inter_user_interference + np.matmul(np.matmul(np.matmul(channel[user_id,:,:],precoder[user_index,:,:]), \
                                                                              np.transpose(np.conj(precoder[user_index,:,:]))), np.transpose(np.conj(channel[user_id,:,:])))

  denominator = noise_power*np.eye(nr_of_UE_antennas,nr_of_UE_antennas)+ inter_user_interference

  result = np.matmul(numerator, np.linalg.inv(denominator))

  return result


def compute_weighted_sum_rate_MIMO(user_weights, channel, precoder, noise_power):
  result = 0
  nr_of_users = np.size(channel,0)
  nr_of_UE_antennas = np.size(channel,1)
  
  for user_index in range(nr_of_users):
    user_sinr = compute_sinr_MIMO(channel, precoder, noise_power, user_index)
    result = result + user_weights[user_index]*np.log(np.linalg.det(np.eye(nr_of_UE_antennas,nr_of_UE_antennas) + user_sinr))
  
  result = np.real(result)

  return result


def run_WMMSE_MIMO_more_streams(epsilon, channel, initial_transmitter_precoder_WMMSE, total_power, noise_power, user_weights, max_nr_of_iterations, log = False):
# initialization
  mse_weights = np.zeros((nr_of_users,nr_of_data_streams,nr_of_data_streams)) + 1j*np.zeros((nr_of_users,nr_of_data_streams,nr_of_data_streams)) 
  receiver_precoder = np.zeros((nr_of_users,nr_of_UE_antennas,nr_of_data_streams)) +1j*np.zeros((nr_of_users,nr_of_UE_antennas,nr_of_data_streams))
  transmitter_precoder = np.zeros((nr_of_users,nr_of_BS_antennas,nr_of_data_streams)) +1j*np.zeros((nr_of_users,nr_of_BS_antennas,nr_of_data_streams))
  new_transmitter_precoder = np.zeros((nr_of_users,nr_of_BS_antennas,nr_of_data_streams)) +1j*np.zeros((nr_of_users,nr_of_BS_antennas,nr_of_data_streams))
  new_receiver_precoder = np.zeros((nr_of_users,nr_of_UE_antennas,nr_of_data_streams)) +1j*np.zeros((nr_of_users,nr_of_UE_antennas,nr_of_data_streams))
  WSR_E = [0]
  WSR = []

  for user_index in range(nr_of_users):
    receiver_precoder[user_index, :,:] = np.zeros((nr_of_UE_antennas,nr_of_data_streams))
    transmitter_precoder[user_index, :,:] = np.random.normal(size = (nr_of_BS_antennas,nr_of_data_streams))

  power = (np.linalg.norm(transmitter_precoder))**2
  transmitter_precoder = initial_transmitter_precoder_WMMSE
  mse_weigths_old = mse_weights

  nr_of_iteration_counter = 1 # to keep track of the number of iteration of the WMMSE
  break_condition = 2*epsilon

  while break_condition >= epsilon and nr_of_iteration_counter<=max_nr_of_iterations:

    nr_of_iteration_counter = nr_of_iteration_counter + 1

    ###################################
    # optimize receiver precoder
    for user_index in range(nr_of_users):
      user_interference = np.zeros((nr_of_UE_antennas,nr_of_UE_antennas)) +1j*np.zeros((nr_of_UE_antennas,nr_of_UE_antennas))
      for user_index2 in range(nr_of_users):
        user_interference = user_interference + np.matmul(channel[user_index,:,:],np.matmul(transmitter_precoder[user_index2,:,:],\
                                                                                            np.matmul(np.transpose(np.conj(transmitter_precoder[user_index2,:,:])),np.transpose(np.conj(channel[user_index,:,:])) ) ) )
      
      new_receiver_precoder[user_index,:,:] = np.matmul(np.linalg.inv(np.eye(nr_of_UE_antennas,nr_of_UE_antennas)*noise_power + user_interference) , \
                                                    np.matmul(channel[user_index,:,:],transmitter_precoder[user_index,:,:]))

    ####################################
    # optimize mse weights
    for user_index in range(nr_of_users):
      mse_weights[user_index,:,:] = np.linalg.inv(np.eye(nr_of_data_streams,nr_of_data_streams) - np.matmul(np.matmul(np.transpose(np.conj(new_receiver_precoder[user_index,:,:])),channel[user_index,:,:]),transmitter_precoder[user_index,:,:]))
    
    ####################################
    # optimize transmitter precoder
    A = np.zeros((nr_of_BS_antennas,nr_of_BS_antennas))+1j*np.zeros((nr_of_BS_antennas,nr_of_BS_antennas))
 
    for user_index in range(nr_of_users):
      A = A + user_weights[user_index]*np.matmul(np.matmul(np.matmul(np.matmul(np.transpose(np.conj(channel[user_index,:,:])),new_receiver_precoder[user_index,:,:]),\
                                            mse_weights[user_index,:,:]),np.transpose(np.conj(new_receiver_precoder[user_index,:,:]))), channel[user_index,:,:])
      
    Sigma_diag_elements_true, U = np.linalg.eigh(A)
    Sigma_diag_elements = copy.deepcopy(np.real(Sigma_diag_elements_true))
    Lambda = np.zeros((nr_of_BS_antennas,nr_of_BS_antennas)) + 1j*np.zeros((nr_of_BS_antennas,nr_of_BS_antennas))
    
    for user_index in range(nr_of_users):
      Lambda = Lambda + ((user_weights[user_index])**2)* np.matmul(np.matmul(np.matmul(np.matmul(np.matmul(\
                                                                np.transpose(np.conj(channel[user_index,:,:])),\
                                                       new_receiver_precoder[user_index,:,:]),mse_weights[user_index,:,:]), np.transpose(np.conj(mse_weights[user_index,:,:]))),\
                                                       np.transpose(np.conj(new_receiver_precoder[user_index,:,:]))),channel[user_index,:,:])

    Phi = np.matmul(np.matmul(np.conj(np.transpose(U)),Lambda),U)
    Phi_diag_elements_true = np.diag(Phi)
    Phi_diag_elements = copy.deepcopy(Phi_diag_elements_true)
    Phi_diag_elements = np.real(Phi_diag_elements)

    for i in range(len(Phi_diag_elements)):
      if Phi_diag_elements[i]<np.finfo(float).eps:
        Phi_diag_elements[i] = np.finfo(float).eps
      if (Sigma_diag_elements[i])<np.finfo(float).eps:
        Sigma_diag_elements[i] = 0

    # Check if mu = 0 is a solution (eq.s (15) and (16) of in the paper of Shi et al.)
    power = 0 # the power of transmitter precoder (i.e. sum of the squared norm)
    for user_index in range(nr_of_users):
      if np.linalg.det(A) != 0:
        temp = np.matmul(np.linalg.inv(A), np.matmul(np.matmul(np.transpose(np.conj(channel[user_index,:,:])), new_receiver_precoder[user_index,:,:]),mse_weights[user_index,:,:]))
        power = power + (np.linalg.norm(temp))**2

    # If mu = 0 is a solution, then mu_star = 0
    if np.linalg.det(A) != 0 and power <= total_power:
      mu_star = 0
    # If mu = 0 is not a solution then we search for the "optimal" mu by bisection
    else:
      power_distance = [] # list to store the distance from total_power in the bisection algorithm 
      mu_low = np.sqrt(1/total_power*np.sum(Phi_diag_elements))
      mu_high = 0
      low_point = compute_P(Phi_diag_elements, Sigma_diag_elements, mu_low)
      high_point = compute_P(Phi_diag_elements, Sigma_diag_elements, mu_high)

      obtained_power = total_power + 2*power_tolerance # initialization of the obtained power such that we enter the while 

      # Bisection search
      while np.absolute(total_power - obtained_power) > power_tolerance:
        mu_new = (mu_high + mu_low)/2
        obtained_power = compute_P(Phi_diag_elements, Sigma_diag_elements, mu_new) # eq. (18) in the paper of Shi et al.
        power_distance.append(np.absolute(total_power - obtained_power))
        if obtained_power > total_power:
          mu_high = mu_new
        if obtained_power < total_power:
          mu_low = mu_new
      mu_star = mu_new

      if log == True:
        print("first value:", power_distance[0])
        plt.title("Distance from the target value in bisection (it should decrease)")
        plt.plot(power_distance)
        plt.show()

    for user_index in range(nr_of_users):
        new_transmitter_precoder[user_index,:,:] = user_weights[user_index]*np.matmul(np.matmul(np.matmul(np.linalg.inv(A  + mu_star*np.eye(nr_of_BS_antennas,nr_of_BS_antennas)), \
                                                             np.transpose(np.conj(channel[user_index,:,:]))),new_receiver_precoder[user_index,:,:]),mse_weights[user_index,:,:])

    transmitter_precoder = deepcopy(new_transmitter_precoder)
    receiver_precoder = deepcopy(new_receiver_precoder)

    WSR_E.append( np.real(np.sum(np.multiply(np.log(np.squeeze(np.linalg.det(np.real(mse_weights)))),user_weights))))
    mse_weights_old = mse_weights
    WSR.append(compute_weighted_sum_rate_MIMO(user_weights, channel, transmitter_precoder, noise_power))
    break_condition = np.absolute(WSR_E[-1] - WSR_E[-2])

  if log == True:
    plt.title("Change of the WSR at each iteration of the WMMSE (it should increase)")
    plt.plot(WSR,'bo')
    plt.show()

  return transmitter_precoder, receiver_precoder, mse_weights, WSR[-1]


# Compute power for bisection search in the optimization of the transmitter precoder 
# - eq. (18) in the paper by Shi et al.
def compute_P(Phi_diag_elements, Sigma_diag_elements, mu):
  nr_of_BS_antennas = Phi_diag_elements.size
  mu_array = mu*np.ones(Phi_diag_elements.size)
  result = np.divide(Phi_diag_elements,(Sigma_diag_elements + mu_array)**2)
  result = np.sum(result)
  return result


# Computes a channel realization and returns it in two formats, one for the WMMSE and one for the unfolded matrix-inverse-free WMMSE.
# It also returns the initialization value of the transmitter precoder and the receiver precoder, which are used as input in the computation graph of the unfolded matrix-inverse-free WMMSE.
def compute_channel(nr_of_BS_antennas, nr_of_users, total_power ):
  channel_nn = []
  initial_transmitter_precoder = []
  initial_transmitter_precoder_WMMSE = []
  initial_receiver_precoder = []
  channel_WMMSE = np.zeros((nr_of_users,nr_of_UE_antennas, nr_of_BS_antennas)) + 1j*np.zeros((nr_of_users, nr_of_UE_antennas,nr_of_BS_antennas))

  transmitter_precoder_power = 0

  VVh = 0
  nr_of_Schulz_iterations = 1

  # If number of data streams is equal to one
  if nr_of_data_streams == 1:

    for i in range(nr_of_users):

        result_real = np.sqrt(0.5)*np.random.normal(size = (nr_of_UE_antennas,nr_of_BS_antennas))
        result_imag = np.sqrt(0.5)*np.random.normal(size = (nr_of_UE_antennas,nr_of_BS_antennas))

        channel_WMMSE[i,:,:] = np.reshape(result_real,(nr_of_UE_antennas,nr_of_BS_antennas)) + 1j*np.reshape(result_imag, (nr_of_UE_antennas,nr_of_BS_antennas))

        result_col_1 = np.vstack((result_real,result_imag))
        result_col_2 = np.vstack((-result_imag,result_real))
        result =  np.hstack((result_col_1, result_col_2))
        channel_nn.append(result)

        ## transmitter precoder
        tp = np.reshape(np.sum(result_real,axis = 0),(nr_of_BS_antennas,1)) +1j*np.reshape(-1*np.sum(result_imag,axis = 0),(nr_of_BS_antennas,1))
        initial_transmitter_precoder_WMMSE.append(tp)

        real_tp = np.real(tp)
        imag_tp = np.imag(tp)
        
        first_row_tp = np.concatenate((real_tp,-1*imag_tp), axis = 1) 
        second_row_tp = np.concatenate((imag_tp,real_tp), axis = 1)

        initial_transmitter_precoder.append(np.concatenate((first_row_tp,second_row_tp), axis = 0))
        transmitter_precoder_power = transmitter_precoder_power + np.linalg.norm(tp)**2

        ## receiver precoder
        rp = np.reshape(np.sum(result_real,axis = 1),(nr_of_UE_antennas,1)) +1j*np.reshape(np.sum(result_imag,axis = 1),(nr_of_UE_antennas,1))
        real_rp = np.real(rp)
        imag_rp = np.imag(rp)
        
        first_row_rp = np.concatenate((real_rp,-1*imag_rp), axis = 1) 
        second_row_rp = np.concatenate((imag_rp,real_rp), axis = 1)

        initial_receiver_precoder.append(np.concatenate((first_row_rp,second_row_rp), axis = 0))

    initial_transmitter_precoder = np.array(initial_transmitter_precoder)    
    initial_transmitter_precoder_WMMSE = np.array(initial_transmitter_precoder_WMMSE)
    
    initial_transmitter_precoder = np.sqrt(total_power)*initial_transmitter_precoder/np.sqrt(transmitter_precoder_power)
    initial_transmitter_precoder_WMMSE = np.sqrt(total_power)*initial_transmitter_precoder_WMMSE/np.sqrt(transmitter_precoder_power)

  else:
    ## transmitter precoder
    for i in range(nr_of_users):
       
        result_real = np.sqrt(0.5)*np.random.normal(size = (nr_of_UE_antennas,nr_of_BS_antennas))
        result_imag = np.sqrt(0.5)*np.random.normal(size = (nr_of_UE_antennas,nr_of_BS_antennas))

        temp = np.reshape(result_real,(nr_of_UE_antennas,nr_of_BS_antennas)) + 1j*np.reshape(result_imag, (nr_of_UE_antennas,nr_of_BS_antennas))

        channel_WMMSE[i,:,:] = temp

        result_col_1 = np.vstack((result_real,result_imag))
        result_col_2 = np.vstack((-result_imag,result_real))
        result =  np.hstack((result_col_1, result_col_2))
        channel_nn.append(result)
        channel_norm_by_row = np.linalg.norm(temp,axis = 1 )
        channel_row_index = (np.argsort(channel_norm_by_row))[-nr_of_data_streams:][::-1]

        tp = np.transpose(np.conj(temp[channel_row_index,])) 
        transmitter_precoder_power = transmitter_precoder_power + np.linalg.norm(tp)**2
        initial_transmitter_precoder_WMMSE.append(tp)

        real_tp = np.real(tp)
        imag_tp = np.imag(tp)
        
        first_row_tp = np.concatenate((real_tp,-1*imag_tp), axis = 1) 
        second_row_tp = np.concatenate((imag_tp,real_tp), axis = 1)
        initial_transmitter_precoder.append(np.concatenate((first_row_tp,second_row_tp), axis = 0))

    initial_transmitter_precoder = np.array(initial_transmitter_precoder)    
    initial_transmitter_precoder_WMMSE = np.array(initial_transmitter_precoder_WMMSE)

    initial_transmitter_precoder = np.sqrt(total_power)*initial_transmitter_precoder/np.sqrt(transmitter_precoder_power)
    initial_transmitter_precoder_WMMSE = np.sqrt(total_power)*initial_transmitter_precoder_WMMSE/np.sqrt(transmitter_precoder_power)

    I = np.eye(nr_of_UE_antennas)

    for i in range(nr_of_users):
      VVh = VVh + np.matmul(initial_transmitter_precoder_WMMSE[i,:,:], np.transpose(np.conj(initial_transmitter_precoder_WMMSE[i,:,:])))
    
    ## receiver precoder
    if nr_of_data_streams != nr_of_UE_antennas:
      # Initialize U as matched filtering
      for i in range(nr_of_users):

        rp = np.matmul(channel_WMMSE[i,:,:],initial_transmitter_precoder_WMMSE[i,:,:])
        real_rp = np.real(rp)
        imag_rp = np.imag(rp)
  
        first_row_rp = np.concatenate((real_rp,-1*imag_rp), axis = 1) 
        second_row_rp = np.concatenate((imag_rp,real_rp), axis = 1)

        initial_receiver_precoder.append(np.concatenate((first_row_rp,second_row_rp), axis = 0))

    if nr_of_data_streams == nr_of_UE_antennas:
    # Initialize U as scaled identity matrix
    
      for i in range(nr_of_users):

        scaling = (np.trace(np.matmul(channel_WMMSE[i,:,:],initial_transmitter_precoder_WMMSE[i,:,:])))/(noise_power*nr_of_UE_antennas + np.trace(np.matmul(np.matmul(channel_WMMSE[i,:,:],VVh),np.transpose(np.conj(channel_WMMSE[i,:,:])))))

        rp = np.eye(nr_of_UE_antennas)*np.real(scaling)

        real_rp = np.real(rp)
        imag_rp = np.imag(rp)
        
        first_row_rp = np.concatenate((real_rp,-1*imag_rp), axis = 1) 
        second_row_rp = np.concatenate((imag_rp,real_rp), axis = 1)

        initial_receiver_precoder.append(np.concatenate((first_row_rp,second_row_rp), axis = 0))

  return channel_nn, initial_transmitter_precoder, initial_receiver_precoder, channel_WMMSE, initial_transmitter_precoder_WMMSE


def compute_WSR_neural_network(H, V, noise_power,user_weights,batch_size):  

  VVh_single_user = tf.matmul(V,tf.transpose(V, perm = [0,1,3,2]))
  VVh = tf.expand_dims(tf.reduce_sum(VVh_single_user, axis = 1), axis = 1)
  VVh_other_users = VVh - VVh_single_user

  I = tf.eye(2*nr_of_UE_antennas, 2*nr_of_UE_antennas, batch_shape = [batch_size, nr_of_users],dtype=tf.dtypes.float64) 

  HVVhHh = tf.matmul(tf.matmul(H,VVh_other_users), tf.transpose(H, perm = [0,1,3,2])) 
  HVVhHh_single_user = tf.matmul(tf.matmul(H,VVh_single_user),tf.transpose(H, perm = [0,1,3,2]))

  rate_per_user = tf.multiply(user_weights,0.5*(tf.log(tf.linalg.det(tf.matmul(HVVhHh_single_user,tf.linalg.inv(HVVhHh + noise_power*I)) + I))))

  return tf.reduce_sum(tf.reduce_sum(rate_per_user,axis = 1),axis = 0)/batch_size


# **Define the gradient descent for the updates of U and V and Schulz iterations for the update of W**

In [None]:

def Schulz(W,E,I):

  W_temp = tf.matmul(W,(2*I - tf.matmul(E,W)))

  return (W_temp + tf.transpose(W_temp,perm=[0,1,3,2]))*0.5


# Builds one Nesterov-accelerated GD iteration of the V update in the unfolded matrix-inverse-free WMMSE and computes the optimal step size
def GD_step_V_line_search_more_streams_Nesterov( init,init_momentum1, init_momentum2, name, mse_weights, user_weights, U,H, V,V_past, A, total_power):

  with tf.variable_scope(name): 

    epsilon_numerical_instability = 10**(-9)
    step_size_factor_temp =  tf.Variable(tf.constant(init, dtype=tf.float64), name=name, dtype=tf.float64)
    step_size_factor = 2*tf.math.sigmoid(step_size_factor_temp)
    momentum1 =  tf.Variable(tf.constant(init_momentum1, dtype = tf.float64), name = "momentum_1", dtype = tf.float64, trainable = True)
    momentum2 =  tf.Variable(tf.constant(init_momentum2, dtype = tf.float64), name = "momentum_2", dtype = tf.float64, trainable = True)
    I = tf.eye(nr_of_BS_antennas,batch_shape = [batch_size],dtype=tf.dtypes.float64) 

    V_a = V + momentum1*(V - V_past)
    V_b = V + momentum2*(V - V_past)
    
    Uh = tf.transpose(U,perm=[0,1,3,2])
    WUhU = tf.matmul(mse_weights,tf.matmul(Uh,U)) 

    real_sum_trace_WUhU = tf.multiply(tf.expand_dims(tf.expand_dims(tf.reduce_sum(tf.trace(WUhU)*0.5,axis =-1),axis = -1),axis = -1),I) 
    imag_sum_trace_WUhU = tf.multiply(tf.expand_dims(tf.expand_dims(tf.reduce_sum(tf.trace(WUhU[:,:,-nr_of_data_streams:,:nr_of_data_streams]),axis =-1),axis = -1),axis = -1),I) 

    real_sum_trace_WUhU_exp = tf.tile(tf.expand_dims(real_sum_trace_WUhU,axis = 1),(1,nr_of_users,1,1)) 
    imag_sum_trace_WUhU_exp = tf.tile(tf.expand_dims(imag_sum_trace_WUhU,axis = 1),(1,nr_of_users,1,1)) 
    
    sum_trace_WUhU_first_row = tf.concat((real_sum_trace_WUhU_exp,-1*imag_sum_trace_WUhU_exp),axis = 3)
    sum_trace_WUhU_second_row = tf.concat((imag_sum_trace_WUhU_exp,real_sum_trace_WUhU_exp),axis = 3)
    sum_trace_WUhU = tf.concat((sum_trace_WUhU_first_row,sum_trace_WUhU_second_row),axis = 2) 

    gradient = 2*tf.matmul(A,V_b)-2*tf.multiply(user_weights,tf.matmul(tf.matmul(tf.transpose(H, perm = [0,1,3,2]),U),mse_weights)) + 2*noise_power*(1/total_power)*tf.matmul(sum_trace_WUhU,V_b)

    ######################################
    # FIND OPTIMAL STEP SIZE##############
    ######################################

    ######################################
    ## NUMERATOR
    ######################################
    

    # FIRST AND SECOND TERMS NUMERATOR
    Vh = tf.transpose(V_a, perm = [0,1,3,2])
    GVh = tf.matmul(gradient,Vh) 
    VGh = tf.transpose(GVh,perm = [0,1,3,2])
    HG = tf.matmul(H,gradient)

    HhU = tf.matmul(tf.transpose(H, perm = [0,1,3,2]),U)

    UhH = tf.transpose(HhU, perm = [0,1,3,2])

    first_term_numerator =  tf.reduce_sum(tf.multiply(0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(tf.matmul(UhH,GVh),HhU))), user_weights[:,:,0,0]), axis = 1) 
    second_term_numerator =  tf.reduce_sum(tf.multiply(0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(tf.matmul(UhH,VGh),HhU))), user_weights[:,:,0,0]), axis = 1) 
    

    # THIRD AND FOURTH TERMS NUMERATOR
    Gh = tf.transpose(gradient, perm = [0,1,3,2])
    third_term_numerator = tf.reduce_sum(tf.multiply(0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(Gh,HhU))),user_weights[:,:,0,0]), axis =1) 
    fourth_term_numerator = tf.reduce_sum(tf.multiply(0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(Uh,HG))),user_weights[:,:,0,0]), axis =1) 
  

    # FIFTH AND SIXTH TERMS NUMERATOR
    GVh_all_users  = tf.reduce_sum(GVh, axis = 1) 
    GVh_other_users = tf.tile(tf.expand_dims(GVh_all_users, axis = 1),(1,nr_of_users,1,1)) - GVh 
    VGh_other_users = tf.transpose(GVh_other_users, perm = [0,1,3,2])
    
    fifth_term_numerator = tf.reduce_sum(tf.multiply(0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(UhH,tf.matmul(GVh_other_users,HhU)))),user_weights[:,:,0,0]), axis =1) 
    sixth_term_numerator = tf.reduce_sum(tf.multiply(0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(UhH,tf.matmul(VGh_other_users,HhU)))),user_weights[:,:,0,0]), axis =1) 


    # SEVENTH TERM NUMERATOR
    trace_WUhU = 0.5*tf.trace(WUhU) 

    seventh_term_numerator = tf.reduce_sum(tf.multiply(tf.multiply(tf.tile(tf.expand_dims(-1*noise_power*(1/total_power)*tf.trace(GVh_all_users),axis = -1),(1,nr_of_users)),trace_WUhU)\
                                                       ,user_weights[:,:,0,0]), axis =1) # B
    
    numerator = (first_term_numerator +  second_term_numerator - third_term_numerator - fourth_term_numerator + fifth_term_numerator + sixth_term_numerator - seventh_term_numerator)


    ############################################################################
    # DENOMINATOR
    ############################################################################
    
    # FIRST TERM DENOMINATOR
    GGh = tf.matmul(gradient,Gh) 
    first_term_denominator =  tf.reduce_sum(tf.multiply(0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(UhH,tf.matmul(GGh,HhU)))),user_weights[:,:,0,0]),axis = -1)


    # SECOND TERM DENOMINATOR
    GGh_all_users = tf.reduce_sum(GGh,axis = 1) 
    GGh_other_users = tf.tile(tf.expand_dims(GGh_all_users,axis = 1),(1,nr_of_users,1,1)) - GGh 

    second_term_denominator = tf.reduce_sum(tf.multiply(0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(UhH,tf.matmul(GGh_other_users,HhU)))),user_weights[:,:,0,0]),axis = -1)


    # THIRD TERM DENOMINATOR
    GGh_trace = tf.tile(tf.expand_dims(tf.reduce_sum(0.5*tf.trace(GGh),axis = 1),axis = -1),(1,nr_of_users)) 
    third_term_denominator = noise_power*(1/total_power)*tf.reduce_sum(tf.multiply( tf.multiply(GGh_trace,trace_WUhU),user_weights[:,:,0,0]),axis = 1) 

    
    denominator = 2*(first_term_denominator + second_term_denominator + third_term_denominator) 

    step_size = tf.expand_dims(tf.expand_dims(tf.tile(tf.expand_dims((numerator/(denominator + epsilon_numerical_instability)),axis = -1),(1,nr_of_users)), axis = -1), axis = -1)
    updated_transmitter_precoder = V_a -step_size*step_size_factor*gradient 
    
    return updated_transmitter_precoder, V, step_size_factor_temp, momentum1, momentum2
   

# Builds one Nesterov-accelerated GD iteration of the U update in the unfolded matrix-inverse-free WMMSE and computes the optimal step size
def GD_step_U_line_search_more_streams_Nesterov( init, init_momentum1, init_momentum2, name, mse_weights, user_weights, U, U_past, H, V,VVh, noise_power):

  with tf.variable_scope(name): 
    
    epsilon_numerical_instability = 10**(-9)
    step_size_factor_temp =  tf.Variable(tf.constant(init, dtype=tf.float64), name=name, dtype=tf.float64)
    step_size_factor = 2*tf.math.sigmoid(step_size_factor_temp)
    momentum1 =  tf.Variable(tf.constant(init_momentum1, dtype = tf.float64), name = "momentum_1", dtype = tf.float64,trainable = True)
    momentum2 =  tf.Variable(tf.constant(init_momentum2, dtype = tf.float64), name = "momentum_2", dtype = tf.float64,trainable = True)

    U_a = U + momentum1*(U - U_past)
    U_b = U + momentum2*(U - U_past)

    first_term = -1*tf.matmul(H,V)
    second_term = tf.matmul(tf.matmul(tf.matmul(H, VVh), tf.transpose(H, perm = [0,1,3,2])), U_b)

    power_V = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.tile(tf.trace(VVh)*0.5,(1,nr_of_users)),axis = -1),(1,1,2*nr_of_UE_antennas)),axis = -1),(1,1,1,2*nr_of_data_streams)) 
    third_term = noise_power*(1/total_power)*tf.multiply(U_b,power_V)

    gradient = 2*tf.multiply(user_weights,tf.matmul((first_term + second_term + third_term), mse_weights)) 

    ######################################
    # FIND OPTIMAL STEP SIZE##############
    ######################################

    #############################
    # NUMERATOR
    #############################
    power_V = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.tile(tf.trace(VVh)*0.5,(1,nr_of_users)),axis = -1),(1,1,2*nr_of_data_streams)),axis = -1),(1,1,1,2*nr_of_data_streams)) 

    first_term_num = 0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(tf.matmul(tf.matmul(tf.matmul(tf.transpose(U_a, perm = [0,1,3,2]),H),VVh),tf.transpose(H, perm = [0,1,3,2])), gradient)))
    second_term_num = 0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(tf.matmul(tf.matmul(tf.matmul(tf.transpose(gradient, perm = [0,1,3,2]),H),VVh),tf.transpose(H, perm = [0,1,3,2])), U_a)))

    third_term_num = 0.5*tf.trace(tf.matmul(mse_weights,noise_power*(1/total_power)*tf.multiply(tf.matmul(tf.transpose(U_a,perm = [0,1,3,2]),gradient),power_V)))
    fourth_term_num = 0.5*tf.trace(tf.matmul(mse_weights,noise_power*(1/total_power)*tf.multiply(tf.matmul(tf.transpose(gradient,perm = [0,1,3,2]),U_a),power_V)))

    fifth_term_num = -1*0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(tf.matmul(tf.transpose(V, perm = [0,1,3,2]), tf.transpose(H, perm = [0,1,3,2])), gradient)))
    sixth_term_num = -1*0.5*tf.trace(tf.matmul(mse_weights,tf.matmul(tf.matmul(tf.transpose(gradient, perm = [0,1,3,2]), H), V)))

    numerator = first_term_num + second_term_num + third_term_num + fourth_term_num + fifth_term_num + sixth_term_num

    #############################
    # DENOMINATOR
    ##############################

    first_term_den = tf.trace(tf.matmul(mse_weights,noise_power*(1/total_power)*tf.multiply(tf.matmul(tf.transpose(gradient, perm = [0,1,3,2]), gradient),power_V)))
    second_term_den = tf.trace(tf.matmul(mse_weights,tf.matmul(tf.matmul(tf.matmul(tf.matmul(tf.transpose(gradient, perm = [0,1,3,2]), H), VVh), tf.transpose(H, perm = [0,1,3,2])), gradient)))
    denominator = first_term_den + second_term_den + epsilon_numerical_instability

    step_size = tf.expand_dims(tf.expand_dims((numerator/denominator),axis = -1), axis = -1)

    updated_receiver_precoder = U_a - step_size*step_size_factor*gradient

  return updated_receiver_precoder, U, step_size_factor_temp, momentum1, momentum2

# **Tensorflow computation graph to run the matrix-inverse-free WMMSE algorithm**

In [None]:
tf.reset_default_graph()

channel_input = tf.placeholder(tf.float64, shape = None, name = 'channel_input')
initial_tp = tf.placeholder(tf.float64, shape = None, name = 'initial_transmitter_precoder') 
initial_tp_past = tf.placeholder(tf.float64, shape = None, name = 'initial_transmitter_precoder_past') # needed for the acceleration scheme

initial_rp = tf.placeholder(tf.float64, shape = None, name = 'initial_receiver_precoder') 
initial_rp_past = tf.placeholder(tf.float64, shape = None, name = 'initial_receiver_precoder_past') # needed for the acceleration scheme

initial_transmitter_precoder = initial_tp
initial_transmitter_precoder_past = initial_tp_past
initial_receiver_precoder = initial_rp
initial_receiver_precoder_past = initial_rp_past

# Arrays that contain the initialization values of the step sizes for receiver precoder
step_size_factor1_init_U = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
step_size_factor2_init_U = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]

# Arrays that contain the initialization values of the step sizes for receiver precoder
momentum1_1_init_U = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
momentum1_2_init_U = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]

# Arrays that contain the initialization values of the step sizes for receiver precoder
momentum2_1_init_U = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
momentum2_2_init_U = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]

# Arrays that contain the initialization values of the step sizes for transmitter precoder
step_size_factor1_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
step_size_factor2_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
step_size_factor3_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
step_size_factor4_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]

momentum1_1_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
momentum1_2_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
momentum1_3_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
momentum1_4_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]

momentum2_1_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
momentum2_2_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
momentum2_3_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
momentum2_4_init_V = [0.0,0.0,0.0,0.0,0.0,0.0,0.0]

profit = [] # stores the WSR obtained at each iteration
profit_alternative = [] # stores the WSR (computed through the mse weights) obtained at each iteration for the training

user_weights = tf.ones((batch_size,nr_of_users),dtype=tf.float64)
user_weights_U_expanded = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(user_weights,-1),(1,1,2*nr_of_UE_antennas)),-1), (1,1,1,2*nr_of_data_streams)) 

user_weights_A_expanded = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(user_weights,-1),(1,1,2*nr_of_BS_antennas)),-1),(1,1,1,2*nr_of_BS_antennas))
user_weights_V_expanded = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(user_weights,-1),(1,1,2*nr_of_BS_antennas)),-1),(1,1,1,2*nr_of_data_streams))


####################################
# UPDATE OF MSE WEIGHTS
####################################
I = tf.eye(2*nr_of_data_streams, 2*nr_of_data_streams, batch_shape = [batch_size, nr_of_users],dtype=tf.dtypes.float64) 

#FIRST TERM OF E
I_UhHV = I - tf.matmul(tf.matmul(tf.transpose(initial_receiver_precoder,perm = [0,1,3,2]),channel_input), initial_transmitter_precoder)
first_term = tf.matmul(I_UhHV, tf.transpose(I_UhHV, perm = [0,1,3,2]))

#SECOND TERM OF E
VVh_single_user = tf.matmul(initial_transmitter_precoder,tf.transpose(initial_transmitter_precoder, perm = [0,1,3,2]))
VVh = tf.expand_dims(tf.reduce_sum(VVh_single_user, axis = 1), axis = 1) 
VVh_other_users = VVh - VVh_single_user
UhHVVhHhU = tf.matmul(tf.matmul(tf.matmul(tf.matmul(tf.transpose(initial_receiver_precoder, perm = [0,1,3,2]), channel_input),VVh_other_users), tf.transpose(channel_input, perm = [0,1,3,2])),initial_receiver_precoder)
second_term = UhHVVhHhU 

#THIRD TERM OF E 
power_V = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.tile(tf.trace(VVh)*0.5,(1,nr_of_users)),axis = -1),(1,1,2*nr_of_data_streams)),axis = -1),(1,1,1,2*nr_of_data_streams)) 
third_term = noise_power *(1/total_power)* tf.multiply(tf.matmul(tf.transpose(initial_receiver_precoder, perm = [0,1,3,2]),initial_receiver_precoder),power_V)

E = first_term + second_term + third_term 

###########################################
# SCHULZ WITH SPECTRAL RADIUS NORMALIZATION
###########################################

mse_weights_init = tf.multiply(I,tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.reciprocal(0.5*tf.trace(E)),axis = -1),(1,1,2*nr_of_data_streams)),axis = -1),(1,1,1,2*nr_of_data_streams))) 

D = tf.matmul(E,mse_weights_init)
abs_D = tf.sqrt((D**2)[:,:,:nr_of_data_streams,:nr_of_data_streams] + (D**2)[:,:,:nr_of_data_streams,-nr_of_data_streams:] ) 
sum_abs_D = tf.expand_dims(tf.reduce_sum(abs_D, axis = 2),axis = -1) 

g = tf.reduce_max(tf.matmul(abs_D,sum_abs_D),axis = [-2,-1]) 
scaling = tf.sqrt(tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(g,axis = -1),(1,1,2*nr_of_data_streams)),axis = -1),(1,1,1,2*nr_of_data_streams))) 
                    
mse_weights_input = tf.divide(mse_weights_init,scaling)

mse_weights1 = Schulz(mse_weights_input,E,I)
mse_weights = Schulz(mse_weights1,E,I)


for loop in range(0,nr_of_iterations_nn):

  # To update VVh for receiver precoder
  VVh = tf.expand_dims(tf.reduce_sum(tf.matmul(initial_transmitter_precoder,tf.transpose(initial_transmitter_precoder, perm = [0,1,3,2])), axis = 1), axis = 1)

  #######################################
  # UPDATE OF RECEIVER PRECODER 
  #######################################

  receiver_precoder1, receiver_precoder_past1, step_size1_U, momentum1_1_U, momentum2_1_U = GD_step_U_line_search_more_streams_Nesterov( step_size_factor1_init_U[loop], momentum1_1_init_U[loop], momentum2_1_init_U[loop],"GD1", mse_weights, user_weights_U_expanded, initial_receiver_precoder, initial_receiver_precoder_past, channel_input, initial_transmitter_precoder, VVh, noise_power) 
  receiver_precoder_final, receiver_precoder_past2, step_size2_U , momentum1_2_U, momentum2_2_U = GD_step_U_line_search_more_streams_Nesterov( step_size_factor2_init_U[loop], momentum1_2_init_U[loop], momentum2_2_init_U[loop],"GD2", mse_weights, user_weights_U_expanded, receiver_precoder1, receiver_precoder_past1, channel_input, initial_transmitter_precoder,  VVh, noise_power) 


  ####################################
  # UPDATE OF MSE WEIGHTS
  ####################################

  #FIRST TERM OF E
  I_UhHV = I - tf.matmul(tf.matmul(tf.transpose(receiver_precoder_final,perm = [0,1,3,2]),channel_input), initial_transmitter_precoder)
  first_term = tf.matmul(I_UhHV, tf.transpose(I_UhHV, perm = [0,1,3,2]))

  #SECOND TERM OF E
  VVh_single_user = tf.matmul(initial_transmitter_precoder,tf.transpose(initial_transmitter_precoder, perm = [0,1,3,2]))
  VVh = tf.expand_dims(tf.reduce_sum(VVh_single_user, axis = 1), axis = 1)
  VVh_other_users = VVh - VVh_single_user
  UhHVVhHhU = tf.matmul(tf.matmul(tf.matmul(tf.matmul(tf.transpose(receiver_precoder_final, perm = [0,1,3,2]), channel_input),VVh_other_users), tf.transpose(channel_input, perm = [0,1,3,2])),receiver_precoder_final)
  second_term = UhHVVhHhU 

  #THIRD TERM OF E
  power_V = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.tile(tf.trace(VVh)*0.5,(1,nr_of_users)),axis = -1),(1,1,2*nr_of_data_streams)),axis = -1),(1,1,1,2*nr_of_data_streams))
  third_term = noise_power *(1/total_power)* tf.multiply(tf.matmul(tf.transpose(receiver_precoder_final, perm = [0,1,3,2]),receiver_precoder_final),power_V)

  E = first_term + second_term + third_term


  ###########################################
  # SCHULZ WITH SPECTRAL RADIUS NORMALIZATION
  ###########################################

  D = tf.matmul(E,mse_weights)
  abs_D = tf.sqrt((D**2)[:,:,:nr_of_data_streams,:nr_of_data_streams] + (D**2)[:,:,:nr_of_data_streams,-nr_of_data_streams:] ) 
  sum_abs_D = tf.expand_dims(tf.reduce_sum(abs_D, axis = 2),axis = -1) 
 
  g = tf.reduce_max(tf.matmul(abs_D,sum_abs_D),axis = [-2,-1]) 
  scaling = tf.sqrt(tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(g,axis = -1),(1,1,2*nr_of_data_streams)),axis = -1),(1,1,1,2*nr_of_data_streams))) 
                      
  mse_weights_input = tf.divide(mse_weights,scaling)

  mse_weights1 = Schulz(mse_weights_input,E,I)
  mse_weights = Schulz(mse_weights1,E,I)


  ##########################################
  # UPDATE OF TRANSMITTER PRECODER
  ##########################################
  
  A = tf.expand_dims(tf.reduce_sum(tf.multiply(user_weights_A_expanded, tf.matmul(tf.matmul(tf.matmul(tf.matmul(tf.transpose(channel_input, perm = [0,1,3,2]), receiver_precoder_final), mse_weights), tf.transpose(receiver_precoder_final, perm = [0,1,3,2])), channel_input)), axis = 1), axis = 1)

  transmitter_precoder1, transmitter_precoder_past1, step_size1_V ,momentum1_1_V, momentum2_1_V= GD_step_V_line_search_more_streams_Nesterov( step_size_factor1_init_V[loop], momentum1_1_init_V[loop], momentum2_1_init_V[loop],"step_size1_V", mse_weights, user_weights_V_expanded, receiver_precoder_final, channel_input, initial_transmitter_precoder,initial_transmitter_precoder_past, A, total_power)
  transmitter_precoder2, transmitter_precoder_past2, step_size2_V ,momentum1_2_V, momentum2_2_V= GD_step_V_line_search_more_streams_Nesterov( step_size_factor2_init_V[loop], momentum1_2_init_V[loop], momentum2_2_init_V[loop],"step_size2_V", mse_weights, user_weights_V_expanded, receiver_precoder_final, channel_input, transmitter_precoder1,transmitter_precoder_past1, A, total_power)
  transmitter_precoder3, transmitter_precoder_past3, step_size3_V ,momentum1_3_V, momentum2_3_V= GD_step_V_line_search_more_streams_Nesterov( step_size_factor3_init_V[loop], momentum1_3_init_V[loop], momentum2_3_init_V[loop],"step_size3_V", mse_weights, user_weights_V_expanded, receiver_precoder_final, channel_input, transmitter_precoder2, transmitter_precoder_past2,A, total_power)
  transmitter_precoder_final, transmitter_precoder_past4, step_size4_V ,momentum1_4_V, momentum2_4_V= GD_step_V_line_search_more_streams_Nesterov( step_size_factor4_init_V[loop], momentum1_4_init_V[loop], momentum2_4_init_V[loop],"step_size4_V", mse_weights, user_weights_V_expanded, receiver_precoder_final, channel_input, transmitter_precoder3, transmitter_precoder_past3,A, total_power)


  ##############################################################################
  # For the next loop
  initial_transmitter_precoder = transmitter_precoder_final
  initial_receiver_precoder = receiver_precoder_final
  ##############################################################################  

  ##############################################################################
  if scale_V_every_iteration == True:
      transmitter_precoder_power = tf.expand_dims(tf.reduce_sum((0.5*(tf.norm(transmitter_precoder_final, axis=[-2,-1] ))**2),axis = 1),axis = -1) 
      power_scaling_ref = tf.divide(1, tf.sqrt(transmitter_precoder_power))*tf.sqrt(tf.cast(total_power,dtype = tf.float64)) 
      power_scaling_expanded = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.tile(power_scaling_ref,(1,nr_of_users)),axis = -1),(1,1,2*nr_of_BS_antennas)),axis = -1),(1,1,1,2*nr_of_data_streams))
      initial_transmitter_precoder = tf.multiply(transmitter_precoder_final, power_scaling_expanded)
  ##############################################################################

  # LOSS FUNCTION
  # scale the transmit precoder at the last iteration
  if  loop == (nr_of_iterations_nn-1.0) and scale_V_every_iteration == False :
    transmitter_precoder_power = tf.expand_dims(tf.reduce_sum((0.5*(tf.norm(transmitter_precoder_final, axis=[-2,-1] ))**2),axis = 1),axis = -1) 
    power_scaling_ref = tf.divide(1, tf.sqrt(transmitter_precoder_power))*tf.sqrt(tf.cast(total_power,dtype = tf.float64)) 
    power_scaling_expanded = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.tile(power_scaling_ref,(1,nr_of_users)),axis = -1),(1,1,2*nr_of_BS_antennas)),axis = -1),(1,1,1,2*nr_of_data_streams))
    initial_transmitter_precoder = tf.multiply(transmitter_precoder_final, power_scaling_expanded)
  
  profit.append(compute_WSR_neural_network(channel_input, initial_transmitter_precoder, noise_power,user_weights, batch_size))


  if loop == (nr_of_iterations_nn-2.0):

      # we scale the final precoder to meet the power constraint
      if scale_V_every_iteration == False:
        transmitter_precoder_power = tf.expand_dims(tf.reduce_sum((0.5*(tf.norm(transmitter_precoder_final, axis=[-2,-1] ))**2),axis = 1),axis = -1)
        power_scaling_ref = tf.divide(1, tf.sqrt(transmitter_precoder_power))*tf.sqrt(tf.cast(total_power,dtype = tf.float64))
        power_scaling_expanded = tf.tile(tf.expand_dims(tf.tile(tf.expand_dims(tf.tile(power_scaling_ref,(1,nr_of_users)),axis = -1),(1,1,2*nr_of_BS_antennas)),axis = -1),(1,1,1,2*nr_of_data_streams))                     
        initial_transmitter_precoder = tf.multiply(transmitter_precoder_final, power_scaling_expanded)

      # compute the WSR given by transmitter_precoder_to_use
      WSR_from_V_previous_iteration = compute_WSR_neural_network(channel_input, initial_transmitter_precoder, noise_power,user_weights, batch_size)

WSR = tf.reduce_sum(profit)
WSR_final = profit[-1]

optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(-WSR)


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


# **Running the unfolded matrix-inverse-free WMMSE and the WMMSE**
In this code, we train the unfolded matrix-inverse-free WMMSE network and we compare its performance with the WMMSE by testing both approaches on the same dataset of channel realizations.

In [None]:
WSR_WMMSE =[] # to store the WSR attained by the WMMSE
WSR_ZF = [] # to store the WSR attained by the zero-forcing 
WSR_RZF = [] # to store the WSR attained by the regularized zero-forcing
WSR_nn = [] # to store the WSR attained by the deep unfolded WMMSE
training_loss = []
WSR_from_W = []

initial_transmitter_precoder_batch_past = np.zeros((nr_of_samples_per_batch,nr_of_users,nr_of_BS_antennas*2,2*nr_of_data_streams)) # for the acceleration scheme
initial_receiver_precoder_batch_past = np.zeros((nr_of_samples_per_batch,nr_of_users,nr_of_UE_antennas*2,2*nr_of_data_streams)) # for the acceleration scheme

with tf.Session() as sess:

    np.random.seed(5678) ##REMOVEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE
    print("start of session")
    start_of_time = time.time()
    sess.run(tf.global_variables_initializer())
   
    for i in range(nr_of_batches_training):
      batch_for_training = []
      initial_transmitter_precoder_batch = []
      initial_receiver_precoder_batch = []
            
      # Building a batch for training
      for ii in range(nr_of_samples_per_batch):
        channel_realization_nn, init_transmitter_precoder, init_receiver_precoder, _,_ = compute_channel(nr_of_BS_antennas, nr_of_users, total_power)
        batch_for_training.append(channel_realization_nn)
        initial_transmitter_precoder_batch.append(init_transmitter_precoder)
        initial_receiver_precoder_batch.append(init_receiver_precoder)
     
      # Training
      sess.run(optimizer, feed_dict={channel_input:batch_for_training,initial_tp:initial_transmitter_precoder_batch , initial_tp_past:initial_transmitter_precoder_batch_past, initial_rp:initial_receiver_precoder_batch, initial_rp_past:initial_receiver_precoder_batch_past })
      training_loss.append(-1*(sess.run(WSR, feed_dict={channel_input:batch_for_training,initial_tp:initial_transmitter_precoder_batch , initial_tp_past:initial_transmitter_precoder_batch_past, initial_rp:initial_receiver_precoder_batch, initial_rp_past:initial_receiver_precoder_batch_past})))

    # For repeatability
    np.random.seed(1234)

    WSR_WMMSE_batch = 0.0
    
    for i in range(nr_of_batches_test):    
      batch_for_testing = []
      initial_transmitter_precoder_batch = []
      initial_receiver_precoder_batch = []
   
      # Building a batch for testing
      for ii in range(nr_of_samples_per_batch):       
        channel_realization_nn, init_transmitter_precoder,init_receiver_precoder, channel_WMMSE, initial_transmitter_precoder_WMMSE = compute_channel(nr_of_BS_antennas, nr_of_users, total_power)          

        batch_for_testing.append(channel_realization_nn)
        initial_transmitter_precoder_batch.append(init_transmitter_precoder)
        initial_receiver_precoder_batch.append(init_receiver_precoder)

        _,_,_, WSR_WMMSE_one_sample = run_WMMSE_MIMO_more_streams(epsilon, channel_WMMSE, initial_transmitter_precoder_WMMSE, total_power, noise_power, user_weights_WMMSE, nr_of_iterations_WMMSE, log = False)
        WSR_WMMSE_batch =  WSR_WMMSE_batch + WSR_WMMSE_one_sample

      #Testing
      WSR_nn.append(sess.run(WSR_final , feed_dict={channel_input:batch_for_testing,initial_tp:initial_transmitter_precoder_batch , initial_tp_past:initial_transmitter_precoder_batch_past, initial_rp:initial_receiver_precoder_batch, initial_rp_past:initial_receiver_precoder_batch_past }))

print("The WSR achieved with unfolded matrix-inverse-free WMMSE is: ", np.mean(WSR_nn))
print("The WSR achieved with the WMMSE algorithm is: ", WSR_WMMSE_batch/(nr_of_samples_per_batch*nr_of_batches_test))

plt.figure()
plt.plot(training_loss)
plt.ylabel("Training loss")
plt.xlabel("Sample index")
