Copyright (c) 2022, salesforce.com, inc and MILA.  
All rights reserved.  
SPDX-License-Identifier: BSD-3-Clause  
For full license text, see the LICENSE file in the repo root  
or https://opensource.org/licenses/BSD-3-Clause  

Get started quickly with end-to-end multi-agent RL using WarpDrive! This shows a basic example to create a simple Rice environment and perform training.

**Try this notebook on [Colab](http://colab.research.google.com/github/salesforce/warp-drive/blob/master/tutorials/simple-end-to-end-example.ipynb)!**

## ⚠️ PLEASE NOTE:
This notebook runs on a GPU runtime.\
If running on Colab, choose Runtime > Change runtime type from the menu, then select `GPU` in the 'hardware accelerator' dropdown menu.

### Dependencies

First, install the WarpDrive package

In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
# !pip install rl-warp-drive

In [None]:
import os
import torch

from rice_cuda import RiceCuda
from warp_drive.env_wrapper import EnvWrapper
from warp_drive.training.trainer import Trainer
from warp_drive.utils.env_registrar import EnvironmentRegistrar

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)

# Environment, Training, and Model Hyperparameters

In [None]:
run_config = dict(
    
    # Environment settings
    env = dict(  
        negotiation_on=1,
        num_discretization_cells = 10,
    ),

    # Trainer settings
    trainer = dict(
        num_envs = 100,  # Number of environment replicas (numbre of GPU blocks used)
        train_batch_size = 10000,  # total batch size used for training per iteration (across all the environments)
        num_episodes = 100000,  # Total number of episodes to run the training for (can be arbitrarily high!)
    ),
    
    # Policy network settings
    policy =  dict(
        regions = dict(
            to_train = True,
            gamma = 0.92,  # discount factor
            lr = 0.0005,  # learning rate
            entropy_coeff = [[0,0.5], [1000000, 0.1], [5000000, 0.05]],
            vf_loss_coeff = [[0,0.0001], [1000000, 0.001], [5000000, 0.01], [10000000, 0.1]],
            model = dict(   
                type = "fully_connected",
                fc_dims = [256,256],  # dimension(s) of the fully connected layers as a list
                model_ckpt_filepath = ""  # load model parameters from a saved checkpoint (if specified)
            )
        ),
    ),
    
    # Checkpoint saving setting
    saving = dict(
        metrics_log_freq = 10,  # How often (in iterations) to print the metrics
        model_params_save_freq = 5000,  # How often (in iterations) to save the model parameters
        basedir = "/tmp",  # base folder used for saving
        name = "rice",
        tag = "example",
    )
)

# End-to-End Training Loop

In [None]:
# Register the environment
env_registrar = EnvironmentRegistrar()
this_file_dir = os.path.dirname(os.path.abspath("__file__"))
env_registrar.add_cuda_env_src_path(
    RiceCuda.name,
    os.path.join(this_file_dir, "rice_build.cu")
)

# cpu_env = EnvWrapper(Rice())

# add_cpu_env = env_registrar.add(device="cpu")
# add_cpu_env(cpu_env)
# add_gpu_env = env_registrar.add(device="gpu")
# add_gpu_env(cpu_env)

# Create a wrapped environment object via the EnvWrapper
# Ensure that use_cuda is set to True (in order to run on the GPU)
env_wrapper = EnvWrapper(
    RiceCuda(**run_config["env"]),
    num_envs=run_config["trainer"]["num_envs"], 
    use_cuda=True,
    env_registrar=env_registrar,
)

# Agents can share policy models: this dictionary maps policy model names to agent ids.
policy_tag_to_agent_id_map = {
    "regions": [agent_id for agent_id in range(env_wrapper.env.num_agents)],
}

# Create the trainer object
trainer = Trainer(
    env_wrapper=env_wrapper,
    config=run_config,
    policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,
)

# Perform training!
trainer.train()

# Shut off gracefully
# trainer.graceful_close()

### Fetch episode states

In [None]:
# Please note that any variable registered in rice_cuda.py can be put here
desired_outputs = [
  "T_i", # Temperature
  "M_i", # Carbon mass
  "sampled_actions",
  "minMu"
                  ]

episode_states = trainer.fetch_episode_states(
    desired_outputs
)

trainer.graceful_close()

In [None]:
import matplotlib.pyplot as plt

In [None]:
def get_episode_T_AT(episode_states, negotiation_on, plot = 0):
    state = 'T_i'
    if negotiation_on:
        values = episode_states[state][::3,0,0]
    else:
        values =  episode_states[state][:,0,0]

    if plot:
        fig = plt.figure() 
        plt.plot(values[:], label='Temperature - Atmosphere')
        fig.legend()
        # plt.yscale('log')
        fig.show()

    return values


def get_episode_T_LO(episode_states, negotiation_on, plot = 0):
    state = 'T_i'
    if negotiation_on:
        values = episode_states[state][::3,0,1]
    else:
        values =  episode_states[state][:,0,1]

    if plot:
        fig = plt.figure()
        plt.plot(values[:], label='Temperature - Lower Oceans')
        fig.legend()
        # plt.yscale('log')
        fig.show()

        return values

def get_episode_M_AT(episode_states, negotiation_on, plot = 0):
    state = 'M_i'
    if negotiation_on:
        values = episode_states[state][::3,0,0]
    else:
        values =  episode_states[state][:,0,0]

    if plot:
        fig = plt.figure()
        plt.plot(values[:], label='Carbon - Atmosphere')
        fig.legend()
        # plt.yscale('log')
        fig.show()

    return values

def get_episode_M_UP(episode_states, negotiation_on, plot = 0):
    state = 'M_i'
    if negotiation_on:
        values = episode_states[state][::3,0,1]
    else:
        values =  episode_states[state][:,:0,1]

    if plot:
        fig = plt.figure()
        plt.plot(values[:], label='Carbon - Upper Strata')
        fig.legend()
        # plt.yscale('log')
        fig.show()

    return values

def get_episode_M_UP(episode_states, negotiation_on, plot = 0):
    state = 'M_i'
    if negotiation_on:
        values = episode_states[state][::3,0,2]
    else:
        values =  episode_states[state][:,0,2]

    if plot:
        fig = plt.figure()
        plt.plot(values[:], label='Carbon - Lower Oceans')
        fig.legend()
        # plt.yscale('log')
        fig.show()

    return values

def get_episode_minMu(episode_states, negotiation_on, plot = 0):
    state = 'minMu'
    if negotiation_on:
        values = episode_states[state][::3,:]
    else:
        values =  episode_states[state][:,:]

    if plot:
        for agent in range(len(values[0])):
            fig = plt.figure()
            plt.plot(values[:,agent], label='minMu -  Agent:' + str(agent))
            fig.legend()
            # plt.yscale('log')
            fig.show()

    return values

In [None]:
def get_episode_MuAction(episode_states, negotiation_on, plot = 0):
    state = 'samples_actions'
    if negotiation_on:
        values = episode_states[state][::3,:, -2]
    else:
        values =  episode_states[state][:,:, -2]

    if plot:
        for agent in range(len(values[0])):
            fig = plt.figure()
            plt.plot(values[:,agent], label='Mu Action -  Agent:' + str(agent))
            fig.legend()
            # plt.yscale('log')
            fig.show()

    return values

def get_episode_SavingAction(episode_states, negotiation_on, plot = 0):
    state = 'samples_actions'
    if negotiation_on:
        values = episode_states[state][::3,:, -1]
    else:
        values =  episode_states[state][:,:, -1]

    if plot:
        for agent in range(len(values[0])):
            fig = plt.figure()
            plt.plot(values[:,agent], label='Mu Action -  Agent:' + str(agent))
            fig.legend()
            # plt.yscale('log')
            fig.show()

    return values

In [None]:
episode_states['sampled_actions']

In [None]:
get_episode_T_AT(episode_states, 1, 1)

In [None]:
get_episode_T_LO(episode_states, 1, 1)

In [None]:
get_episode_minMu(episode_states, 1, 1)

In [None]:
get_episode_M_AT(episode_states, 1, 1)

In [None]:
T_AT_reg_0_neg_on = episode_states['T_i'][:,0,0]
#T_AT_reg_1_neg_off = episode_states['T_i'][:,1,0]

T_LO_reg_0_neg_on = episode_states['T_i'][:,0,1]
#T_LO_reg_1_neg_off = episode_states['T_i'][:,1,1]


M_AT_reg_0_neg_on = episode_states['M_i'][:,0,0]
#M_AT_reg_1_neg_off = episode_states['M_i'][:,1,0]

M_UP_reg_0_neg_on = episode_states['M_i'][:,0,1]
#M_UP_reg_1_neg_off = episode_states['M_i'][:,1,1]

M_LO_reg_0_neg_on = episode_states['M_i'][:,0,2]
#M_LO_reg_1_neg_off = episode_states['M_i'][:,1,2]
# episode_states_neg_off = episode_states.copy()



In [None]:
episode_states['minMu']

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(T_AT_reg_0_neg_on, label='Temperature - Upper Strata - negotiation on')
fig.legend()
# plt.yscale('log')
fig.show()

fig = plt.figure()
plt.plot(T_LO_reg_0_neg_on, label="Temperature - Lower Oceans - negotiation on")
fig.legend()
# plt.yscale('log')
fig.show()

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(T_AT_reg_0_neg_on, label='Temperature - Upper Strata - negotiation on')
fig.legend()
# plt.yscale('log')
fig.show()

fig = plt.figure()
plt.plot(T_LO_reg_0_neg_on, label="Temperature - Lower Oceans - negotiation on")
fig.legend()
# plt.yscale('log')
fig.show()

In [None]:
run_config = dict(
    
    # Environment settings
    env = dict(  
        negotiation_on=0,
    ),

    # Trainer settings
    trainer = dict(
        num_envs = 100,  # Number of environment replicas (numbre of GPU blocks used)
        train_batch_size = 10000,  # total batch size used for training per iteration (across all the environments)
        num_episodes = 30000,  # Total number of episodes to run the training for (can be arbitrarily high!)
    ),
    
    # Policy network settings
    policy =  dict(
        regions = dict(
            to_train = True,
            gamma = 0.98,  # discount factor
            lr = 0.005,  # learning rate
            model = dict(   
                type = "fully_connected",
                fc_dims = [256, 256],  # dimension(s) of the fully connected layers as a list
                model_ckpt_filepath = ""  # load model parameters from a saved checkpoint (if specified)
            )
        ),
    ),
    
    # Checkpoint saving setting
    saving = dict(
        metrics_log_freq = 10,  # How often (in iterations) to print the metrics
        model_params_save_freq = 5000,  # How often (in iterations) to save the model parameters
        basedir = "/tmp",  # base folder used for saving
        name = "rice",
        tag = "example",
    )
)

In [None]:
# Register the environment
env_registrar = EnvironmentRegistrar()
this_file_dir = os.path.dirname(os.path.abspath("__file__"))
env_registrar.add_cuda_env_src_path(
    RiceCuda.name,
    os.path.join(this_file_dir, "rice_build.cu")
)

# cpu_env = EnvWrapper(Rice())

# add_cpu_env = env_registrar.add(device="cpu")
# add_cpu_env(cpu_env)
# add_gpu_env = env_registrar.add(device="gpu")
# add_gpu_env(cpu_env)

# Create a wrapped environment object via the EnvWrapper
# Ensure that use_cuda is set to True (in order to run on the GPU)
env_wrapper = EnvWrapper(
    RiceCuda(**run_config["env"]),
    num_envs=run_config["trainer"]["num_envs"], 
    use_cuda=True,
    env_registrar=env_registrar,
)

# Agents can share policy models: this dictionary maps policy model names to agent ids.
policy_tag_to_agent_id_map = {
    "regions": [agent_id for agent_id in range(env_wrapper.env.num_agents)],
}

# Create the trainer object
trainer = Trainer(
    env_wrapper=env_wrapper,
    config=run_config,
    policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,
)

# Perform training!
trainer.train()


In [None]:
# Please note that any variable registered in rice_cuda.py can be put here
desired_outputs = [
  "T_i", # Temperature
  "M_i", # Carbon mass
  "sampled_actions",
  "minMu"
                  ]

episode_states_neg_off = trainer.fetch_episode_states(
    desired_outputs
)

trainer.graceful_close()

In [None]:
get_episode_T_AT(episode_states_neg_off, 0, 1)

In [None]:
get_episode_M_AT(episode_states_neg_off, 0, 1)

In [None]:
T_AT_reg_0_neg_off = episode_states_neg_off['T_i'][:,0,0]
#T_AT_reg_1_neg_off = episode_states['T_i'][:,1,0]

T_LO_reg_0_neg_off = episode_states_neg_off['T_i'][:,0,1]
#T_LO_reg_1_neg_off = episode_states['T_i'][:,1,1]


M_AT_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,0]
#M_AT_reg_1_neg_off = episode_states['M_i'][:,1,0]

M_UP_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,1]
#M_UP_reg_1_neg_off = episode_states['M_i'][:,1,1]

M_LO_reg_0_neg_off = episode_states_neg_off['M_i'][:,0,2]
#M_LO_reg_1_neg_off = episode_states['M_i'][:,1,2]
# episode_states_neg_off = episode_states.copy()

In [None]:
fig = plt.figure()
plt.plot(T_AT_reg_0_neg_off, label='Temperature - Upper Strata - negotiation off')
fig.legend()
# plt.yscale('log')
fig.show()

fig = plt.figure()
plt.plot(T_LO_reg_0_neg_off, label="Temperature - Lower Oceans - negotiation off")
fig.legend()
# plt.yscale('log')
fig.show()