In [1]:
import gym
from model import Policy
from storage import RolloutStorage
from wrappers import  TimeLimit
from a2c import A2C
import torch
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

In [2]:
def make_env(env, time_limit = 5):
    env = TimeLimit(env, time_limit)
    return env

In [3]:
env = gym.make("rware:rware-tiny-2ag-v1")
writer = SummaryWriter("test")
obs_space = env.observation_space
action_space = env.action_space

In [4]:
agents = [A2C(i, obs_space, action_space) for i in range(env.n_agents)]

10
10


In [5]:
obs = env.reset()

In [6]:
#Get the initial obs
for i in range(len(obs)):
    agents[i].storage.obs[0].copy_(torch.tensor(obs[i]))
    agents[i].storage.to('cpu')


for j in range(100):
    #N-steps rollout
    for _ in range(5):
        #Get the action
        with torch.no_grad():
            n_value, n_action, n_log_probs = zip( *[agent.model.act(agent.storage.obs[0]) for agent in agents] )
            n_action = [action.item() for action in n_action]

        #Step on env
        obs, reward, done, infos = env.step(n_action)
        
        #Get the mask
        masks = torch.tensor([[0.0] if done_ else [1.0] for done_ in done])

        #Copy the state transition to agent's on-policy storage
        for i in range(len(agents)):
            agents[i].storage.insert(torch.tensor(obs[i]),
                                    torch.tensor(n_action[i]),
                                    torch.tensor(n_log_probs[i]), 
                                    n_value[i],
                                    torch.tensor(reward[i]),
                                    masks[i])

    print('rollout finished')
    for agent in agents:
        agent.compute_returns()

    for agent in agents:
        loss = agent.update([a.storage for a in agents])
        for k, v in loss.items():
            writer.add_scalar(f"agent{agent.agent_id}/{k}", v, j)

    for agent in agents:
        agent.storage.after_update()
    
    


rollout finished
1
1
rollout finished
1
1
rollout finished
1
1


  torch.tensor(n_log_probs[i]),


rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finished
1
1
rollout finis

In [7]:
loss

{'policy_loss': -1.442805528640747,
 'value_loss': 0.3882187306880951,
 'dist_entropy': 0.016044992208480834,
 'importance_sampling': 1.001254916191101,
 'seac_policy_loss': -1.5284276008605957,
 'seac_value_loss': 0.4578359127044678}

In [8]:
agents[0].storage.obs

tensor([[[6., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[6., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [9]:
agents[0].storage.obs[0]

tensor([[6., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [10]:
agents[0].storage.obs[-1]

tensor([[6., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])