In [1]:
!python --version # should say 3.7.16

Python 3.7.16


In [20]:
import os
from agent import Agent, recursive_obs_dict_to_spaces_dict
from action_net import ActionNet
os.chdir("..")
from rice import Rice

import torch
from collections import deque
import numpy as np

from tqdm import tqdm

In [3]:
env = Rice()

In [46]:
agents[0].act(0, env.reset()[0])

((3, 7, 3, 9, 5, 8, 9, 2, 1, 7, 0),
 tensor([-0.5380, -2.8160, -2.3809, -1.9160, -2.9805, -0.6615, -1.2168, -0.8940,
         -1.8528, -2.4258, -2.0461], grad_fn=<CatBackward>))

In [4]:
def create_agents():
    initial_observation = env.reset()
    agents = []
    for key in initial_observation:
        agents.append(
            Agent(
                recursive_obs_dict_to_spaces_dict(initial_observation[0]),
                env.action_space[0],
                id = key
            )
        )
    return agents

In [5]:
agents = create_agents()

In [79]:
def reinforce(agents, n_training_episodes, gamma):
    
    optimizers = {agent.id : torch.optim.Adam(agent.nets[0].parameters(), lr=.0005) for agent in agents}
    
    scores_deque = deque(maxlen=100)
    scores = []
    
    for i_episode in tqdm(range(1, n_training_episodes+1)):
        saved_log_probs = {agent.id : [] for agent in agents}
        rewards = {agent.id : [] for agent in agents}
        state = env.reset()
        
        # Generate a whole episode
        for t in range(env.episode_length):
            
            collective_action = {}
            
            for agent in agents:
                action, log_prob = agent.act(0, state[agent.id])
                saved_log_probs[agent.id].append(log_prob)
                collective_action[agent.id] = np.array(action)
                
            state, reward, done, _ = env.step(collective_action)
            
            for agent in agents:
                rewards[agent.id].append(reward[agent.id])
        
        returns = {agent.id : deque(maxlen=env.episode_length) for agent in agents} 
        
        # Calculate discounted returns
        for t in range(env.episode_length)[::-1]:
            for agent in agents:
                disc_return_t = (returns[agent.id][0] if len(returns[agent.id])>0 else 0)
                returns[agent.id].appendleft( gamma*disc_return_t + rewards[agent.id][t]   )    
            
        eps = np.finfo(np.float32).eps.item()
        
        # Standardize returns
        returns = {agent.id : torch.tensor(returns[agent.id]) for agent in agents}
        for agent in agents:
            returns[agent.id] = (returns[agent.id] - returns[agent.id].mean()) / (returns[agent.id].std() + eps)
        
        # Calculate loss and update weights
        policy_loss = {agent.id : [] for agent in agents}
        for agent in agents:
            for log_prob, disc_return in zip(saved_log_probs[agent.id], returns[agent.id]):
                policy_loss[agent.id].append(-log_prob * disc_return)
            loss = torch.cat(policy_loss[agent.id]).sum()
            
            optimizers[agent.id].zero_grad()
            loss.backward()
            optimizers[agent.id].step()

In [86]:
reinforce(agents, n_training_episodes = 100, gamma = 1.)

100%|█████████████████████████████████████████| 100/100 [00:30<00:00,  3.26it/s]


In [87]:
def evaluate_agents(agents):
    state = env.reset()
    for i in range(env.episode_length):
        collective_action = {}
    
        for agent in agents:
            action, _ = agent.act(0, state[agent.id])
            collective_action[agent.id] = np.array(action)
                
        state, reward, done, _ = env.step(collective_action)
    return env.global_state

In [88]:
def baseline():
    return evaluate_agents(create_agents())

In [89]:
evaluate_agents(agents)["reward_all_regions"]

{'value': array([[0.        , 0.        , 0.        , 0.        ],
        [0.05850608, 0.72980726, 0.27089572, 0.05207098],
        [0.05331389, 0.6906957 , 0.30329227, 0.06162233],
        [0.05917268, 0.692366  , 0.39722022, 0.08164351],
        [0.05662945, 0.8340226 , 0.28000402, 0.0656942 ],
        [0.05810888, 0.8365032 , 0.3569158 , 0.06932385],
        [0.06290881, 0.67058283, 0.44530904, 0.08765482],
        [0.05929073, 0.7847819 , 0.45354325, 0.085191  ],
        [0.05993723, 0.6306677 , 0.41380388, 0.07190837],
        [0.06280081, 0.8712302 , 0.4101105 , 0.0798427 ],
        [0.06518678, 0.80571026, 0.41909185, 0.09194484],
        [0.06158277, 0.8905454 , 0.40946323, 0.08501077],
        [0.06207387, 0.8796677 , 0.41144574, 0.08396596],
        [0.0669051 , 0.82355547, 0.41746083, 0.08134782],
        [0.06300753, 0.73157513, 0.4340559 , 0.08584155],
        [0.06654098, 0.8901327 , 0.39776883, 0.08433434],
        [0.06389566, 0.8599195 , 0.3237495 , 0.08570278],
     

In [90]:
baseline()["reward_all_regions"]

{'value': array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 4.09122370e-02,  6.24348581e-01,  5.51530644e-02,
          4.87187319e-03],
        [ 5.10071702e-02,  7.70902336e-01,  2.74724722e-01,
          8.67050067e-02],
        [ 4.52505685e-02,  7.77053893e-01,  3.52690101e-01,
          6.11822903e-02],
        [ 1.15539385e-02,  7.11685836e-01, -0.00000000e+00,
          6.71878085e-02],
        [ 6.78033829e-02,  6.76693901e-08,  4.69006866e-01,
          5.21093234e-02],
        [ 2.24424917e-02,  8.25839460e-01,  4.58496183e-01,
          3.26789566e-03],
        [ 6.15346693e-02,  7.92625546e-01,  3.63857120e-01,
          9.51857343e-02],
        [ 2.65941829e-08,  8.75564933e-01,  5.12179971e-01,
          6.24971837e-02],
        [ 5.47525883e-02,  6.29430473e-01,  1.33975863e-01,
          6.51712045e-02],
        [ 3.38395163e-02,  8.71820867e-01,  5.00552118e-01,
          2.29037385e-02],
        [ 6.65325522e-02,  8.62298