<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/examples/pytorch_dist/%5BDist%5D_Distributed_RPC_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Issue: the init_rpc freezes in colab

> Tutorial: https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html

> The RL example is based on: https://github.com/pytorch/examples/blob/main/reinforcement_learning/actor_critic.py

In [1]:
# @title Imports
import argparse
import gymnasium
import numpy as np
from itertools import count
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [13]:
# Cart Pole Gym

seed = 543
gamma = 0.99
log_interval = 10

env = gymnasium.make('CartPole-v1')
env.reset(seed=seed)
torch.manual_seed(seed)

AGENT_NAME = "agent"
OBSERVER_NAME="obs{}"

In [3]:
# @title Libs - Policy model

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

class Policy(nn.Module):
  """Implement both actor and critic in one model."""

  def __init__(self):
    super().__init__()

    hidden_dim = 128

    # common layer
    self.affine1 = nn.Linear(4, hidden_dim)   # 4 is the state space.
    self.dropout = nn.Dropout(p=0.6)

    # actor's head
    self.actor_head = nn.Linear(hidden_dim, 2)

    # critic head
    self.value_head = nn.Linear(hidden_dim, 1)

    # action & reward buffer
    self.saved_actions = []
    self.rewards = []

  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Forward for both actor and critic.

    Args:
      x: the current state. It should contain 4 floats to represent the state.
    """
    assert x.shape == (4,)
    x = self.affine1(x)     # [H]
    x = self.dropout(x)     # [H]
    x = F.relu(x)

    action_logits = self.actor_head(x)
    action_prob = F.softmax(action_logits, dim=-1)   # [2]

    state_value = self.value_head(x)                  # [1]

    return action_prob, state_value

## Observer

> Observer is in charge of interacting with the expensive env.

In [11]:
import torch.distributed.rpc as rpc

class Observer:
  """An Observer is in charge of interacting with env.

  In some RL setup, it is expensive to interact with the env, while it is
  relatively cheap to run the agent (e.g. because it has small model). In this
  case, we want to run multiple Observer in a distributed way, and reuse the
  same agent.
  """

  def __init__(self, seed: int):
    self.id = rpc.get_worker_info().id
    self.env = gymnasium.make('CartPole-v1')
    self.env.reset(seed=seed)

  def run_episode(self, agent_rref):
    """Run a single episode (full trajectory)."""
    state, _ = self.env.reset()
    ep_reward = 0           # Total reward on this episode

    for _ in range(10000):
      # Send the state to the remote agent and get the action back.
      action = agent_rref.rpc_sync().select_action(self.id, state)

      state, reward, done, _, _ = self.env.step(action)

      # Send the reward to the remote agent.
      agent_rref.rpc_sync().report_reward(self.id, reward)

      ep_reward += reward

      if done:
        break

## Agent

> Agent is in charge of predict the next action.

In [16]:
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical

class Agent:
  def __init__(self, world_size):
    self.ob_rrefs = []
    self.agent_rref = RRef(self)
    self.rewards = {}
    self.saved_log_probs = {}
    self.policy = Policy()
    self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
    self.eps = np.finfo(np.float32).eps.item()
    self.running_reward = 0     # The running epoch-level total reward.
    self.reward_threshold = gymnasium.make('CartPole-v1').spec.reward_threshold

    for ob_rank in range(world_size):
      ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
      self.ob_rrefs.append(remote(ob_info, Observer))
      self.rewards[ob_info.id] = []
      self.saved_log_probs[ob_info.id] = []

  def select_action(self, ob_id, state):
    assert len(self.rewards[ob_id]) == len(self.saved_log_probs[ob_id])

    state = torch.from_numpy(state).float().unsqueeze(0)
    action_prob, state_value = self.policy(state)

    # Sample from the probs and emit the enum int of the next action
    sampler = Categorical(action_prob)
    next_action = sampler.sample()               # [1]

    # save the action to buffer
    #
    # log_prob is the log of the probability of the selected action.
    log_prob = sampler.log_prob(next_action)     # [1]
    assert torch.allclose(log_prob, action_prob[next_action].log(), atol=1e-5)
    self.saved_log_probs[ob_id].append(log_prob)

    return next_action.item()

  def report_reward(self, ob_id, reward):
    self.rewards[ob_id].append(reward)

    assert len(self.rewards[ob_id]) == len(self.saved_log_probs[ob_id])

  def run_episode(self):
    """Run one episode on each observer."""
    futs = []
    for ob_rref in self.ob_rrefs:
      handle = rpc_async(ob_rref.owner(), ob_rref.rpc_sync().run_episode, args=(self.agent_rref,))
      futs.append(handle)

    for fut in futs:
      fut.wait()

  def finish_episode(self):
    # joins probs and rewards from different observers into lists
    R, probs, rewards = 0, [], []
    for ob_id in self.rewards:
        probs.extend(self.saved_log_probs[ob_id])
        rewards.extend(self.rewards[ob_id])

    # use the minimum observer reward to calculate the running reward
    min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards])
    self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward

    # clear saved probs and rewards
    for ob_id in self.rewards:
        self.rewards[ob_id] = []
        self.saved_log_probs[ob_id] = []

    policy_loss, returns = [], []
    for r in rewards[::-1]:
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + self.eps)
    for log_prob, R in zip(probs, returns):
        policy_loss.append(-log_prob * R)
    self.optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    self.optimizer.step()
    return min_reward

In [None]:
import os

import torch.multiprocessing as mp
import torch.nn.functional as F
import torch.distributed as dist


def init_process(rank: int, world_size: int):
  print(f"Starting process with {rank=}, {world_size=}")

  if rank == 0:
    # rank 0 is the agent
    # rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)
    rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size, rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method="file://content/agent_share"))
    print("Agent RPC started!!!")

    agent = Agent(world_size)
    print(f"This will run until reward threshold of {agent.reward_threshold}"
                " is reached. Ctrl+C to exit.")
    for i_episode in count(1):
      agent.run_episode()
      last_reward = agent.finish_episode()

      if i_episode % log_interval == 0:
        print(f"Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: "
            f"{agent.running_reward:.2f}")
      if agent.running_reward > agent.reward_threshold:
        print(f"Solved! Running reward is now {agent.running_reward}!")
        break
    else:
      # other ranks are the observer
      # rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
      rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size, rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=f"file://content/ob_rank_{rank}"))
      print(f"Observer RPC started on {rank=}!!!")

    # block until all rpcs finish, and shutdown the RPC instance
    rpc.shutdown()


os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29502' # You can choose a different port if 12355 is in use


world_size = 4

processes = []
for rank in range(world_size):
  p = mp.Process(target=init_process, args=(rank, world_size))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

Starting process with rank=0, world_size=4
Starting process with rank=1, world_size=4
Starting process with rank=2, world_size=4
Starting process with rank=3, world_size=4
