In [1]:
!python --version

Python 3.7.16


In [1]:
import os
os.chdir("..")
from rice import Rice
import torch
import numpy as np

from tqdm import tqdm

import copy

from PPO_copied import ActorCritic, PPO, RolloutBuffer

In [2]:
env = Rice(region_yamls_filename='fewer_region_yamls/')

In [3]:
def create_agents(env : Rice):
    agents = []
    initial_state = env.reset()
    for i in range(env.num_regions):
        agents.append(
            PPO(
                state_dim = len(initial_state[i]['features']), 
                action_dim = env.action_space[i]
            )
        )
    return agents

In [4]:
agents = create_agents(env)

In [5]:
epochs = 100
batch_size = 8

In [6]:
for epoch in tqdm(range(epochs)):
    for batch in range(batch_size):
        state = env.reset()
        for t in range(env.episode_length):
            collective_action = {}
            for agent_id in range(len(agents)):
                collective_action[agent_id] = agents[agent_id].select_action(state[agent_id])
            state, reward, _, _ = env.step(collective_action)
            for agent_id in range(len(agents)):
                agents[agent_id].buffer.rewards.append(reward[agent_id])
                agents[agent_id].buffer.is_terminals.append(t == env.episode_length - 1)
            
    for agent in agents:
        agent.update()

100%|█████████████████████████████████████████| 100/100 [02:43<00:00,  1.63s/it]


In [7]:
def evaluate_agents(agents):
    state = env.reset()
    actions = {i : [] for i in range(len(agents))}
    for t in range(env.episode_length):
        collective_action = {}
        for agent_id in range(len(agents)):
            action = agents[agent_id].select_action(state[agent_id])
            collective_action[agent_id] = action
            actions[agent_id].append(action)
        state, reward, _, _ = env.step(collective_action)
    return copy.deepcopy(env.global_state), actions

In [8]:
def baseline():
    return evaluate_agents(create_agents(env))

In [9]:
t, a = evaluate_agents(agents)
b, a1 = baseline()

In [10]:
t["global_temperature"]

{'value': array([[0.85      , 0.0068    ],
        [0.9886611 , 0.02788   ],
        [1.1111463 , 0.05189953],
        [1.220285  , 0.0783807 ],
        [1.3222697 , 0.1069283 ],
        [1.4201028 , 0.13731183],
        [1.5159597 , 0.1693816 ],
        [1.6031895 , 0.20304605],
        [1.686077  , 0.23804964],
        [1.7684636 , 0.27425033],
        [1.8498844 , 0.31160566],
        [1.9285836 , 0.35006264],
        [2.007049  , 0.38952565],
        [2.085516  , 0.42996374],
        [2.1595397 , 0.47135255],
        [2.2326097 , 0.51355726],
        [2.3051276 , 0.5565336 ],
        [2.3765416 , 0.60024846],
        [2.4474492 , 0.64465576],
        [2.5204687 , 0.6897256 ],
        [2.5922585 , 0.73549414]], dtype=float32),
 'norm': 10.0}

In [11]:
b["global_temperature"]

{'value': array([[0.85      , 0.0068    ],
        [0.9886611 , 0.02788   ],
        [1.1099201 , 0.05189953],
        [1.2210369 , 0.07835004],
        [1.3243413 , 0.10691722],
        [1.4231193 , 0.13735282],
        [1.5202795 , 0.16949698],
        [1.6136744 , 0.20326655],
        [1.7061995 , 0.23852675],
        [1.795763  , 0.27521858],
        [1.8858715 , 0.31323218],
        [1.9733851 , 0.35254815],
        [2.0592983 , 0.3930691 ],
        [2.1479213 , 0.4347248 ],
        [2.2294114 , 0.4775547 ],
        [2.3113794 , 0.5213511 ],
        [2.3914547 , 0.5661018 ],
        [2.4662201 , 0.6117356 ],
        [2.53875   , 0.6580977 ],
        [2.6157198 , 0.705114  ],
        [2.6963778 , 0.75287914]], dtype=float32),
 'norm': 10.0}

In [12]:
acts, logits, value = agents[0].policy.act(torch.FloatTensor(env.reset()[0]['features']))

In [13]:
logits.softmax(0)

tensor([0.2121, 0.0351, 0.1010, 0.0927, 0.0277, 0.0246, 0.1008, 0.1045, 0.1000,
        0.0666, 0.1349])