In [1]:
!python --version

Python 3.7.16


In [1]:
import os
os.chdir("..")
from rice import Rice
import torch
import numpy as np

from tqdm import tqdm

import copy

from PPO_copied import ActorCritic, PPO, RolloutBuffer

In [2]:
env = Rice(region_yamls_filename='fewer_region_yamls/')

In [3]:
def create_agents(env : Rice):
    agents = []
    initial_state = env.reset()
    for i in range(env.num_regions):
        agents.append(
            PPO(
                state_dim = len(initial_state[i]['features']), 
                action_dim = env.action_space[i]
            )
        )
    return agents

In [4]:
agents = create_agents(env)

In [5]:
for epoch in tqdm(range(20)):
    state = env.reset()
    for t in range(env.episode_length):
        collective_action = {}
        for agent_id in range(len(agents)):
            collective_action[agent_id] = agents[agent_id].select_action(state[agent_id])
        state, reward, _, _ = env.step(collective_action)
        for agent_id in range(len(agents)):
            agents[agent_id].buffer.rewards.append(reward[agent_id])
            
    for agent in agents:
        agent.update()

100%|███████████████████████████████████████████| 20/20 [01:30<00:00,  4.51s/it]


In [6]:
def evaluate_agents(agents):
    state = env.reset()
    actions = {i : [] for i in range(len(agents))}
    for t in range(env.episode_length):
        collective_action = {}
        for agent_id in range(len(agents)):
            action = agents[agent_id].select_action(state[agent_id])
            collective_action[agent_id] = action
            actions[agent_id].append(action)
        state, reward, _, _ = env.step(collective_action)
    return copy.deepcopy(env.global_state), actions

In [7]:
def baseline():
    return evaluate_agents(create_agents(env))

In [8]:
t, a = evaluate_agents(agents)
b, a1 = baseline()

In [9]:
t["global_temperature"]

{'value': array([[0.85      , 0.0068    ],
        [0.9886611 , 0.02788   ],
        [1.1085788 , 0.05189953],
        [1.2165788 , 0.07831651],
        [1.3185287 , 0.10677307],
        [1.4157696 , 0.13706696],
        [1.51711   , 0.16903453],
        [1.6073216 , 0.20273641],
        [1.6988846 , 0.23785104],
        [1.7900738 , 0.27437687],
        [1.8814331 , 0.3122693 ],
        [1.9729142 , 0.3514984 ],
        [2.0626662 , 0.3920338 ],
        [2.1559305 , 0.4337996 ],
        [2.2436054 , 0.47685286],
        [2.3255155 , 0.52102166],
        [2.4074383 , 0.56613404],
        [2.4916945 , 0.61216664],
        [2.5804365 , 0.65915483],
        [2.6757383 , 0.7071869 ],
        [2.767941  , 0.75640064]], dtype=float32),
 'norm': 10.0}

In [10]:
b["global_temperature"]

{'value': array([[0.85      , 0.0068    ],
        [0.9886611 , 0.02788   ],
        [1.1112729 , 0.05189953],
        [1.224051  , 0.07838386],
        [1.3282546 , 0.10702554],
        [1.4249457 , 0.13755627],
        [1.5149925 , 0.169741  ],
        [1.6060003 , 0.20337228],
        [1.6999042 , 0.23843798],
        [1.7963964 , 0.27497464],
        [1.8892277 , 0.3130102 ],
        [1.9828362 , 0.35241562],
        [2.0811887 , 0.39317614],
        [2.1757264 , 0.43537647],
        [2.2614012 , 0.4788852 ],
        [2.345068  , 0.5234481 ],
        [2.4346073 , 0.5689886 ],
        [2.5280492 , 0.6156291 ],
        [2.6199832 , 0.6634396 ],
        [2.7128935 , 0.71235317],
        [2.806301  , 0.76236665]], dtype=float32),
 'norm': 10.0}