In [1]:
%matplotlib inline


# Behavioral cloning with PyTorch


We present here how to perform behavioral cloning on a Minari dataset using [PyTorch](https://pytorch.org/).
We will start generating the dataset of the expert policy for the [CartPole-v1](https://gymnasium.farama.org/environments/classic_control/cart_pole/) environment, which is a classic control problem.
The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful timestep.



## Imports
For this tutorial you will need the [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) library, which you can install with `pip install rl_zoo3`.
Let's then import all the required packages and set the random seed for reproducibility:



In [3]:
import os
import sys

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gymnasium import spaces
from rl_zoo3.train import train
from stable_baselines3 import PPO
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import minari
from minari import DataCollector


torch.manual_seed(42)

<torch._C.Generator at 0x126138e10>

## Policy training
Now we can train the expert policy using RL Baselines3 Zoo.
We train a PPO agent on the environment:



In [4]:
sys.argv = ["python", "--algo", "ppo", "--env", "CartPole-v1"]
train()

Seed: 2154629021
Loading hyperparameters from: /Users/frankcholula/Workspace/school/FRL-playground/.frl/lib/python3.10/site-packages/rl_zoo3/hyperparams/ppo.yml
Default hyperparameters for environment (ones being tuned will be overridden):
OrderedDict([('batch_size', 256),
             ('clip_range', 'lin_0.2'),
             ('ent_coef', 0.0),
             ('gae_lambda', 0.8),
             ('gamma', 0.98),
             ('learning_rate', 'lin_0.001'),
             ('n_envs', 8),
             ('n_epochs', 20),
             ('n_steps', 32),
             ('n_timesteps', 100000.0),
             ('policy', 'MlpPolicy')])
Using 8 environments
Creating test environment
Using cpu device
Log path: logs/ppo/CartPole-v1_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 16.9     |
|    ep_rew_mean     | 16.9     |
| time/              |          |
|    fps             | 20600    |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    tot

This will generate a new folder named `log` with the expert policy.



## Dataset generation
Now let's generate the dataset using the [DataCollector](https://minari.farama.org/api/data_collector/) wrapper:




In [20]:
env = DataCollector(gym.make('CartPole-v1'))
path = os.path.abspath('') + '/logs/ppo/CartPole-v1_1/best_model.zip'
agent = PPO.load(path)

total_episodes = 1_000
for i in tqdm(range(total_episodes)):
    obs, _ = env.reset(seed=42)
    while True:
        action, _ = agent.predict(obs)
        obs, rew, terminated, truncated, info = env.step(action)

        if terminated or truncated:
            break

100%|██████████| 1000/1000 [01:03<00:00, 15.77it/s]


In [21]:
dataset = env.create_dataset(
    dataset_id="CartPole-v1/ppo-1000-v1",
    algorithm_name="ppo",
    code_permalink="https://github.com/frankcholula/FRL-playground/blob/main/code/behavioral_cloning.py",
    author="Frank Lu",
    author_email="lu.phrank@gmail.com",
    description="Behavioral cloning dataset for CartPole-v1 using PPO",
    eval_env="CartPole-v1"
)

Once executing the script, the dataset will be saved on your disk. You can display the list of datasets with ``minari list local`` command.



## Behavioral cloning with PyTorch
Now we can use PyTorch to learn the policy from the offline dataset.
Let's define the policy network:



In [22]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In this scenario, the output dimension will be two, as previously mentioned. As for the input dimension, it will be four, corresponding to the observation space of ``CartPole-v1``.
Our next step is to load the dataset and set up the training loop. The ``MinariDataset`` is compatible with the PyTorch Dataset API, allowing us to load it directly using [PyTorch DataLoader](https://pytorch.org/docs/stable/data.html).
However, since each episode can have a varying length, we need to pad them.
To achieve this, we can utilize the [collate_fn](https://pytorch.org/docs/stable/data.html#working-with-collate-fn) feature of PyTorch DataLoader. Let's create the ``collate_fn`` function:



In [38]:
def collate_fn(batch):
    return {
        "id": torch.Tensor([x.id for x in batch]),
        # "seed": torch.Tensor([x.seed for x in batch]),
        # "total_timesteps": torch.Tensor([x.total_timesteps for x in batch]),
        "observations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.observations) for x in batch],
            batch_first=True
        ),
        "actions": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.actions) for x in batch],
            batch_first=True
        ),
        "rewards": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.rewards) for x in batch],
            batch_first=True
        ),
        "terminations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.terminations) for x in batch],
            batch_first=True
        ),
        "truncations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.truncations) for x in batch],
            batch_first=True
        )
    }

We can now proceed to load the data and create the training loop.
To begin, let's initialize the DataLoader, neural network, optimizer, and loss.



In [39]:
minari_dataset = minari.load_dataset("CartPole-v1/ppo-1000-v1")
dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)

env = minari_dataset.recover_environment()
observation_space = env.observation_space
action_space = env.action_space
assert isinstance(observation_space, spaces.Box)
assert isinstance(action_space, spaces.Discrete)

policy_net = PolicyNetwork(np.prod(observation_space.shape), action_space.n)
optimizer = torch.optim.Adam(policy_net.parameters())
loss_fn = nn.CrossEntropyLoss()

In [43]:
episode = minari_dataset[0]
print(episode)
print(episode.__dict__.keys())

EpisodeData(id=0, total_steps=500, observations=ndarray of shape (501, 4) and dtype float32, actions=ndarray of shape (500,) and dtype int64, rewards=ndarray of 500 floats, terminations=ndarray of 500 bools, truncations=ndarray of 500 bools, infos=dict with the following keys: [])
dict_keys(['id', 'observations', 'actions', 'rewards', 'terminations', 'truncations', 'infos'])


We use the cross-entropy loss like a classic classification task, as the action space is discrete.
We then train the policy to predict the actions:



In [41]:
num_epochs = 32

for epoch in range(num_epochs):
    for batch in dataloader:
        a_pred = policy_net(batch['observations'][:, :-1])
        a_hat = F.one_hot(batch["actions"]).type(torch.float32)
        loss = loss_fn(a_pred, a_hat)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch: {epoch}/{num_epochs}, Loss: {loss.item()}")

Epoch: 0/32, Loss: 1543.491943359375
Epoch: 1/32, Loss: 1533.8970947265625
Epoch: 2/32, Loss: 1525.1358642578125
Epoch: 3/32, Loss: 1517.8487548828125
Epoch: 4/32, Loss: 1514.1788330078125
Epoch: 5/32, Loss: 1510.79443359375
Epoch: 6/32, Loss: 1507.1866455078125
Epoch: 7/32, Loss: 1504.1551513671875
Epoch: 8/32, Loss: 1501.1678466796875
Epoch: 9/32, Loss: 1496.8033447265625
Epoch: 10/32, Loss: 1493.4075927734375
Epoch: 11/32, Loss: 1489.1939697265625
Epoch: 12/32, Loss: 1485.9334716796875
Epoch: 13/32, Loss: 1482.1888427734375
Epoch: 14/32, Loss: 1478.7518310546875
Epoch: 15/32, Loss: 1475.19580078125
Epoch: 16/32, Loss: 1472.0560302734375
Epoch: 17/32, Loss: 1469.5020751953125
Epoch: 18/32, Loss: 1467.544677734375
Epoch: 19/32, Loss: 1466.12158203125
Epoch: 20/32, Loss: 1464.5888671875
Epoch: 21/32, Loss: 1463.237060546875
Epoch: 22/32, Loss: 1461.949462890625
Epoch: 23/32, Loss: 1460.8919677734375
Epoch: 24/32, Loss: 1459.6158447265625
Epoch: 25/32, Loss: 1458.77490234375
Epoch: 26/3

And now, we can evaluate if the policy learned from the expert!



In [44]:
env = gym.make("CartPole-v1", render_mode="human")
obs, _ = env.reset(seed=42)
done = False
accumulated_rew = 0
while not done:
    action = policy_net(torch.Tensor(obs)).argmax()
    obs, rew, ter, tru, _ = env.step(action.numpy())
    done = ter or tru
    accumulated_rew += rew

env.close()
print("Accumulated rew: ", accumulated_rew)

Accumulated rew:  500.0


We can visually observe that the learned policy aces this simple control task, and we get the maximum reward 500, as the episode is truncated after 500 steps.


