## Importing Libraries

In [1]:
import numpy as np # Importing NumPy for numerical computing, such as array operations and mathematical functions.
import matplotlib.pyplot as plt # Importing Pyplot from Matplotlib for data visualization and plotting graphs.
from matplotlib import cm # Importing 'cm' from Matplotlib for access to color maps used in plotting.
from scipy.special import erfc # Importing 'erfc' (complementary error function) from SciPy's special functions module.
import pandas as pd # Importing Pandas for data manipulation and analysis, particularly for structured data operations.
from modem import PSKModem, QAMModem, PAMModem, FSKModem # Importing various modem types (Phase Shift Keying, Quadrature Amplitude Modulation, Pulse Amplitude Modulation, Frequency Shift Keying) from a custom 'modem' module.
from channels import awgn, rayleighFading, TSMG # Importing functions for simulating different types of channels: Additive White Gaussian Noise, Rayleigh Fading, and TSMG (possibly a custom channel type) from a custom 'channels' module.
from errorRates import ser_rayleigh # Importing a function to calculate symbol error rate in Rayleigh fading channel conditions from a custom 'errorRates' module.
from tqdm.notebook import tqdm # Importing 'tqdm' from tqdm.notebook for displaying progress bars in Jupyter notebooks.
import torch

## Defining the RL-adapted WSN simulation Environment (REINFORCE): 

In [103]:
class WSN_env_RL():

####################################################### Initialization ########################################################################

    def __init__(self, M=10, Tc = 20,Env_size = 30,Modem_type='QAM',Mod_order = 4,p_B = 0,R = 20,NOISE_MEMORY=20,random_seed = 10):
        # Initialization method for the WSN_environment class.
        # Parameters:
        # M: Number of nodes in the network.
        # Tc: Coherence Time in symbol duration.
        # Env_size: Size of the environment.
        # Modem_type: Type of modem used (PSK, QAM, PAM).
        # Mod_order: Modulation order.
        # p_B, R, NOISE_MEMORY: Parameters for noise model.
        # random_seed: Seed for random number generation.
        self.M = M
        self.p_B = p_B 
        self.R = R 
        self.NOISE_MEMORY = NOISE_MEMORY
        self.agent = None
        self.node_battery_states = np.ones(M-1)

        #Tc/Ts ratio:
        self.Tc =  Tc 
        #Modulation parameters
        modem_dict = {'psk': PSKModem,'qam':QAMModem,'pam':PAMModem}
        self.Modem = modem_dict[Modem_type.lower()](Mod_order)
        self.Mod_order = Mod_order 

        #nodes and links characteristics 
        self.Env_size = Env_size 
        _ = self.reset()
        #self.visualize()
        #print('Initial link CSIs :')
        #print(self.links_CSI.head(10))

    def reset(self,n_time_step = 1):
        #nodes and links characteristics 
        self.node_positions = self.init_positions()
        self.links_CSI = pd.DataFrame(columns=['link','h','h_abs','G/B'])
        self.initiate_links()
        self.log = []
        return self.training_step(n_time_step = n_time_step)



    def init_positions(self):
        # Initializes positions of nodes in the network.

        while True :

            node_positions = np.array([np.random.randint(0,self.Env_size,self.M),np.random.randint(0,self.Env_size,self.M)]).T
            node_positions[0] = np.array([0,self.Env_size-1])
            node_positions[-1] = np.array([self.Env_size-1,0])
            found_positions = True
            for i in range(len(node_positions)):
                for j in range(len(node_positions)):
                    if j!=i and node_positions[i][0] == node_positions[j][0] and node_positions[i][1] == node_positions[j][1]:
                        found_positions = False
            if found_positions:
                break
        return node_positions


    def initiate_links(self): # Initializes the links between nodes and their channel state information (CSI).
        self.links_CSI = pd.DataFrame(columns=['link','distance','h','h_abs','G/B'])
        index = 0
        for i in range(1,self.M+1):
            for j in range(1,self.M+1):
                if i !=j :
                    d = self.distance(i,j)
                    h = np.array(rayleighFading(1,d,2)[0],dtype = 'complex_')
                    self.links_CSI.loc[index] = [str(i)+'->'+str(j),d,h,abs(h),'G']
                    index+=1


    def visualize(self):
        print('visualize  environment :')

        for i, txt in enumerate(np.arange(1,self.M)):
            for j in range(1,self.M):
                plt.annotate("", xy=(self.node_positions[i,0], self.node_positions[i,1]), xytext=(self.node_positions[j,0], self.node_positions[j,1]),
                arrowprops=dict(arrowstyle="<->"))
            
        for i, txt in enumerate(np.arange(1,self.M)):
            plt.annotate(txt, (self.node_positions[i,0]+0.1, self.node_positions[i,1]+0.1),color='green',size=20)
        plt.annotate('D', (self.node_positions[-1,0]+0.1, self.node_positions[-1,1]+0.1),color='blue',size=20)

        plt.scatter(self.node_positions[:,0], self.node_positions[:,1],
            c='g', alpha=0.6, lw=0,s=400)
        plt.scatter(self.node_positions[-1,0], self.node_positions[-1,1], c='g', alpha=0.6, lw=0,s=400  )
        plt.xlim(0,self.Env_size)
        plt.ylim(0,self.Env_size)
        plt.axis('off')
        plt.show()








####################################################### Helpfull functions ########################################################################

    def update_all_links(self):
        # Updates the channel state information (CSI) for all links in the network.
        # This is done by first setting the 'h' column in the links_CSI DataFrame equal to the distance values.
        # Then, it applies the `update_one_link` function to each 'h' value to update the CSI.
        # Finally, it calculates the absolute value of each 'h' and stores it in 'h_abs'.
        self.links_CSI['h'] =  self.links_CSI['distance'].values
        self.links_CSI['h'] =  self.links_CSI['h'].apply(self.update_one_link)
        self.links_CSI['h_abs'] =  self.links_CSI['h'].apply(np.abs)



    def update_one_link(self, d):
        # Updates the channel state information for a single link.
        # This function takes the distance 'd' as input and returns the CSI as a complex number.
        # The rayleighFading function is used to model the fading effect on the link based on the distance.
        # The output is formatted as a complex number to represent the amplitude and phase shift caused by the fading.
        return np.array(rayleighFading(1,d,2)[0],dtype = 'complex_')


    def distance(self, node1, node2):
        # Calculates the normalized distance between two nodes in the network.
        # The function computes the Euclidean distance between the positions of node1 and node2.
        # The distance is then normalized with respect to the distance between nodes 1 and 10 in the network.
        # This normalization might be specific to the network layout or design criteria.
        distance_1_to_10 =  np.linalg.norm(self.node_positions[0] - self.node_positions[-1],ord=2) 
        return  np.linalg.norm(self.node_positions[node1-1] - self.node_positions[node2-1],ord=2)/distance_1_to_10


    def training_step(self,EsN0dB_ = 50,n_time_step = 128,actions = np.zeros(128),source = 1,allow_battery_consumption = False):
        """
        Actions : array with size nA containing nA action: an action for each frame
        action = select a relay
        """
        #k = np.log2(self.Mod_order)
        #EsN0dB = 10*np.log10(k)+EbN0dB # EsN0dB calculation
        if sum(actions)==0:
            actions = np.zeros(n_time_step)
            
        nSym = n_time_step*self.Tc
        #print('nSym :',nSym)

        ###############################################  #recover the selected signals: #########################################################

        if len(self.log) >0:
                
                list_optimal_actions = []
                list_obtained_SER_MRC = []
                list_optimal_MRC_SER = []
                list_DT_SER = []
                for time_step in range(n_time_step):
                    best_relay = int(actions[time_step])  # 0 -> M-3 (a total of M-2 relay) to 2 -> M-1 (1 = source, M = Destination)  

                    optimal_MRC_SER = np.min(self.log[time_step]['SERs'])
                    optimal_action = np.argmin(self.log[time_step]['SERs'])
                    obtained_MRC_SER = np.min(self.log[time_step]['SERs'][best_relay])
                    list_optimal_actions.append(optimal_action)
                    list_obtained_SER_MRC.append(obtained_MRC_SER)
                    list_optimal_MRC_SER.append(optimal_MRC_SER)
                    list_DT_SER.append(self.log[time_step]['SER_DT'])
                    
                    if allow_battery_consumption:
                        self.node_battery_states[best_relay+1] -=  self.Tc

                    #print('max MRC = max RD_mask ? ',self.log[time_step]['SERs'][np.argmax(self.log[time_step]['RD_masks'])]== optimal_MRC_SER)




        else:

                list_optimal_actions = np.zeros(n_time_step)
                list_obtained_SER_MRC = np.ones(n_time_step)
                list_optimal_MRC_SER = np.ones(n_time_step)
                list_DT_SER = np.ones(n_time_step)
        

        ##########################################        #Obtaining Next State:        ##############################################################

              
        self.log = np.zeros(n_time_step,dtype=object)
        
        next_states_tensor = []
        for time_step in range(n_time_step):
            state = []
            self.log[time_step] = {}
            self.update_all_links()
            
            #uniform random symbols from 0 to M-1 with size nSym
            input_syms_chunk = np.random.randint(low=0, high = self.Mod_order, size=self.Tc)
            
            PGs = []
            SERs = []
            RD_masks = []

            #Sending information from S to D
            SD_receivedSyms,h_SD,noise_states_SD = self.send_over_one_link('awgn',input_syms_chunk,source = source,destination = self.M,EsN0dB = EsN0dB_)
            SD_receivedSyms_equalized = np.array(h_SD.conjugate()*SD_receivedSyms/abs(h_SD),dtype = 'complex_')
            SD_detectedSyms = self.Modem.demodulate(SD_receivedSyms_equalized) #demodulate 
            SER_DT = np.sum(SD_detectedSyms.astype(int) != input_syms_chunk)/len(input_syms_chunk)
            self.log[time_step]['SER_DT'] = SER_DT

            #Sending information from S to all possible Relays
            for relay in range(1,self.M):
                if relay != source:
                    receivedSyms,h_SR,SR_noise_states = self.send_over_one_link('tsmg',input_syms_chunk,source = source,destination = relay,EsN0dB = EsN0dB_)
                    receivedSyms = np.array(h_SR.conjugate()*receivedSyms/abs(h_SR),dtype = 'complex_')
                    SR_detectedSyms = self.Modem.demodulate(receivedSyms) #demodulate 

                    #Sending information from R to D
                    RD_receivedSyms,h_RD,noise_states_RD = self.send_over_one_link('awgn',SR_detectedSyms,source = relay,destination = self.M,EsN0dB = EsN0dB_)


                    #Filtering the bad states :
                    RD_mask = SR_noise_states*(SR_detectedSyms == input_syms_chunk)
            

                    #Maximum ratio combiner :
                    D_receivedSyms = (h_SD.conjugate()*SD_receivedSyms +  h_RD.conjugate()*RD_mask*RD_receivedSyms) /np.linalg.norm([h_RD,h_SD],ord=2)
                    D_receivedSyms_eq = np.array(D_receivedSyms,dtype = 'complex_')
                    D_detectedSyms = self.Modem.demodulate(D_receivedSyms_eq) #demodulate 
                    SER_MRC = np.sum(D_detectedSyms.astype(int) != input_syms_chunk)/len(input_syms_chunk)
                    
                    p_G= sum(SR_noise_states)/self.Tc

                    PGs.append(p_G)
                    SERs.append(SER_MRC)
                    RD_masks.append(sum(RD_mask))

                    state.append(sum(RD_mask)*10/len(RD_mask) - 9)
                    state.append(min(np.abs(h_RD),np.abs(h_SR)))
                    state.append(np.abs(h_RD)) 
                    state.append(np.abs(h_SR))
                    state.append((1-p_G-0.1)*10)
                    state.append(p_G)
                    state.append(1-p_G)
                    #RBP = 10*(self.node_battery_states[relay-1]-np.min(self.node_battery_states[1:]))/max(self.node_battery_states[1:]) 
                    #state.append(RBP)
                   # state.append(SER_MRC*self.Tc)


                   
            self.log[time_step]['SERs'] = SERs
            self.log[time_step]['PGs'] = PGs
            self.log[time_step]['RD_masks'] = RD_masks


                     
   

            next_states_tensor.append(state)

        next_states_tensor = torch.Tensor(np.array(next_states_tensor))
        
        return np.array(list_optimal_actions),np.array(list_obtained_SER_MRC),np.array(list_optimal_MRC_SER),np.array(list_DT_SER),next_states_tensor,False





    def send_over_one_link(self, noise_type, inputSyms, source=1, destination=2, EsN0dB=5):
        # Simulates the transmission of symbols over a single link between two nodes.
        # noise_type: Type of noise to be added to the link (e.g., 'awgn' for Additive White Gaussian Noise, 'tsmg' for another type).
        # inputSyms: Array of input symbols to be transmitted.
        # source: The source node for the transmission.
        # destination: The destination node for the transmission.
        # EsN0dB: Energy per symbol to noise power spectral density ratio, in decibels.

        # Modulation of the input symbols.
        modulatedSyms = self.Modem.modulate(inputSyms)

        # Retrieving the channel coefficient 'h' for the specified link.
        h = self.links_CSI[self.links_CSI.link == str(source)+'->'+str(destination)]['h'].to_numpy()[0]

        # Applying fading effect to the modulated symbols.
        fadedSyms = h * modulatedSyms

        # Adding noise based on the specified noise type.
        if noise_type.lower() == 'awgn':
            # Adding Additive White Gaussian Noise.
            receivedSyms = awgn(fadedSyms, EsN0dB)
            noise_states = []

        elif noise_type.lower() == 'tsmg':
            # Adding noise modeled by the TSMG function, potentially representing a different noise model.
            receivedSyms, noise_states = TSMG(fadedSyms, P_B=self.p_B, R=self.R, NOISE_MEMORY=self.NOISE_MEMORY, SNRdB=EsN0dB)

        # Return the received symbols, the channel coefficient, and the noise states.
        return receivedSyms, h, noise_states







## Testing the Environment

In [104]:
env = WSN_env_RL(M = 10,Modem_type='QAM', Tc = 1000, Mod_order=4,Env_size=40,p_B=0.1, R=100, NOISE_MEMORY=100)
_,_,_,_,next_states,done = env.reset(n_time_step=1)
next_states

tensor([[ 0.6600,  0.3714,  0.3714,  1.1963, -0.6600,  0.9660,  0.0340, -2.7600,
          2.0700,  2.0700,  5.0923,  2.7600,  0.6240,  0.3760,  0.0200,  0.8852,
          1.3026,  0.8852, -0.0200,  0.9020,  0.0980,  0.3800,  1.6092,  3.1514,
          1.6092, -0.3800,  0.9380,  0.0620,  1.0000,  0.5706,  0.5706,  9.9406,
         -1.0000,  1.0000,  0.0000, -3.4000,  1.3866,  1.3866,  2.8093,  3.4000,
          0.5600,  0.4400, -0.1900,  1.5229,  4.8637,  1.5229,  0.1900,  0.8810,
          0.1190,  1.0000,  1.1669,  1.1669,  1.6149, -1.0000,  1.0000,  0.0000]])

## Defining the RL Functions

### Pseudocode

-------------------------------------------------------------------------------
**Input:** a differentiable policy network $\pi_\theta \in \mathcal{R}^d$ \\
**Algorithm parameters:** 


1.   $\alpha$: step size > 0
2.   n_iterations: number of gradient updates  > 0
3.   n_episodes: number of episodes per gradient update  > 0 

Initialize policy parameters 

loop for n_iterations: \\

  &emsp;&emsp;sample a dataset of episodes according to $\pi_\theta$ \\
  
  &emsp;&emsp;# compute policy gradient \\
  &emsp;&emsp;$\nabla_\theta J(\theta) = \sum_j \sum_t \psi_{jt} \nabla_\theta ln\pi_\theta (a_t^j|s_t^j)$ \\
  
  &emsp;&emsp; # where the first summation is over the episodes and the second summation is over the trajectory of the episode. \\
  
  &emsp;&emsp;# update policy parameters \\
  &emsp;&emsp;$\theta_i = \theta_i + \alpha \nabla_\theta J(\theta)$
  
------------------------------------------------------------------------------

        

### Building the Policy Network

In [79]:
from typing import Sequence, Dict, Any, Optional

import numpy as np

# torch stuff
import torch
import torch.nn as nn
import torch.nn.functional as F 

# data manipulation, colab dispaly, and plotting
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# misc util
import  itertools
import time
from torch.distributions import Categorical


### Building the Base Agent Class


In [105]:
class BaseAgent(object):
    """ The base agent class function.
    """
    
    def __init__(self, config: Dict[str, Any]):
        """
        args:
            config: configuration dictionary
        """
        
        self.config = config
        # assert len(config['policy_layers']) > 0 # this won't allow linear models

        # environments
        self.env  = WSN_env_RL(M = 10,Modem_type='QAM', Tc = 1000, Mod_order=4,Env_size=40,p_B=0.1, R=100, NOISE_MEMORY=100)
        #self.test_env  = WSN_env_RL(M = 10,Modem_type='QAM', Tc = 1000, Mod_order=4,Env_size=40,p_B=0.1, R=100, NOISE_MEMORY=100)
        self.env.observation_space_size = (self.env.M-2) * 7
        self.env.action_space_size      = self.env.M-2
        self.gamma = config['gamma']

        # set seed
        np.random.seed(seed=config['seed'])
        torch.manual_seed(config['seed'])

        # build policy model
        _policy_logits_model = Model(
            [self.env.observation_space_size] + 
            config['policy_layers'] + # note that these are only the intermediate layers
            [self.env.action_space_size],
        )
        
        # NOTE: by design, policy model should take *batches* of states as input.
        # self.policy_model spits out the probability of each action
        self.policy_model = nn.Sequential(
            _policy_logits_model, nn.Softmax(dim=1), 
        )

        self.policy_optimizer = torch.optim.Adam(
            self.policy_model.parameters(), 
            lr=config['policy_learning_rate'],
        )

    def _make_returns(self, rewards: np.ndarray):
        """ Compute the cumulative discounted rewards at each time step
        args:
            rewards: an array of step rewards

        returns:
            returns: an array of discounted returns from that timestep onward
        """
        returns = np.zeros_like(rewards)
        returns[-1] = rewards[-1]
        for t in reversed(range(len(rewards) - 1)):
            returns[t] = rewards[t] + self.gamma * returns[t + 1]
        return returns
    
    
    def train(self ,n_episodes: int, n_iterations: int, plot: bool = True,EsN0dB = 5) -> Sequence[np.ndarray]:
        """ Train.
        args:
            n_episodes: number of episodes for each gradient step
            n_iterations: determine training duration
        """        
        max_accuracy = 0
        ##first steps for the agent 
        _,_,_,_,next_states,done = self.env.reset(n_time_step = 1)

        self.env.visualize()
        _,_,_,_,next_states,done =self.env.training_step(n_time_step = 1,EsN0dB_=EsN0dB)
        accuracies = []
        losses = []
        rewards = []
        
        for it in tqdm(range(n_iterations)):

            reward,accuracy,loss,NRBP =  self.optimize_model(n_episodes,1,EsN0dB,next_states)            



            self.policy_optimizer.zero_grad()
            loss.backward()
            self.policy_optimizer.step()

            rewards.append(np.array(reward))
            accuracies.append(accuracy)
            losses.append(loss.detach().numpy().round(2))

            if accuracy > max_accuracy and NRBP <1:
                torch.save(self.policy_model,'best_model.pth')
                max_accuracy = accuracy
            print('Iteration ',it + 1,'/',n_iterations,': rewards ',np.mean(rewards[-10:]).round(2),' ## accuracy ',accuracy,' ## policy loss ',loss.detach().numpy().round(2),' ## NRBP ',round(NRBP, 2))

        if plot:
            self.plot_rewards(rewards)
            plt.show()
            plt.plot([np.mean(i) for i in rewards ],marker='o')
            plt.xlabel("episodes")
            plt.ylabel("Reward")
            plt.show()
        
            plt.plot(accuracies,marker='o')
            plt.ylabel("accuracy")
            plt.xlabel("episodes")
            plt.show()          
            plt.plot(losses,marker='o')
            plt.xlabel("episodes")
            plt.ylabel("policy loss")
            plt.show()

        
    @staticmethod
    def plot_rewards(rewards: Sequence[np.ndarray], ax: Optional[Any] = None):
        # Plotting
        r = pd.DataFrame((itertools.chain(*(itertools.product([i], rewards[i]) for i in range(len(rewards))))), columns=['Epoch', 'Reward'])
        if ax is None:
            sns.lineplot(x="Epoch", y="Reward", data=r, ci='sd');
        else:
            sns.lineplot(x="Epoch", y="Reward", data=r, ci='sd', ax=ax);
        



In [106]:
# Insert your code and run this cell
class REINFORCEv1Agent(BaseAgent):
    """ REINFORCE agent with total trajectory reward.
    """
    
    def optimize_model(self, n_episodes,n_iter_per_episode,EsN0dB,next_states):
        """
            This method is called at each training iteration and is responsible for 
            (i) gathering a dataset of episodes
            (ii) computing the expectation of the policy gradient.
        """
        
        # ======================================================================
        # ======================================================================
        # ==========              Training          ============================
        # ======================================================================
        # ======================================================================

        ### (i) gathering a dataset of episodes
        accuracy = 0
        rewards = torch.zeros((1),requires_grad=False)
        log_probs = torch.zeros((1),requires_grad=True)
        NRBPVs = []
        MRC = []
        #OPt_MRC = []
        #for each time step
        for episode in range(n_episodes):
    

            probs = self.policy_model(next_states)

            ## Calculating the variance in the remaining battery power:
            battery_diff = (self.env.node_battery_states[1:]-np.min(self.env.node_battery_states[1:]))/np.max(self.env.node_battery_states[1:])
            NRBPV = (np.max(self.env.node_battery_states[1:])-np.min(self.env.node_battery_states[1:]))/np.max(self.env.node_battery_states[1:])
            NRBPVs.append(NRBPV)
            
            all_actions = torch.topk(probs, self.env.M-2).indices
            actions = all_actions[:,0]
            for i in range(self.env.M-2):
                if battery_diff[all_actions[0,i]] > 0.2 * NRBPV:
                    actions = all_actions[:,i]
                    break

            dist = Categorical(probs)
            log_probs_ep = dist.log_prob(actions)


            #take a step
            list_optimal_actions,list_obtained_SER_MRC,list_optimal_MRC_SER,list_DT_SER, next_states,done = self.env.training_step(n_time_step = n_iter_per_episode,actions=actions,EsN0dB_=EsN0dB,allow_battery_consumption = True)
            MRC.append(list_obtained_SER_MRC)
            #OPt_MRC.append(list_optimal_MRC_SER)

            ## Calculating the accuracy of choosing an optimal action :
           # print("list_optimal_MRC_SER :",list_optimal_MRC_SER,"  list_obtained_SER_MRC :",list_obtained_SER_MRC)
            accuracy += sum(list_optimal_MRC_SER == list_obtained_SER_MRC)

            NRBPVs.append(NRBPV)

            ## Calculating the reward:
            reward = -1000*list_obtained_SER_MRC +5
                
            rewards =  torch.cat([rewards, torch.Tensor(reward)], axis=0)
            log_probs =  torch.cat([log_probs, log_probs_ep], axis=0)
            
            #print("rewards.shape : ", rewards.shape)
            #print("log_probs.shape : ", log_probs.shape)

        accuracy = accuracy/n_episodes
        rewards = rewards[1:]
        log_probs = log_probs[1:]

      
        # ========================================================================================================================================
        
        policy_loss =  -torch.sum(rewards * log_probs)
        
        #print(' MRC ',np.mean(MRC).round(5),' OPT MRC ',np.mean(OPt_MRC).round(5),'  #', end='')

        return rewards.detach().numpy().round(2),accuracy.round(2),policy_loss,np.mean(NRBPVs)






In [107]:
class Model(nn.Module):
    def __init__(self, features: Sequence[int]):
        """Fully-connected Network

        Args:
            features: a list of ints like: [input_dim, 16, 16, output_dim]
        """
        super(Model, self).__init__()
        
        layers = []
        for i in range(len(features) - 1):
            layers.append(
                nn.Linear(
                    in_features=features[i],
                    out_features=features[i + 1],
                    )
            )
            if i != len(features) - 2:
                layers.append(nn.ReLU())
        
        self.net = nn.Sequential(*layers)
        
    def forward(self, input):
        return self.net(input)

## Training

In [None]:
config = {
    'seed': 10,
    'gamma': 1.0,
    'policy_layers': [64,16],
    'policy_learning_rate': 1e-3,
    #'value_layers': [128],
    #'value_learning_rate': 1e-3,
    'use_baseline': False,
}


agent = REINFORCEv1Agent(config)
agent.env.battery_capacity =   agent.env.Tc * 20 * 265 * 150
agent.env.node_battery_states = np.ones(agent.env.M-1) *  agent.env.battery_capacity
agent.train(n_episodes=128, n_iterations= 50,EsN0dB=5+10*np.log10(2))

## Simulation Results:

In [None]:
class REINFORCEv1Agent(REINFORCEv1Agent):

    def test_model(self, n_episodes,n_iter_per_EsN0dB,EsN0dB,next_states):
        
        # ======================================================================
        # ======================================================================
        # ==========              Testing          ============================
        # ======================================================================
        # ======================================================================
        self.policy_model = torch.load('best_model.pth')
        self.policy_model.eval()
        total_rewards = []
        avg_SER_DT = []
        avg_optimal_SER_MRC = []
        avg_SER_MRC = []
        n_time_step = n_iter_per_EsN0dB

        #for each time step
        for episode in range(n_episodes):
    
            #select an action using the policy model
            probs = self.policy_model(next_states)
            actions = torch.argmax(probs,dim = 1)
            #print(actions)
            
            #take a step
            _,list_obtained_SER_MRC,list_optimal_MRC_SER,list_DT_SER, next_states,done = self.env.training_step(n_time_step = n_time_step,actions=actions,EsN0dB_=EsN0dB,allow_battery_consumption = True)
            
            #rewards = 10*(list_SER_DT - list_SER_MRC)/np.mean(list_SER_DT + 0.000000000001)
        
            #total_rewards.append(np.sum(rewards))
            avg_SER_MRC.append(np.mean(list_obtained_SER_MRC))
            avg_optimal_SER_MRC.append(np.mean(list_optimal_MRC_SER))
            avg_SER_DT.append(np.mean(list_DT_SER))

            #total_rewards = np.array(total_rewards)

        return total_rewards,np.mean(avg_SER_MRC).round(5),np.mean(avg_optimal_SER_MRC).round(5),np.mean(avg_SER_DT).round(5)



 
    def train_and_get_SER(self, n_episodes = 10,
                            n_iter_per_EsN0dB = 1,
                            min_EbN0dB = 5,
                            max_EbN0dB = 6,
                            EbN0dB_step = 1,
                            n_training_iter = 25):



        EbN0dBs = range(min_EbN0dB,max_EbN0dB+1,EbN0dB_step)
        k = np.log2(self.env.Mod_order)
        EsN0dBs = 10*np.log10(k)+EbN0dBs # EsN0dB calculation
        
        self.env.battery_capacity = len(EbN0dBs)  * self.env.Tc * n_episodes * 2 * n_training_iter
        self.env.node_battery_states = np.ones(self.env.M-1) *  self.env.battery_capacity
        SER_sim_MRC = []
        SER_sim_optimal_MRC = []
        SER_sim_DT = []

        ##first steps for the agent 
        _,_,_,_,next_states,done = self.env.reset(n_time_step = n_iter_per_EsN0dB)
        #self.env.visualize()
        _,_,_,_,next_states,done =self.env.training_step(n_time_step = n_iter_per_EsN0dB,EsN0dB_=EsN0dBs[0])


        for EsN0dB in tqdm(EsN0dBs):
            _  = agent.train(n_episodes=n_episodes, n_iterations= n_training_iter,EsN0dB =EsN0dB,plot=False)
            total_rewards,SER_MRC,SER_optimal_MRC,SER_DT = self.test_model(n_episodes,n_iter_per_EsN0dB,EsN0dB,next_states)
            SER_sim_MRC.append(SER_MRC)
            SER_sim_optimal_MRC.append(SER_optimal_MRC)
            SER_sim_DT.append(SER_DT)


        colors = plt.cm.jet(np.linspace(0,1,12)) # colormap
        fig, ax = plt.subplots(nrows=1,ncols = 1)
        fig.set_size_inches(10, 6)

            
            
        ax.semilogy(EbN0dBs,SER_sim_DT,color = colors[3],marker='o',linestyle='-',label='DT AWGN',markersize = 15)
        ax.semilogy(EbN0dBs,SER_sim_MRC,color = colors[1],marker='o',linestyle='-',label='MRC TSMG using REINFORCE',markersize = 15)
        ax.semilogy(EbN0dBs,SER_sim_optimal_MRC,color = colors[4],marker='o',linestyle='-',label='Optimal MRC TSMG',markersize = 15)

        df = pd.DataFrame()
        df['EbN0dBs'] = EbN0dBs
        df['SER_sim_MRC_TSMG'] = SER_sim_MRC
        df['SER_sim_SD_TSMG'] = SER_sim_DT
        df.to_csv('Reinforce    .csv',index = False)

        ax.set_xlabel('Eb/N0(dB)',fontsize=15);ax.set_ylabel('SER ($P_s$)',fontsize=15)
        ax.set_title('Probability of Symbol Error over Rayleigh flat fading channel',fontsize=15)
        ax.legend()
        ax.xaxis.set_tick_params(labelsize=20)
        ax.yaxis.set_tick_params(labelsize=20)
        ax.grid(True)
        ax.set_ylim([0.00006, 0])

        fig, ax = plt.subplots(figsize = (10,6))
        print( self.env.node_battery_states[1:self.env.M-1])
        remaining_battery_level = self.env.node_battery_states[1:self.env.M-1]*100/  self.env.battery_capacity
        print( remaining_battery_level)

        #clrs = ['green' if (x > 20) else 'red' for x in remaining_battery_level ]
        barlist=ax.bar([ str(i+2) for i in np.arange(self.env.M-2)],remaining_battery_level)
        print(remaining_battery_level)

        for i in range(len(remaining_battery_level)):
            if abs(remaining_battery_level[i])<20:
                barlist[i].set_color('r')
            else:
                barlist[i].set_color('g')

        ax.set_title("Remaining Battery level for each node",fontsize = 20)
        ax.set_xlabel("nodes",fontsize = 20)
        ax.set_ylabel("Battery level (%)",fontsize = 20)
        ax.xaxis.set_tick_params(labelsize=20)
        ax.yaxis.set_tick_params(labelsize=20)
        ax.grid(True)
        plt.show()


In [None]:
config = {
    'seed': 15,
    'gamma': 1.0,
    'policy_layers': [1024],
    'policy_learning_rate': 5e-3,
}

agent = REINFORCEv1Agent(config)
agent.train_and_get_SER(n_episodes = 256,
                        n_iter_per_EsN0dB = 1,
                        n_training_iter = 60,
                        min_EbN0dB = 0,
                        max_EbN0dB = 10,
                        EbN0dB_step = 1)
agent.env.visualize()
