In [None]:
class QNAgent_constraints:
    def __init__(self, constraints_size, N, model, criterion, 
                 optimizer, device, egreedy, egreedy_decay, 
                 egreedy_min, gamma, weight_E, weight_params, weight_simp = 0):       
        
        # Input size
        self.state_size = constraints_size

        # Output size
        self.constraints_size = constraints_size
        self.N = N
        self.action_size = self.constraints_size 
        
        
        # Parameters
        self.memory = deque(maxlen=10000)
        self.gamma = gamma    # discount rate
        self.epsilon = egreedy  # exploration rate
        self.epsilon_min = egreedy_min
        self.epsilon_decay = egreedy_decay
        self.model = model   
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
        
    def remember(self, state, action, reward, next_state, done):      
        
        info = [torch.FloatTensor(state).reshape(1, self.state_size).to(self.device), 
                torch.tensor([action], device = self.device), 
                torch.FloatTensor([reward]).to(self.device), 
                torch.FloatTensor(next_state).reshape(1, self.state_size).to(self.device), 
                done]
        self.memory.append(self.Transition(*info))
        
    def act(self, state):        
        state = torch.FloatTensor(state).reshape(1, self.state_size).to(self.device)
        
        if np.random.rand() <= self.epsilon:            
            action = random.randrange(self.constraints_size)                                     
            return action
                
        action = int(torch.argmax(self.model(state)))
        
        return action
    
    def replay_with_loop(self, batch_size):
        ''' Same as replay but with a for loop, which is less optimized'''
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward            
            
            if not done:
                target = reward + self.gamma * torch.max(self.model(next_state))
             
            output = self.model(state)
            target_f = output.clone()
            target_f[0, action] = target
            
            self.optimizer.zero_grad()
            loss = self.criterion(output, target_f)
            loss.backward()
            self.optimizer.step()
            
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
    def replay(self, batch_size):    
        transitions = random.sample(self.memory, batch_size)        
        batch = self.Transition(*zip(*transitions))    
        
        
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.done)), 
                                      device=self.device, dtype=torch.uint8)        
        non_final_next_states = torch.cat([s for s in batch.next_state
                                                        if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        state_action_values = self.model(state_batch).gather(1, action_batch.reshape(batch_size, 1))

        next_state_values = reward_batch

        next_state_values[non_final_mask] = self.model(non_final_next_states).max(1)[0].detach()* self.gamma + reward_batch[non_final_mask]        
                
        # Optimize the model
        self.optimizer.zero_grad()
        loss = self.criterion(state_action_values, next_state_values.unsqueeze(1))
        loss.backward()
        self.optimizer.step()    
            
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
    def past_reward(self, mem_array, value, mean = True, en = False):        
        ''' Calculates the reward for a given feature cccc. If mean = True, it 
        it gives reward if the current value is < than the mean of previous. If 
        mean = False, it searches for the minimal value (this one does not work
        very well yet...)'''  
        
        if mean: # Comparing to previous mean  
            if en: # We want to maximize energy
                if value > sum(mem_array)/len(mem_array):
                    return 1
                else:
                    return 0 
            else: # We want to minimize parameters
                if value < sum(mem_array)/len(mem_array):
                    return 1
                else:
                    return 0 
        else: # Searching for minimal
            
            ''' IMPLEMENT MAXIMIZING ENERGY '''
            if value <= min(mem_array):
                return 1
            else:
                return 0 
            
    def init_dense_reward(self, freeP, energy):
        ''' Initiates variables that will be used in dense_reward.'''
        self.prevP = freeP
        self.prevE = energy
            
    def discrete_mean_reward(self, env, freeP, energy, simp, err): 
        ''' This function gives a reward by comparing the current value of a variable
        to the mean value of the previous values, through the past_reward function.
        Here, the rewards rew_t and rew_E will be discrete.'''
        if err == 0: # if there was error calculating the energy
            rew = 0
        else: # if there was no error, we save the energy and the # of params for
            # Time reward (now a function of params)
            rew_t = self.past_reward(env.mem_freeP, freeP)
            # Energy reward            
            rew_E = self.past_reward(env.mem_E, energy, en = True)
            # Total reward
            rew = self.weight_E*rew_E + self.weight_params*rew_t + self.weight_simp*simp         
        return rew 
    
    def discrete_previous_reward(self, env, freeP, energy, simp, err): 
        ''' This function gives a reward by comparing the current value of a variable
        to the previous value. Here, the rewards rew_t and rew_E will be discrete.'''
        if err == 0: # if there was error calculating the energy
            rew = 0
        else: # if there was no error, we save the energy and the # of params for
            # Time reward (now a function of params)
            rew_t = rew_E = 0
            if self.prevP > freeP:
                rew_t = 1
            if self.prevE < energy:
                rew_E = 1
            # Total reward
            rew = self.weight_E*rew_E + self.weight_params*rew_t + self.weight_simp*simp         
        return rew 

    
    def dense_reward(self, env, freeP, energy, simp, err):
        ''' This functions gives rewards by comparing the current value of a variable
        to the previous one. Here, the rewards rew_t and rew_E will be continuous.'''
        if err == 0:
            rew = 0
        else: # We make it such that, the smaller the variable wrt previous value,
              # the bigger the reward.
            rew_t = self.prevP - freeP
            rew_E = self.prevE - energy
            rew = self.weight_E*rew_E + self.weight_params*rew_t + self.weight_simp*simp
            
        self.prevP = freeP
        self.prevE = energy
        
        return rew