In [1]:
import simpy
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO, A2C, DQN
import math
from packet import Packet

In [82]:
class User:
    def __init__(self,id,arrRate,channelCond):
        self.id = id
        self.buffer = []
        self.arrRate = arrRate
        self.channelCond = channelCond

In [83]:
numUsers = 2
numRB = 1
buffSize = 10
sizeMean = 12000 #bits
sizeMin = 0.25*sizeMean #64 bytes
sizeMax = 2*sizeMean #64 KB
boundMin = 0.001
boundMax = 0.005
slotTime = 1e-3
arrRateMin = 30 #packets/sec
arrRateMax = 120 #packets/sec
servRateBad = 50*1e6 #256 Mbps
servRateGood = 100*1e6 #512 Mbps
p01= 0.5
p10 = 0.5
maxSlots = 1000
numFeaturesPerPacket = 4
numObservations = 1+numFeaturesPerPacket*numUsers*buffSize

In [84]:
import time
from gymnasium import spaces
import csv
import os
class QueueSys(gym.Env):
    def __init__(self):
        self.numEpi = 0
        self.simpyEnv = simpy.Environment()
        self.buffer = []
        self.users = []
        self.arrRates = np.random.randint(1000,10000,(numUsers,))
        self.channelConds = np.random.randint(0,1,(numUsers,))
        self.slotInd = 0
        self.numSteps = 0
        self.startTime = 0
        self.nextArrTimes = np.zeros(numUsers)
        self.violations = np.zeros(numUsers)
        self.totPackets = np.zeros(numUsers)
        self.servRates=[servRateBad,servRateGood]
        self.simpyEnv.process(self.slottedTime(slotTime))  
        self.server = simpy.Resource(self.simpyEnv,capacity=1)
        self.action_space = spaces.Discrete(buffSize*numUsers)
        self.observation_space = spaces.Box(low=-maxSlots,high=maxSlots,shape=(numObservations,),dtype=np.float32)
        self.totalReward = np.zeros(numUsers)
        self.drop = np.zeros(numUsers)
        csvRow = ['Episode','Total Rewards','Total Violations','Total Packets','%Success','%Violated','Time']
        
        for i in range(numUsers):
            csvRow +=[f'Rewards{i}',f'Violations{i}',f'Packets{i}',f'Drops{i}']
            user = User(i,self.arrRates[i],self.channelConds[i])
            self.nextArrTimes[i] = np.random.exponential(1.0/user.arrRate)
            self.users.append(user)
            self.buffer+=user.buffer
            self.simpyEnv.process(self.arrProcess(self.users[i]))
        self.nextArrTime = np.min(self.nextArrTimes)
        global log_file
        if not os.path.exists(log_file):
            with open(log_file, mode='w') as file:
                writer = csv.writer(file)
                writer.writerow(csvRow)
    def arrProcess(self,user):
        while True:
            yield self.simpyEnv.timeout(self.nextArrTimes[user.id] - self.simpyEnv.now)
            interArr = np.random.exponential(1.0/user.arrRate)
            self.nextArrTimes[user.id] = self.simpyEnv.now + interArr
            self.nextArrTime = np.min(self.nextArrTimes)
            # print(f'Next arrival for user {user.id} = {self.nextArrTimes[user.id]} and minimum next arrival = {self.nextArrTime}')
            if len(user.buffer)<buffSize:
                bound = np.random.uniform(boundMin,boundMax)
                boundSlot = math.ceil(bound/slotTime)
                size = np.random.randint(sizeMin,sizeMax)
                p = Packet(genTime=self.slotInd,bound=boundSlot,size=size,userId=user.id,channelCond=self.channelConds[user.id])
                user.buffer.append(p)
                self.totPackets[user.id]+=1
                # print(f'Packet arrived for user {user.id} at {self.simpyEnv.now} with slotInd = {self.slotInd}')    
            else:
                self.drop[user.id]+=1
                # print(f'Buffer Limit Exceeded')
    def slottedTime(self,slotTime):
        while True:
            # print(f'Slot  = {self.slotInd} at {self.simpyEnv.now}')
            yield self.simpyEnv.timeout(slotTime)
            self.slotInd+=1
            for i in range(numUsers):
                self.channelConds[i] = self.GilbertElliot(self.channelConds[i],p01,p10) 
                self.users[i].channelCond = self.channelConds[i]
                for p in self.users[i].buffer:
                    p.bound-=1
                    p.channelCond = self.channelConds[i]
                # print(self.channelConds[i])
                
                # print(f'channel condition of user {i} = {self.users[i].channelCond}',end=" ")
    def servePacket(self,userId,action):
        p = self.users[userId].buffer[action]
        channelCond = self.users[userId].channelCond
        servTime = p.size/self.servRates[channelCond]
        # print(f'Time before deleting a packet for user {userId} with arrival time,bound,servTime = {(p.arrTime,p.bound,servTime)}  is {self.simpyEnv.now}')
        yield self.simpyEnv.timeout(servTime)
        self.users[userId].buffer.remove(p)
        # print(f'Time after deleting a packet  = {self.simpyEnv.now}')
        if p.arrTime + p.boundFix>=self.slotInd:
            reward = 1
        else:
            reward = 0
            self.violations[userId]+=1
        return reward
    
    def counterfacts(self,userId,action):
        pass
    def step(self,action):
        if self.numSteps ==0:
            self.startTime = time.time()
        userId = math.floor(action/buffSize)
        # print(f'User = {userId}')
        action = action%buffSize
        # print(f'Action = {action}')
        if not self.users[userId].buffer:
            self.simpyEnv.run(self.nextArrTime)
            self.simpyEnv.step()
            reward = 0
        if action<len(self.users[userId].buffer):
            p = self.users[userId].buffer[action]
            # print(f'Trying to delete packet with {(p.arrTime,p.bound,p.size)} at {self.simpyEnv.now}')
            servProc = self.simpyEnv.process(self.servePacket(userId,action))
            reward = self.simpyEnv.run(servProc)
            # print(f'Current Time = {self.simpyEnv.now}')
        else:
            reward =0
        self.numSteps+=1
        self.totalReward[userId]+= reward
        buff = []
        for user in self.users:
            buff += user.buffer
            # print(f'Length of buffer for user {user.id} is {len(user.buffer)}')
        # print(f'The length of the buffer is: {len(buff)}')
        state = self._get_obs(buff,self.slotInd,self.channelConds)
        # print(f'The current state is {state}')
        truncated = False
        terminated = bool(self.numSteps>=maxSlots)
        
        if terminated:
            self.numEpi+=1
            row = [self.numEpi, np.sum(self.totalReward), np.sum(self.violations), np.sum(self.totPackets),np.sum(self.totalReward)/np.sum(self.totPackets),np.sum(self.violations)/np.sum(self.totPackets),time.time() - self.startTime]
            for i in range(numUsers):
                row+=[self.totalReward[i],self.violations[i],self.totPackets[i],self.drop[i]]
            with open(log_file, mode='a') as file:
                writer = csv.writer(file)
                writer.writerow(row)
        info={}
        return state, reward, terminated, truncated, info
    def _get_obs(self,buffer,slotInd,channelConds):
        state = np.zeros((numObservations), dtype=np.float32)
        # print(f'The shape of state is {state.shape}')
        state[0] = len(buffer)
        for i, p in enumerate(buffer):
            time_until_bound_normalized = (p.arrTime + p.boundFix-slotInd)
            state[1 + i] = time_until_bound_normalized
            # print(f'packetSize = {p.size}')
            service_slots = math.ceil((p.size/self.servRates[channelConds[p.userId]])/slotTime)
            # print(f'service time = {service_time_normalized}')
            # state[0][1 + i] = service_time_normalized
            state[1+len(buffer)+i]  = channelConds[p.userId]
            state[1+2*len(buffer) + i] = service_slots
            state[1+3*len(buffer) + i] = slotInd
        return state
    def reset(self,seed=0,options=None):
        # print(f'Steps = {self.numSteps}')
        # print(f'Total Rewards = {self.totalReward},violations = {self.violations}, , totalPackets = {self.totPackets},sumTotReward = {np.sum(self.totalReward)}, totalViolations = {np.sum(self.violations)}, sumPackets={np.sum(self.totPackets)}')
        # print(f'Resetting')
        super().reset(seed=seed, options=options)
        self.simpyEnv = simpy.Environment()
        self.buffer = []
        self.users = []
        self.arrRates = np.random.randint(1000,10000,(numUsers,))
        self.channelConds = np.random.randint(0,1,(numUsers,))
        self.slotInd = 0
        self.numSteps = 0
        self.nextArrTimes = np.zeros(numUsers)
        self.totPackets = np.zeros(numUsers)
        self.violations = np.zeros(numUsers)
        self.servRates=[servRateBad,servRateGood]
        self.simpyEnv.process(self.slottedTime(slotTime))  
        self.server = simpy.Resource(self.simpyEnv,capacity=1)
        self.action_space = spaces.Discrete(buffSize*numUsers)
        self.observation_space = spaces.Box(low=-maxSlots,high=maxSlots,shape=(numObservations,),dtype=np.float32)
        self.totalReward = np.zeros(numUsers)
        self.drop = np.zeros(numUsers)
        for i in range(numUsers):
            user = User(i,self.arrRates[i],self.channelConds[i])
            self.nextArrTimes[i] = np.random.exponential(1.0/user.arrRate)
            self.users.append(user)
            self.buffer+=user.buffer
            self.simpyEnv.process(self.arrProcess(self.users[i]))
        self.nextArrTime = np.min(self.nextArrTimes)
        obs = self._get_obs(self.buffer,self.slotInd,self.channelConds)
        info={}
        return obs,info
    def GilbertElliot(self,cond,p01,p10):
        if cond:
            if np.random.uniform(0,1) <= p10:
                return 0
            else:
                return 1
        else:
            if np.random.uniform(0,1) <= p01:
                return 1
            else:
                return 0
            

    

In [85]:
# # from stable_baselines3 import A2C
# from stable_baselines3.common.env_checker import check_env

# global log_file 
# for i in range(5):
#     log_file = f'./csvLogs/DQNuser{numUsers}_seed{i}.csv'
#     np.random.seed(i)
#     if os.path.exists(log_file):
#         os.remove(log_file)
#     que = QueueSys()
#     check_env(que,warn=True)
#     model = DQN('MlpPolicy', que,verbose=1)
#     model.learn(total_timesteps=1000000)

# for i in range(5):
#     log_file = f'./csvLogs/PPOuser{numUsers}_seed{i}.csv'
#     np.random.seed(i)
#     if os.path.exists(log_file):
#         os.remove(log_file)
#     que = QueueSys()
#     check_env(que,warn=True)
#     model = PPO('MlpPolicy', que,verbose=1)
#     model.learn(total_timesteps=1000000)

# for i in range(5):
#     log_file = f'./csvLogs/A2Cuser{numUsers}_seed{i}.csv'
#     np.random.seed(i)
#     if os.path.exists(log_file):
#         os.remove(log_file)
#     que = QueueSys()
#     check_env(que,warn=True)
#     model = A2C('MlpPolicy', que,verbose=1)
#     model.learn(total_timesteps=1000000)

In [86]:
import copy
class QueueSysCounter(QueueSys):
    def step(self,action):
        if self.numSteps ==0:
            self.startTime = time.time()
        userId = math.floor(action/buffSize)
        # print(f'User = {userId}')
        action = action%buffSize
        # print(f'Action = {action}')
        if not self.users[userId].buffer:
            self.simpyEnv.run(self.nextArrTime)
            self.simpyEnv.step()
            reward = 0
        stateBeforeAction = dict({'userState':self.users,'userId': userId,'action': action, 'time': self.simpyEnv.now,'slotInd':self.slotInd,'channelCond':self.channelConds})
        if action<len(self.users[userId].buffer):
            p = self.users[userId].buffer[action]
            # print(f'Trying to delete packet with {(p.arrTime,p.bound,p.size)} at {self.simpyEnv.now}')
            servProc = self.simpyEnv.process(self.servePacket(userId,action))
            reward = self.simpyEnv.run(servProc)
            # print(f'Current Time = {self.simpyEnv.now}')
        else:
            reward =0
        self.numSteps+=1
        self.totalReward[userId]+= reward
        buff = []
        for user in self.users:
            buff += user.buffer
            # print(f'Length of buffer for user {user.id} is {len(user.buffer)}')
        # print(f'The length of the buffer is: {len(buff)}')
        state = self._get_obs(buff,self.slotInd,self.channelConds)
        counterFactStates = self.computeCounterFacts(stateBeforeAction)
        # print(f'The current state is {state}')
        truncated = False
        terminated = bool(self.numSteps>=maxSlots)
        if terminated:
            self.numEpi+=1
            row = [self.numEpi, np.sum(self.totalReward), np.sum(self.violations), np.sum(self.totPackets),np.sum(self.totalReward)/np.sum(self.totPackets),np.sum(self.violations)/np.sum(self.totPackets),time.time() - self.startTime]
            for i in range(numUsers):
                row+=[self.totalReward[i],self.violations[i],self.totPackets[i],self.drop[i]]
            with open(log_file, mode='a') as file:
                writer = csv.writer(file)
                writer.writerow(row)
        info={"counterFactStates": counterFactStates}
        return state, reward, terminated, truncated, info
        
    def computeCounterFacts(self,stateInfo):
        users = stateInfo["userState"]
        buffer = []
        for user in users:
            buffer+=user.buffer
        userId = stateInfo["userId"]
        action = stateInfo["action"]
        action = userId*buffSize + action
        currTime = stateInfo["time"]
        channelConds = stateInfo["channelCond"]
        slotInd = stateInfo["slotInd"]
        counterFactStates = []
        for i in range(len(buffer)):
            tempBuff = copy.copy(buffer)
            if i != action:
                p = buffer[i]
                tempBuff.remove(p)
                servSlots = math.ceil(p.size/self.servRates[channelConds[p.userId]])
                deptSlot = slotInd + servSlots
                if deptSlot > p.arrTime + p.boundFix:
                    reward = 0
                else:
                    reward = 1
                state = self._get_obs(tempBuff,deptSlot,self.channelConds)
                counterFactStates.append({"state":state,"action":i,"reward":reward})
        return counterFactStates
        

In [87]:
import numpy as np
from typing import Generator, Optional
from stable_baselines3.common.buffers import RolloutBuffer, RolloutBufferSamples
import torch as th
from stable_baselines3.common.vec_env import VecNormalize

class RolloutBufferWithCounterfactuals(RolloutBuffer):
    def __init__(self, buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=1, max_cf_length=5):
        super(RolloutBufferWithCounterfactuals, self).__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)
        self.cf_observations = []
        self.cf_actions = []
        self.cf_rewards = []
        self.cf_validity = []
        self.max_cf_length = max_cf_length
        self.cf_values=[]
        self.cf_steps = []
        self.cf_step = 0

    def reset(self):
        super(RolloutBufferWithCounterfactuals, self).reset()
        self.cf_observations = []
        self.cf_actions = []
        self.cf_rewards = []
        self.cf_validity = []
        self.cf_steps = []
        self.cf_step = 0

    def add(self, obs, action, reward, episode_start, value, log_prob,cf_state):
        """
        :param obs: Observation
        :param action: Action
        :param reward:
        :param episode_start: Start of episode signal.
        :param value: estimated value of the current state
            following the current policy.
        :param log_prob: log probability of the action
            following the current policy.
        """
        if len(log_prob.shape) == 0:
            # Reshape 0-d tensor to avoid error
            log_prob = log_prob.reshape(-1, 1)

        # Reshape needed when using multiple envs with discrete observations
        # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
        if isinstance(self.observation_space, spaces.Discrete):
            obs = obs.reshape((self.n_envs, *self.obs_shape))

        # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
        action = action.reshape((self.n_envs, self.action_dim))

        self.observations[self.pos] = np.array(obs)
        self.actions[self.pos] = np.array(action)
        self.rewards[self.pos] = np.array(reward)
        self.episode_starts[self.pos] = np.array(episode_start)
        self.values[self.pos] = value.clone().cpu().numpy().flatten()
        self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True
        if cf_state:
            # print(cf_state)
            for cf in cf_state:    
                self.cf_observations.append(cf["state"])
                self.cf_actions.append(cf["action"])
                self.cf_rewards.append(cf["reward"])    
                self.cf_validity.append(1)
                self.cf_steps.append(self.cf_step)
        else:
            self.cf_observations.append(np.zeros(numObservations,))
            self.cf_actions.append(-1)
            self.cf_rewards.append(0)
            self.cf_validity.append(0)
        self.cf_step+=1
        # print(f'cf step = {self.cf_steps} cf observation = {self.cf_observations} cf_actions = {self.cf_actions} cf_rewards = {self.cf_rewards} cf_validity = {self.cf_validity} ')
    # def getA2C(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
    #     assert self.full, ""
    #     indices = np.random.permutation(self.buffer_size * self.n_envs)
    #     # Prepare the data
    #     if not self.generator_ready:
    #         _tensor_names = [
    #             "observations",
    #             "actions",
    #             "values",
    #             "log_probs",
    #             "advantages",
    #             "returns",
    #         ]

    #         for tensor in _tensor_names:
    #             self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
    #         self.generator_ready = True

    #     # Return everything, don't create minibatches
    #     if batch_size is None:
    #         batch_size = self.buffer_size * self.n_envs

    #     start_idx = 0
    #     while start_idx < self.buffer_size * self.n_envs:
    #         yield self._get_samples(indices[start_idx : start_idx + batch_size])
    #         start_idx += batch_size
    def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
        assert self.full, ""
        indices = np.random.permutation(self.buffer_size * self.n_envs)
        # Prepare the data
        if not self.generator_ready:
            _tensor_names = [
                "observations",
                "actions",
                "values",
                "log_probs",
                "advantages",
                "returns",
            ]

            for tensor in _tensor_names:
                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            batch_size = self.buffer_size * self.n_envs

        start_idx = 0
        while start_idx < self.buffer_size * self.n_envs:
            yield self._get_samples(indices[start_idx : start_idx + batch_size])
            # yield self._get_samples(indices[start_idx : start_idx + batch_size]),self._get_cf_samples(indices[start_idx : start_idx + batch_size])
            start_idx += batch_size
    # def _get_cf_samples(self,batch_inds: np.ndarray):
    #     # print(batch_inds)
    #     # print(f'obslength = {len(self.cf_observations)}, action length = {len(self.cf_actions)}, reward length = {len(self.cf_rewards)}, validity length = {len(self.cf_validity)}')
    #     cf_obs = np.array(self.cf_observations)
    #     cf_actions = np.array(self.cf_actions)
    #     cf_rewards = np.array(self.cf_rewards)
    #     cf_validity = np.array(self.cf_validity)
    #     cf_steps = np.array(self.cf_steps)
    #     # print(cf_obs)
    #     cf_data = (cf_obs[batch_inds],cf_actions[batch_inds],cf_rewards[batch_inds],cf_validity[batch_inds],cf_steps[batch_inds]) 
    #     return cf_data
    def _get_samples(
        self,
        batch_inds: np.ndarray,
        env: Optional[VecNormalize] = None,
    ) -> RolloutBufferSamples:
        data = (
            self.observations[batch_inds],
            self.actions[batch_inds],
            self.values[batch_inds].flatten(),
            self.log_probs[batch_inds].flatten(),
            self.advantages[batch_inds].flatten(),
            self.returns[batch_inds].flatten(),
        )
        return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
    def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
            # Convert to numpy                    
        last_values = last_values.clone().cpu().numpy().flatten()
        last_gae_lam = 0
        for step in reversed(range(self.buffer_size)):
            if step == self.buffer_size - 1:
                next_non_terminal = 1.0 - dones
                next_values = last_values
            else:
                next_non_terminal = 1.0 - self.episode_starts[step + 1]
                next_values = self.values[step + 1]
            # Compute observed TD error
            delta_observed = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
            # print(f'values = {self.values[step]}, next_values = {next_values}, delta_observed = {delta_observed}')    
            # Compute counterfactual TD errors with validity check
            cf_indices = [i for i, cf_step in enumerate(self.cf_steps) if cf_step == step and self.cf_validity[i] == 1]
            if cf_indices:
                # delta_counterfactuals = np.array([
                #     self.cf_rewards[i] + self.gamma * next_values * next_non_terminal - self.values[step]
                #     for i in cf_indices
                # ])
                delta_counterfactuals = np.array([self.cf_rewards[i] for i in cf_indices])

                # print(f'delta_counterfactuals = {delta_counterfactuals.mean()}')
                delta_combined = (delta_observed - delta_counterfactuals.mean())
            else:
                delta_combined = delta_observed

            # Update GAE lambda
            last_gae_lam = delta_combined + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
            self.advantages[step] = last_gae_lam

        self.returns = self.advantages + self.values    


            # delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
            
        #     last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
        #     self.advantages[step] = last_gae_lam
        # # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
        # # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
        # self.returns = self.advantages + self.values


In [88]:
from stable_baselines3.ppo import PPO
from stable_baselines3.common.utils import obs_as_tensor
from stable_baselines3.common.utils import explained_variance
import torch as th
from torch.nn import functional as F
from stable_baselines3.common.type_aliases import RolloutReturn
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.utils import obs_as_tensor, safe_mean

class PPOCounterFact(PPO):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rollout_buffer = RolloutBufferWithCounterfactuals(
            self.n_steps, self.observation_space, self.action_space, gamma=self.gamma,
            gae_lambda=self.gae_lambda, n_envs=self.n_envs,device=self.device
        )
    def collect_rollouts(
            self,
            env: VecEnv,
            callback,
            rollout_buffer: RolloutBufferWithCounterfactuals,
            n_rollout_steps: int,
        ) -> bool:
        """
        Collect experiences using the current policy and fill a ``RolloutBuffer``.
        The term rollout here refers to the model-free notion and should not
        be used with the concept of rollout used in model-based RL or planning.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param rollout_buffer: Buffer to fill with rollouts
        :param n_rollout_steps: Number of experiences to collect per environment
        :return: True if function returned with at least `n_rollout_steps`
            collected, False if callback terminated rollout prematurely.
        """
        assert self._last_obs is not None, "No previous observation was provided"
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        n_steps = 0
        rollout_buffer.reset()
        # Sample new weights for the state dependent exploration
        if self.use_sde:
            self.policy.reset_noise(env.num_envs)

        callback.on_rollout_start()

        while n_steps < n_rollout_steps:
            if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
                # Sample a new noise matrix
                self.policy.reset_noise(env.num_envs)

            with th.no_grad():
                # Convert to pytorch tensor or to TensorDict
                obs_tensor = obs_as_tensor(self._last_obs, self.device)
                actions, values, log_probs = self.policy(obs_tensor)
            actions = actions.cpu().numpy()

            # Rescale and perform action
            clipped_actions = actions

            if isinstance(self.action_space, spaces.Box):
                if self.policy.squash_output:
                    # Unscale the actions to match env bounds
                    # if they were previously squashed (scaled in [-1, 1])
                    clipped_actions = self.policy.unscale_action(clipped_actions)
                else:
                    # Otherwise, clip the actions to avoid out of bound error
                    # as we are sampling from an unbounded Gaussian distribution
                    clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)

            new_obs, rewards, dones, infos = env.step(clipped_actions)
            # print(f'n_steps = {n_steps}')
            # print(f'last_obs = {self._last_obs}')
            # print(f'action = {clipped_actions}')
            # print(f'new_obs = {new_obs}')
            # print(f'rewards = {rewards}')
            # print(f'dones = {dones}')
            # print(f'infos = {infos}')
            self.num_timesteps += env.num_envs

            # Give access to local variables
            callback.update_locals(locals())
            if not callback.on_step():
                return False

            self._update_info_buffer(infos, dones)
            n_steps += 1

            if isinstance(self.action_space, spaces.Discrete):
                # Reshape in case of discrete action
                actions = actions.reshape(-1, 1)

            # Handle timeout by bootstraping with value function
            # see GitHub issue #633
            for idx, done in enumerate(dones):
                # print(f'idx = {idx}, done = {done}')
                if (
                    done
                    and infos[idx].get("terminal_observation") is not None
                    and infos[idx].get("TimeLimit.truncated", False)
                ):
                    terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
                    with th.no_grad():
                        terminal_value = self.policy.predict_values(terminal_obs)[0]  # type: ignore[arg-type]
                    rewards[idx] += self.gamma * terminal_value
                    print(f'terminal value = {terminal_value}, rewards = {rewards}')

            for info in infos:
                if 'counterFactStates' in info:
                    cf_states = info['counterFactStates']
                    rollout_buffer.add(
                    self._last_obs,  # type: ignore[arg-type]
                    actions,
                    rewards,
                    self._last_episode_starts,  # type: ignore[arg-type]
                    values,
                    log_probs,cf_states
                    )
            self._last_obs = new_obs  # type: ignore[assignment]
            self._last_episode_starts = dones

        with th.no_grad():
            # Compute value for the last timestep
            values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))  # type: ignore[arg-type]

        rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

        callback.update_locals(locals())

        callback.on_rollout_end()

        return True    
    
    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True
        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                values = values.flatten()
                # Normalize advantage
                advantages = rollout_data.advantages
                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                if self.normalize_advantage and len(advantages) > 1:
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob)

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None:
                    # No clipping
                    values_pred = values
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred)
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else:
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())
                
                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # cf_observations, cf_actions, cf_rewards, cf_validity, cf_steps = cf_data
                # cf_observations = th.as_tensor(cf_observations, dtype=th.float32, device=self.device)
                # cf_actions = th.as_tensor(cf_actions, dtype=th.float32, device=self.device)
                # cf_rewards = th.as_tensor(cf_rewards, dtype=th.float32, device=self.device)
                # cf_validity = th.as_tensor(cf_validity, dtype=th.bool, device=self.device)
                # cf_actions[~cf_validity] = 0
                # cf_steps = th.as_tensor(cf_steps, dtype=th.float32, device=self.device)
                # # print(f'cf_observation shape = {cf_observations[cf_validity].shape}')
                
                # cf_values, cf_log_probs, cf_entropy = self.policy.evaluate_actions(cf_observations, cf_actions)
                # cf_advantages = cf_rewards - cf_values.flatten()
                # if self.normalize_advantage and len(cf_advantages) > 1:
                #     cf_advantages = (cf_advantages - cf_advantages.mean()) / (cf_advantages.std() + 1e-8)

                # # Counterfactual policy loss
                # cf_ratio = th.exp(cf_log_probs - log_prob)
                # cf_policy_loss_1 = cf_validity * cf_advantages * cf_ratio
                # cf_policy_loss_2 = cf_validity * cf_advantages * th.clamp(cf_ratio, 1 - clip_range, 1 + clip_range)
                # cf_policy_loss = -th.min(cf_policy_loss_1, cf_policy_loss_2).mean()

                # # Counterfactual value loss
                # cf_value_loss = F.mse_loss(cf_rewards, cf_values.flatten())

                # # Counterfactual entropy loss
                # if cf_entropy is None:
                #     cf_entropy_loss = -th.mean(-cf_log_probs)
                # else:
                #     cf_entropy_loss = -th.mean(cf_entropy)

                # # Combined loss
                # cf_coeff = 0.0
                # loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
                # loss += cf_coeff*(cf_policy_loss + self.ent_coef * cf_entropy_loss + self.vf_coef * cf_value_loss)
                
                # cf_observations, cf_actions, cf_rewards, cf_validity,cf_steps = cf_data
                # valid_cf_indices = np.nonzero(cf_validity)
                # print(f'Training step = {epoch},  ')
                # if len(valid_cf_indices) > 0:
                #     cf_observations = th.cat([th.tensor(cf_observations[idx], dtype=th.float32).to(self.device) for idx in valid_cf_indices])
                #     cf_actions = th.cat([th.tensor(cf_actions[idx], dtype=th.float32).to(self.device) for idx in valid_cf_indices])
                #     cf_rewards = th.cat([th.tensor(cf_rewards[idx], dtype=th.float32).to(self.device) for idx in valid_cf_indices])

                #     cf_values, cf_log_prob, cf_entropy = self.policy.evaluate_actions(cf_observations, cf_actions)

                #     cf_advantages = cf_rewards - cf_values.flatten()
                #     cf_advantages = (cf_advantages - cf_advantages.mean()) / (cf_advantages.std() + 1e-8)

                #     # Counterfactual policy loss
                #     cf_ratio = th.exp(cf_log_prob)
                #     cf_policy_loss_1 = cf_advantages * cf_ratio
                #     cf_policy_loss_2 = cf_advantages * th.clamp(cf_ratio, 1 - clip_range, 1 + clip_range)
                #     cf_policy_loss = -th.min(cf_policy_loss_1, cf_policy_loss_2).mean()

                #     # Counterfactual value loss
                #     cf_value_loss = F.mse_loss(cf_rewards, cf_values.flatten())

                #     # Counterfactual entropy loss
                #     cf_entropy_loss = -th.mean(cf_entropy)

                #     # Combined loss
                #     cf_loss = cf_policy_loss + self.ent_coef * cf_entropy_loss + self.vf_coef * cf_value_loss
                #     cf_coeff = 0.01
                #     loss += cf_coeff*cf_loss


                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

            self._n_updates += 1
            if not continue_training:
                break

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

In [89]:
# from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
np.random.seed(0)
global log_file 
for i in range(5): 
    log_file = f'./csvLogs/PPO_{p01:.1f}_{p10:.1f}_{numUsers}_seed{i}.csv'
    if os.path.exists(log_file):
        os.remove(log_file)
    que = QueueSysCounter()
    check_env(que,warn=True)
    model = PPO('MlpPolicy', que,verbose=1)
    model.learn(total_timesteps=400000)
    fileName = f'./models/CPPO_{p01:.1f}_{p10:.1f}_{numUsers}_seed{i}'
    model.save(fileName)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1e+03    |
|    ep_rew_mean     | 354      |
| time/              |          |
|    fps             | 2125     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1e+03       |
|    ep_rew_mean          | 472         |
| time/                   |             |
|    fps                  | 1456        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.028518341 |
|    clip_fraction        | 0.174       |
|    clip_range           | 0.2         |
|    entropy_loss   

In [90]:
from stable_baselines3 import A2C
from stable_baselines3.common.utils import obs_as_tensor
from stable_baselines3.common.utils import explained_variance
import torch as th
from torch.nn import functional as F
from stable_baselines3.common.type_aliases import RolloutReturn
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.utils import obs_as_tensor, safe_mean

class A2CCounterFact(A2C):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rollout_buffer = RolloutBufferWithCounterfactuals(
            self.n_steps, self.observation_space, self.action_space, gamma=self.gamma,
            gae_lambda=self.gae_lambda, n_envs=self.n_envs,device=self.device
        )
    def collect_rollouts(
            self,
            env: VecEnv,
            callback,
            rollout_buffer: RolloutBufferWithCounterfactuals,
            n_rollout_steps: int,
        ) -> bool:
        """
        Collect experiences using the current policy and fill a ``RolloutBuffer``.
        The term rollout here refers to the model-free notion and should not
        be used with the concept of rollout used in model-based RL or planning.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param rollout_buffer: Buffer to fill with rollouts
        :param n_rollout_steps: Number of experiences to collect per environment
        :return: True if function returned with at least `n_rollout_steps`
            collected, False if callback terminated rollout prematurely.
        """
        assert self._last_obs is not None, "No previous observation was provided"
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        n_steps = 0
        rollout_buffer.reset()
        # Sample new weights for the state dependent exploration
        if self.use_sde:
            self.policy.reset_noise(env.num_envs)

        callback.on_rollout_start()

        while n_steps < n_rollout_steps:
            if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
                # Sample a new noise matrix
                self.policy.reset_noise(env.num_envs)

            with th.no_grad():
                # Convert to pytorch tensor or to TensorDict
                obs_tensor = obs_as_tensor(self._last_obs, self.device)
                actions, values, log_probs = self.policy(obs_tensor)
            actions = actions.cpu().numpy()

            # Rescale and perform action
            clipped_actions = actions

            if isinstance(self.action_space, spaces.Box):
                if self.policy.squash_output:
                    # Unscale the actions to match env bounds
                    # if they were previously squashed (scaled in [-1, 1])
                    clipped_actions = self.policy.unscale_action(clipped_actions)
                else:
                    # Otherwise, clip the actions to avoid out of bound error
                    # as we are sampling from an unbounded Gaussian distribution
                    clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)

            new_obs, rewards, dones, infos = env.step(clipped_actions)
            # print(f'n_steps = {n_steps}')
            # print(f'last_obs = {self._last_obs}')
            # print(f'action = {clipped_actions}')
            # print(f'new_obs = {new_obs}')
            # print(f'rewards = {rewards}')
            # print(f'dones = {dones}')
            # print(f'infos = {infos}')
            self.num_timesteps += env.num_envs

            # Give access to local variables
            callback.update_locals(locals())
            if not callback.on_step():
                return False

            self._update_info_buffer(infos, dones)
            n_steps += 1

            if isinstance(self.action_space, spaces.Discrete):
                # Reshape in case of discrete action
                actions = actions.reshape(-1, 1)

            # Handle timeout by bootstraping with value function
            # see GitHub issue #633
            for idx, done in enumerate(dones):
                # print(f'idx = {idx}, done = {done}')
                if (
                    done
                    and infos[idx].get("terminal_observation") is not None
                    and infos[idx].get("TimeLimit.truncated", False)
                ):
                    terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
                    with th.no_grad():
                        terminal_value = self.policy.predict_values(terminal_obs)[0]  # type: ignore[arg-type]
                    rewards[idx] += self.gamma * terminal_value
                    print(f'terminal value = {terminal_value}, rewards = {rewards}')

            for info in infos:
                if 'counterFactStates' in info:
                    cf_states = info['counterFactStates']
                    rollout_buffer.add(
                    self._last_obs,  # type: ignore[arg-type]
                    actions,
                    rewards,
                    self._last_episode_starts,  # type: ignore[arg-type]
                    values,
                    log_probs,cf_states
                    )
            self._last_obs = new_obs  # type: ignore[assignment]
            self._last_episode_starts = dones

        with th.no_grad():
            # Compute value for the last timestep
            values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))  # type: ignore[arg-type]

        rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

        callback.update_locals(locals())

        callback.on_rollout_end()

        return True
    def train(self) -> None:
        """
        Update policy using the currently gathered
        rollout buffer (one gradient step over whole data).
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)

        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)

        # This will only loop once (get all data in one go)
        for rollout_data in self.rollout_buffer.get(batch_size=None):
            actions = rollout_data.actions
            if isinstance(self.action_space, spaces.Discrete):
                # Convert discrete action from float to long
                actions = actions.long().flatten()

            values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
            values = values.flatten()

            # Normalize advantage (not present in the original implementation)
            advantages = rollout_data.advantages
            if self.normalize_advantage:
                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            # Policy gradient loss
            policy_loss = -(advantages * log_prob).mean()

            # Value loss using the TD(gae_lambda) target
            value_loss = F.mse_loss(rollout_data.returns, values)

            # Entropy loss favor exploration
            if entropy is None:
                # Approximate entropy when no analytical form
                entropy_loss = -th.mean(-log_prob)
            else:
                entropy_loss = -th.mean(entropy)

            loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

            # Optimization step
            self.policy.optimizer.zero_grad()
            loss.backward()

            # Clip grad norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        self._n_updates += 1
        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/explained_variance", explained_var)
        self.logger.record("train/entropy_loss", entropy_loss.item())
        self.logger.record("train/policy_loss", policy_loss.item())
        self.logger.record("train/value_loss", value_loss.item())
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

In [91]:
from stable_baselines3.common.env_checker import check_env
np.random.seed(0)
global log_file 
for i in range(5): 
    log_file = f'./csvLogs/A2C_{p01:.1f}_{p10:.1f}_{numUsers}_seed{i}.csv'
    if os.path.exists(log_file):
        os.remove(log_file)
    que = QueueSysCounter()
    check_env(que,warn=True)
    model = A2C('MlpPolicy', que,verbose=1)
    model.learn(total_timesteps=400000)
    fileName = f'./models/A2C_{p01:.1f}_{p10:.1f}_{numUsers}_seed{i}'
    model.save(fileName)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
------------------------------------
| time/                 |          |
|    fps                | 1381     |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -2.26    |
|    explained_variance | 0.593    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 1.78     |
|    value_loss         | 1.02     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1e+03    |
|    ep_rew_mean        | 462      |
| time/                 |          |
|    fps                | 1348     |
|    iterations         | 200      |
|    time_elapsed       | 0        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss 