In [1]:
!python --version # should say 3.7.16

Python 3.7.16


In [2]:
import os
from agent import Agent, recursive_obs_dict_to_spaces_dict
from action_net import ActionNet
os.chdir("..")
from rice import Rice

import torch
from collections import deque
import numpy as np

from tqdm import tqdm

from typing import List

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
env = Rice(negotiation_on = False)

In [4]:
initial_state = env.reset()

In [5]:
def get_features_n() -> int:
    """
    Take the number features from the observation that are relevant for the actions. (and not for the negotiation)
    """
    return sum(np.prod(env.global_state[feature]['value'].shape[1:]) 
               for feature in 
               env.global_features + env.public_features) + env.num_regions + len(env.private_features)

features_n = get_features_n()
features_n

77

In [6]:
def get_action_spaces_n() -> int:
    """
    Take the number of action spaces that are relevant for the actions. (and not for the negotiation)
    """
    return len(env.savings_action_nvec + env.mitigation_rate_action_nvec + 
              env.export_action_nvec + env.import_actions_nvec + env.tariff_actions_nvec)

action_spaces_n = get_action_spaces_n()
action_spaces_n

11

In [7]:
def get_actions_n(agent_id : int) -> int:
    """
    Take the total number of actions in the relevant action spaces.
    """
    return sum(space.n for space in env.action_space[agent_id][:action_spaces_n])

actions_n = get_actions_n(agent_id = 0)
actions_n

110

In [8]:
def trim_observation(state : dict, agent_id : int):
    """
    Slice the observation such that it includes features only relevant for taking actions
    """
    observation = state[agent_id].copy()
    observation['features'] = observation['features'][:features_n]
    observation['action_mask'] = observation['action_mask'][:actions_n]
    return observation

def get_observation_space(agent_id : int):
    """
    Get the observation space that is relevant for taking actions
    """
    observation = trim_observation(initial_state, agent_id)
    return recursive_obs_dict_to_spaces_dict(observation)

def get_action_space(agent_id : int):
    """
    Get the action space that is relevant for takin actions
    """
    return env.action_space[agent_id][:action_spaces_n]
    

In [9]:
def create_agents() -> List[Agent]:
    agents = []
    for agent_id in initial_state:
        agents.append(
            Agent(
                observation_space = get_observation_space(agent_id),
                action_space = get_action_space(agent_id),
                id = agent_id
            )
        )
    return agents

In [10]:
agents = create_agents()

In [11]:
collective_action = {}
for agent in agents:
    action, log_prob = agent.act(0, trim_observation(initial_state, agent.id))
    #saved_log_probs[agent.id].append(log_prob)
    collective_action[agent.id] = np.array(action)

state, reward, done, _ = env.step(collective_action)

In [12]:
collective_action

{0: array([5, 3, 8, 6, 7, 0, 6, 4, 5, 6, 2]),
 1: array([8, 8, 2, 0, 3, 5, 8, 4, 4, 7, 0]),
 2: array([6, 2, 0, 0, 4, 0, 5, 7, 2, 3, 2]),
 3: array([8, 8, 6, 7, 4, 2, 5, 0, 3, 2, 4])}

In [13]:
# Adapted from https://huggingface.co/deep-rl-course/unit4/hands-on?fw=pt
def reinforce(agents : List[Agent], 
              n_training_episodes : int, 
              gamma : float) -> None:
    
    optimizers = {agent.id : torch.optim.Adam(agent.nets[0].parameters(), lr=.0005) for agent in agents}
    
    scores_deque = deque(maxlen=100)
    scores = []
    
    for i_episode in tqdm(range(1, n_training_episodes+1)):
        saved_log_probs = {agent.id : [] for agent in agents}
        rewards = {agent.id : [] for agent in agents}
        state = env.reset()
        
        # Generate a whole episode
        for t in range(env.episode_length):
            
            collective_action = {}
            
            for agent in agents:
                action, log_prob = agent.act(0, trim_observation(state, agent.id))
                saved_log_probs[agent.id].append(log_prob)
                collective_action[agent.id] = np.array(action)
                
            state, reward, done, _ = env.step(collective_action)
            
            for agent in agents:
                rewards[agent.id].append(reward[agent.id])
        
        returns = {agent.id : deque(maxlen=env.episode_length) for agent in agents} 
        
        # Calculate discounted returns
        for t in range(env.episode_length)[::-1]:
            for agent in agents:
                disc_return_t = (returns[agent.id][0] if len(returns[agent.id])>0 else 0)
                returns[agent.id].appendleft( gamma*disc_return_t + rewards[agent.id][t]   )    
            
        eps = np.finfo(np.float32).eps.item()
        
        # Standardize returns
        returns = {agent.id : torch.tensor(returns[agent.id]) for agent in agents}
        for agent in agents:
            returns[agent.id] = (returns[agent.id] - returns[agent.id].mean()) / (returns[agent.id].std() + eps)
        
        # Calculate loss and update weights
        policy_loss = {agent.id : [] for agent in agents}
        for agent in agents:
            for log_prob, disc_return in zip(saved_log_probs[agent.id], returns[agent.id]):
                policy_loss[agent.id].append(-log_prob * disc_return)
            loss = torch.cat(policy_loss[agent.id]).sum()
            
            optimizers[agent.id].zero_grad()
            loss.backward()
            optimizers[agent.id].step()

In [14]:
reinforce(agents, n_training_episodes = 100, gamma = 1.)

100%|█████████████████████████████████████████| 100/100 [00:29<00:00,  3.43it/s]


In [25]:
def evaluate_agents(agents : List[Agent]) -> dict:
    state = env.reset()
    for i in range(env.episode_length):
        collective_action = {}
    
        for agent in agents:
            action, _ = agent.act(0, state[agent.id])
            collective_action[agent.id] = np.array(action)
        state, reward, done, _ = env.step(collective_action)
    return env.global_state

In [16]:
def baseline() -> dict:
    return evaluate_agents(create_agents())