# **PushBlockEnv Reinforcement Learning**

In [1]:
ENV_PATH = "../../../envs/PushBlock"

### **Imports**

In [2]:
### Utility
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os

### Torch
import torch
from torch import nn
from torch import optim
### Torch RL
# Env
from torchrl.envs.libs import UnityMLAgentsEnv
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from torchrl.envs.utils import step_mdp, check_env_specs
from torchrl.envs import TransformedEnv, Stack, ExcludeTransform, CatTensors

# Data Collection
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement


# Model
from model import create_policy, create_value
# Train Util
from train_util import make_loss_module,  compute_trajectory_metrics, loss_dict, Stopwatch, Logger, Checkpointer, compute_single_trajectory_metrics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}")

Using Device: cpu


## **Create Torch Env**

In [3]:
def create_unity_env(graphics=False, **kwargs):
    try:
        env.close()
    except:
        pass
    
    env = TransformedEnv(UnityMLAgentsEnv(
        file_name=ENV_PATH, worker_id=np.random.randint(10000), 
        no_graphics=(not graphics), **kwargs,
        device="cpu",
    ))

    return env

def batch_agents(env, out_key="agents"):
    agent_root_key = env.observation_keys[0][0]
    agents = list(env.action_spec[agent_root_key].keys())
    
    # Create transform
    stack = Stack(
        in_keys=[(agent_root_key, agent) for agent in agents], 
        out_key=(out_key,), 
        in_key_inv=(out_key,), 
        out_keys_inv=[(agent_root_key, agent) for agent in agents]
    )

    env.append_transform(stack)
    return env

def create_base_env(graphics=False, **kwargs):
    env = create_unity_env(graphics, **kwargs)

    # Batch into Agents Dimension
    env = batch_agents(env)

    # Concatenate Observation
    obs_keys = env.observation_keys
    env.append_transform(
        CatTensors(in_keys=obs_keys, out_key=("agents", "observation"), del_keys=False)
    )
    # Exclude Group Reward
    env.append_transform(
        ExcludeTransform(("agents", "group_reward"))
    )
    return env

def create_env(graphics=False, time_scale = 1, **kwargs):
    # Time scale
    if time_scale != 1:
        engine_config_channel = EngineConfigurationChannel()
        env = create_base_env(graphics, **kwargs, side_channels=[engine_config_channel])
        engine_config_channel.set_configuration_parameters(time_scale=time_scale)
    else:
        env = create_base_env(graphics, **kwargs)

    return env

### **Inspect Specs**

In [4]:
def print_specs(env):
    print("action_spec:", env.action_spec)
    print("reward_spec:", env.reward_spec)
    print("done_spec:", env.done_spec)
    print("observation_spec:", env.observation_spec)

env = create_env(time_scale=20)
check_env_specs(env)
print_specs(env)

[UnityMemory] Configuration Parameters - Can be set up in boot.config
    "memorysetup-bucket-allocator-granularity=16"
    "memorysetup-bucket-allocator-bucket-count=8"
    "memorysetup-bucket-allocator-block-size=4194304"
    "memorysetup-bucket-allocator-block-count=1"
    "memorysetup-main-allocator-block-size=16777216"
    "memorysetup-thread-allocator-block-size=16777216"
    "memorysetup-gfx-main-allocator-block-size=16777216"
    "memorysetup-gfx-thread-allocator-block-size=16777216"
    "memorysetup-cache-allocator-block-size=4194304"
    "memorysetup-typetree-allocator-block-size=2097152"
    "memorysetup-profiler-bucket-allocator-granularity=16"
    "memorysetup-profiler-bucket-allocator-bucket-count=8"
    "memorysetup-profiler-bucket-allocator-block-size=4194304"
    "memorysetup-profiler-bucket-allocator-block-count=1"
    "memorysetup-profiler-allocator-block-size=16777216"
    "memorysetup-profiler-editor-allocator-block-size=1048576"
    "memorysetup-temp-allocator-siz

In [5]:
action_key = env.action_key[1]
print(f"action_key: {action_key}")

observation_shape = env.observation_spec["agents", "observation"].shape
action_shape = env.action_spec["agents", action_key].shape

print(f"observation_shape: {observation_shape}, action_shape: {action_shape}")

action_key: discrete_action
observation_shape: torch.Size([32, 210]), action_shape: torch.Size([32, 7])


In [6]:
td = env.rollout(10)
td

TensorDict(
    fields={
        agents: TensorDict(
            fields={
                StackingSensor_size3_OffsetRayPerceptionSensor: Tensor(shape=torch.Size([10, 32, 105]), device=cpu, dtype=torch.float32, is_shared=False),
                StackingSensor_size3_RayPerceptionSensor: Tensor(shape=torch.Size([10, 32, 105]), device=cpu, dtype=torch.float32, is_shared=False),
                discrete_action: Tensor(shape=torch.Size([10, 32, 7]), device=cpu, dtype=torch.int32, is_shared=False),
                done: Tensor(shape=torch.Size([10, 32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10, 32, 210]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10, 32]),
 

### **Inspect Action Space**

We have categorical actions: 7 choices (choose one).

In [7]:
actions_df = pd.DataFrame({
    "action": td["agents", action_key].reshape(-1)
})
actions_df.describe()

Unnamed: 0,action
count,2240.0
mean,0.142857
std,0.350005
min,0.0
25%,0.0
50%,0.0
75%,0.0
max,1.0


### **Inspect Observation Space**

Observations consist of binary features and continuous features scaled to the range [0, 1].

In [8]:
obs_df = pd.DataFrame({
    "obs": td["agents", "observation"].reshape(-1)
})
obs_df.describe()

Unnamed: 0,obs
count,67200.0
mean,0.366839
std,0.464883
min,0.0
25%,0.0
50%,0.0
75%,1.0
max,1.0


### **Inspect Reward Space**
-0.001 for surviving, +1 on success.

In [9]:
reward_df = pd.DataFrame({
    "reward": td["next", "agents", "reward"].reshape(-1),
})
reward_df.describe()

Unnamed: 0,reward
count,320.0
mean,-0.001
std,0.0
min,-0.001
25%,-0.001
50%,-0.001
75%,-0.001
max,-0.001


## **Create Models**

### **Config**

In [10]:
HIDDEN_DIM = 256
N_BLOCKS = 3

MODEL_CONFIG = {
    "hidden_dim": HIDDEN_DIM,
    "n_blocks": N_BLOCKS,
    "in_features": int(observation_shape[1]),
    "out_features": int(action_shape[1]),
}
MODEL_CONFIG

{'hidden_dim': 256, 'n_blocks': 3, 'in_features': 210, 'out_features': 7}

### **Inspect**

In [11]:
policy, value = create_policy(MODEL_CONFIG).to(device), create_value(MODEL_CONFIG).to(device)
loss_module = make_loss_module(policy, value, epsilon=0.1, entropy_coef=0.01, gamma=0.99, lmbda=0.95).to(device)

with torch.no_grad():
    td = env.rollout(10, policy=policy, auto_cast_to_device=True).to(device)
    loss_module.value_estimator(td)
data = step_mdp(td)["agents"]
data

  source[group_name][agent_name]["truncated"] = torch.tensor(


TensorDict(
    fields={
        StackingSensor_size3_OffsetRayPerceptionSensor: Tensor(shape=torch.Size([9, 32, 105]), device=cpu, dtype=torch.float32, is_shared=False),
        StackingSensor_size3_RayPerceptionSensor: Tensor(shape=torch.Size([9, 32, 105]), device=cpu, dtype=torch.float32, is_shared=False),
        advantage: Tensor(shape=torch.Size([9, 32, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        discrete_action: Tensor(shape=torch.Size([9, 32, 7]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([9, 32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        log_prob: Tensor(shape=torch.Size([9, 32]), device=cpu, dtype=torch.float32, is_shared=False),
        logits: Tensor(shape=torch.Size([9, 32, 7]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([9, 32, 210]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([9, 32, 1]),

In [12]:
compute_single_trajectory_metrics(td)

{'return': 0.14822500944137573,
 'episode_length': 9.0,
 'entropy': 1.80357825756073}

## **Training**

### **Train Config**

In [13]:
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
amp_dtype   = torch.float16 if device_type == "cuda" else torch.float32

### Training Loop Params
WORKERS = os.cpu_count()
print("device:", device, "workers:", WORKERS)
STORAGE_DEVICE = device
GENERATION_SIZE = 1000 * WORKERS # 1000 is the truncation point for the env
TIMESTAMPS = GENERATION_SIZE * 50
EPOCHS = 5

# GD Params
MINIBATCH_SIZE = 64
LR = 5e-5
MAX_GRAD_NORM = 0.5

### RL Params

# ENV Params
TIME_SCALE = 20

# PPO Params
GAMMA = 0.99
GAE_LAMBDA = 0.95
EPSILON = 0.2
ENTROPY_COEF = 1e-5

LOG_KEYS = [
    "generation", "time", "collection_time", "train_time",  # Training Progress Metrics
    "return", "episode_length",                             # Performance Metrics
    "entropy",                                              # Exploration Metrics
    "policy_loss", "kl_approx", "clip_fraction", "ESS",     # Policy Metrics
    "value_loss", "explained_variance",                     # Value Metrics
]

LOG_PATH = 'logs'
CKPT_PATH = 'ckpt'
MODEL_PATH = 'models'
NAME = 'run0'

LOG_INTERVAL = 1
CHECKPOINT_INTERVAL = 1
METRIC_KEY = "return"

CONTINUE=False

device: cpu workers: 1


### **Train Loop**

In [14]:
def train(create_env, policy, value, timestamps=TIMESTAMPS):
    # Loss + Optimizer
    loss_module = make_loss_module(policy, value, epsilon=EPSILON, entropy_coef=ENTROPY_COEF, gamma=GAMMA, lmbda=GAE_LAMBDA)
    optimizer = optim.Adam(loss_module.parameters(), lr=LR)
    # only need scaler with float16, float32 and bfloat16 have wider exponent ranges.
    scaler = torch.amp.GradScaler(enabled=(amp_dtype == torch.float16))

    # Logger + Checkpointer
    logger = Logger(keys = LOG_KEYS, log_path=LOG_PATH, name=NAME)
    checkpointer = Checkpointer(ckpt_path=CKPT_PATH, name=NAME)
    

    # Continue/Reset
    start_generation = 0
    if not CONTINUE:
        logger.reset()
        checkpointer.reset()
    else:
        checkpoint = checkpointer.load_progress()
        if checkpoint:
            start_generation = int(checkpoint["generation"])
            policy.load_state_dict(checkpoint["policy_state_dict"])
            value.load_state_dict(checkpoint["value_state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            if "scaler_state_dict" in checkpoint:
                scaler.load_state_dict(checkpoint["scaler_state_dict"])
            print("CHECKPOINT FOUND, STARTING FROM GENERATION:", start_generation)
        else:
            print("CHECKPOINT NOT FOUND, STARTING FROM SCRATCH")

    # Watches
    short_watch = Stopwatch()
    long_watch = Stopwatch()

    if WORKERS > 1:
        collector = MultiSyncDataCollector([create_env]*WORKERS, policy, 
            frames_per_batch=GENERATION_SIZE, 
            total_frames=timestamps - GENERATION_SIZE*start_generation, 
            env_device="cpu", device=device, storing_device=STORAGE_DEVICE, 
            update_at_each_batch=True
        )
    else:
        collector = SyncDataCollector(
            create_env, policy, 
            frames_per_batch=GENERATION_SIZE, 
            total_frames=timestamps - GENERATION_SIZE*start_generation, 
            env_device="cpu", device=device, storing_device=STORAGE_DEVICE,
        )
    replay_buffer = ReplayBuffer(storage=LazyTensorStorage(GENERATION_SIZE, device=STORAGE_DEVICE), sampler=SamplerWithoutReplacement(), batch_size=MINIBATCH_SIZE)


    ### TRAINING LOOP
    short_watch.start(); long_watch.start()
    policy.eval(); value.eval()
    for i, tensordict_data in enumerate(collector):
        # 0. Time collect
        collection_time = short_watch.end()
        logger.sum({"collection_time": collection_time})

        # 1. Compute Advantages and Value Target and Metrics
        tensordict_data = tensordict_data.to(device)
        with torch.no_grad():
            loss_module.value_estimator(tensordict_data)
            metrics = compute_trajectory_metrics(tensordict_data)
        logger.add(metrics)

        # 2. Minibatch Gradient Descent Loop
        short_watch.start()
        policy.train(); value.train()
        replay_buffer.empty(); replay_buffer.extend(tensordict_data.reshape(-1))
        for epoch in range(EPOCHS):
            for j, batch in enumerate(replay_buffer):
                # 2.1 Optimization Step
                batch = batch.to(device)
                with torch.autocast(device_type=device_type, dtype=amp_dtype, enabled=(amp_dtype==torch.float16)):
                    loss_data = loss_module(batch)
                    loss = loss_data["loss_objective"] + loss_data["loss_critic"] + loss_data["loss_entropy"]
                optimizer.zero_grad(set_to_none=True)
                scaler.scale(loss).backward()

                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=MAX_GRAD_NORM)

                scaler.step(optimizer)
                scaler.update()

                # 2.2 Accumulate Metric
                weight = float(batch.batch_size[0])
                logger.accumulate(loss_dict(loss_data, weight))
        policy.eval(); value.eval()
        train_time = short_watch.end()
        logger.sum({"train_time": train_time})

        # 3. Log results
        logger.sum({"generation": 1})
        if (i % LOG_INTERVAL) == 0:
            logger.sum({"time": long_watch.end()})
            long_watch.start()
            logger.next(print_row=True)
        # 4. Checkpoint model
        if (i % CHECKPOINT_INTERVAL) == 0:
            gen = int(start_generation + i + 1)
            metric = metrics[METRIC_KEY]
            checkpointer.save_progress(metric_key=METRIC_KEY,
            state_obj={
                "generation": gen,
                "policy_state_dict": policy.state_dict(),
                "value_state_dict": value.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scaler_state_dict": scaler.state_dict(),
                METRIC_KEY: metric,
            })

        # 5. Start collection time
        short_watch.start()

    checkpointer.copy_model('latest', MODEL_PATH, ('policy_state_dict', 'value_state_dict'))
    return logger.dataframe()

In [None]:
policy = create_policy(MODEL_CONFIG).to(device)
value = create_value(MODEL_CONFIG).to(device)

train(lambda: create_env(time_scale=TIME_SCALE), policy, value, timestamps=TIMESTAMPS)

[UnityMemory] Configuration Parameters - Can be set up in boot.config
    "memorysetup-bucket-allocator-granularity=16"
    "memorysetup-bucket-allocator-bucket-count=8"
    "memorysetup-bucket-allocator-block-size=4194304"
    "memorysetup-bucket-allocator-block-count=1"
    "memorysetup-main-allocator-block-size=16777216"
    "memorysetup-thread-allocator-block-size=16777216"
    "memorysetup-gfx-main-allocator-block-size=16777216"
    "memorysetup-gfx-thread-allocator-block-size=16777216"
    "memorysetup-cache-allocator-block-size=4194304"
    "memorysetup-typetree-allocator-block-size=2097152"
    "memorysetup-profiler-bucket-allocator-granularity=16"
    "memorysetup-profiler-bucket-allocator-bucket-count=8"
    "memorysetup-profiler-bucket-allocator-block-size=4194304"
    "memorysetup-profiler-bucket-allocator-block-count=1"
    "memorysetup-profiler-allocator-block-size=16777216"
    "memorysetup-profiler-editor-allocator-block-size=1048576"
    "memorysetup-temp-allocator-siz

  source[group_name][agent_name]["truncated"] = torch.tensor(


   generation       time  collection_time  train_time    return  \
0           1  34.294036        28.137577     5.78922 -0.421378   

   episode_length   entropy  policy_loss  kl_approx  clip_fraction       ESS  \
0           500.0  1.782504     0.026393   0.010858       0.148075  0.980853   

   value_loss  explained_variance  
0    0.004929            0.610575  


  source[group_name][agent_name]["truncated"] = torch.tensor(


   generation       time  collection_time  train_time    return  \
1           2  68.053759        55.654487   11.622838  0.037738   

   episode_length   entropy  policy_loss  kl_approx  clip_fraction       ESS  \
1       90.909088  1.760716     0.005723   0.013393       0.174652  0.976749   

   value_loss  explained_variance  
1    0.015734            0.514099  


  source[group_name][agent_name]["truncated"] = torch.tensor(
