In [12]:
# Torch
import torch
from torch import nn
from torch import optim

# Common Util
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Environment
from env import create_env

# Model
from rlkit.models import MLP
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.modules import ProbabilisticActor, TanhNormal

# Training
from rlkit.util import Checkpointer, Logger, Stopwatch, round_up

# Config
from config import (
    ENV_PATH, 
    N_ENVS, OBSERVATION_DIM, ACTION_DIM,
    LOG_KEYS, LOG_INDEX, BEST_METRIC_KEY,
    MODEL_PATH, CKPT_PATH, LOG_PATH, RESULTS_PATH,
)

# Use cuda if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}")

Using Device: cpu


### **Environment**

In [51]:
def print_specs(env):
    print("action_spec:", env.action_spec)
    print("reward_spec:", env.reward_spec)
    print("done_spec:", env.done_spec)
    print("observation_spec:", env.observation_spec)

# env = create_env(graphics=False, time_scale=5)
# td = env.rollout(1000, break_when_any_done=False)
# print_specs(env)

In [52]:
# act_data = pd.DataFrame(td["action"].reshape(-1))
# obs_data = pd.DataFrame(td["observation"].reshape(-1))
# obs_data = obs_data.clip(float(obs_data.quantile(0.01).iloc[0]), float(obs_data.quantile(0.99).iloc[0]))
# rew_data = pd.DataFrame(td["next", "reward"].reshape(-1))

# plt.violinplot(obs_data, positions=[0], showmedians=True, showextrema=True, widths=0.9)
# plt.violinplot(act_data, positions=[1], showmedians=True, showextrema=True, widths=0.9)
# plt.violinplot(rew_data, positions=[2], showmedians=True, showextrema=True, widths=0.9);

### **Model**

In [53]:
def create_policy(model_config):
    model_config = model_config.copy()
    model_config["out_features"] *= 2
    model = MLP(**model_config)

    model = nn.Sequential(
        model,
        NormalParamExtractor()
    )
    model = TensorDictModule(model, in_keys=["observation"], out_keys=["loc", "scale"])
    
    policy = ProbabilisticActor(
        module=model,  
        distribution_class=TanhNormal,

        in_keys=["loc", "scale"],
        out_keys=["action"],

        return_log_prob=True,
        log_prob_key="log_prob",
        cache_dist=True,
    )

    return policy

def create_value(model_config):
    # Remove out_features from config
    model_config = model_config.copy()
    model_config["out_features"] = 1

    model = MLP(**model_config)
    value = TensorDictModule(model, in_keys=["observation"], out_keys=["state_value"])
    return value

### **Config**

In [54]:
MODEL_CONFIG = {
    "hidden_dim": 256,
    "n_blocks": 3,
    "in_features": OBSERVATION_DIM,
    "out_features": ACTION_DIM,
}

In [10]:
# AMP + Scaler?
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
amp_dtype   = torch.float16 if device_type == "cuda" else torch.float32

### Training Loop Params
WORKERS = os.cpu_count() // 2
STORAGE_DEVICE = "cpu"
GENERATION_SIZE = round_up(64_000, WORKERS*N_ENVS)
GENERATIONS = 200
COLLECTOR_BUFFER_SIZE = round_up(WORKERS*N_ENVS*128)

# Advantage Comp Params
SLICE_LEN = 128 # GAE Window
ADV_MINIBATCH_SIZE = round_up(10_000, SLICE_LEN)

# GD Params
EPOCHS = 2
MINIBATCH_SIZE = 1024
LR = 3e-4
MAX_GRAD_NORM = 0.5
KL_TARGET = 0.01

### RL Params

# ENV Params
TIME_SCALE = 10

# PPO Params
GAMMA = 0.99
GAE_LAMBDA = 0.95
EPSILON = 0.2
ENTROPY_COEF = 1e-4

NAME = 'run0'

LOG_INTERVAL = 1
CHECKPOINT_INTERVAL = 1

CONTINUE=False

In [11]:
def summary():
    s = [
        ("workers", WORKERS), ("parallel envs", WORKERS*N_ENVS),
        ("generation_size", GENERATION_SIZE), ("generations", GENERATIONS), ("timesteps", GENERATIONS*GENERATION_SIZE),
        ("device", device),
    ]
    for key, value in s:
        print(f"{key} = {value} ")
summary()

workers = 7 
parallel envs = 70 
generation_size = 64050 
generations = 200 
timesteps = 12810000 
device = cpu 


In [None]:
def train(create_env, policy, value, generations=GENERATIONS):
    # Loss + Optimizer
    loss_module = make_loss_module(policy, value, epsilon=EPSILON, entropy_coef=ENTROPY_COEF, gamma=GAMMA, lmbda=GAE_LAMBDA)
    optimizer = optim.Adam(loss_module.parameters(), lr=LR)
    # only need scaler with float16, float32 and bfloat16 have wider exponent ranges.
    scaler = torch.amp.GradScaler(enabled=(amp_dtype == torch.float16))

    # Logger + Checkpointer
    logger = Logger(keys = LOG_KEYS, log_path=LOG_PATH, name=NAME)
    checkpointer = Checkpointer(ckpt_path=CKPT_PATH, name=NAME)

    # Continue/Reset
    start_generation = 0
    if not CONTINUE:
        logger.reset()
        checkpointer.reset()
    else:
        checkpoint = checkpointer.load_progress()
        if checkpoint:
            start_generation = int(checkpoint["generation"])
            policy.load_state_dict(checkpoint["policy_state_dict"])
            value.load_state_dict(checkpoint["value_state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            if "scaler_state_dict" in checkpoint:
                scaler.load_state_dict(checkpoint["scaler_state_dict"])
            print("CHECKPOINT FOUND, STARTING FROM GENERATION:", start_generation)
        else:
            print("CHECKPOINT NOT FOUND, STARTING FROM SCRATCH")

    # Watches
    short_watch = Stopwatch()
    long_watch = Stopwatch()

    # Replay Buffers
    collect_replay_buffer = ReplayBuffer(
        storage=LazyMemmapStorage(GENERATION_SIZE, device=STORAGE_DEVICE, ndim=2 + int(WORKERS > 1)),
        sampler=SliceSamplerWithoutReplacement(
            slice_len = SLICE_LEN,
            shuffle=False, strict_length=False, 
            end_key=("next", "done")
        ),
        batch_size=ADV_MINIBATCH_SIZE,
    )
    train_replay_buffer = ReplayBuffer(storage=LazyMemmapStorage(GENERATION_SIZE, device=STORAGE_DEVICE), sampler=SamplerWithoutReplacement(), batch_size=MINIBATCH_SIZE)

    # Collectors
    if WORKERS > 1:
        collector = MultiSyncDataCollector([create_env]*WORKERS, policy, 
            frames_per_batch=COLLECTOR_BUFFER_SIZE, 
            total_frames=GENERATION_SIZE*(generations - start_generation), 
            env_device="cpu", device=device, storing_device=STORAGE_DEVICE, 
            update_at_each_batch=True,
        )
    else:
        collector = SyncDataCollector(
            create_env, policy, 
            frames_per_batch=COLLECTOR_BUFFER_SIZE, 
            total_frames=GENERATION_SIZE*(generations - start_generation), 
            env_device="cpu", device=device, storing_device=STORAGE_DEVICE,
        )


    collector_iters_per_gen = int(np.ceil(GENERATION_SIZE / COLLECTOR_BUFFER_SIZE))
    long_watch.start()

    ### TRAINING LOOP
    for i in range(start_generation, generations):
        # 1. COLLECT TRAJECTORY DATASET
        policy.eval(); value.eval()
        short_watch.start(); 
        collect_replay_buffer.empty()

        # Buffer in memory then move to memory mapped storage in loop
        for j in range(collector_iters_per_gen):
            data = collector.next()
            collect_replay_buffer.extend(data)
        
        collection_time = short_watch.end()
        logger.sum({"collection_time": collection_time})


        # 2. Compute Advantages, Value Target, and Metrics (Iterate Along Trajectories)
        train_replay_buffer.empty()
        for j, batch in enumerate(collect_replay_buffer):
            batch = batch.to(device)
        
            with torch.no_grad():
                loss_module.value_estimator(batch)
                metrics = compute_trajectory_metrics(split_trajectories(batch))
            
            logger.accumulate(metrics)
            train_replay_buffer.extend(batch.reshape(-1).cpu())
        collect_replay_buffer.empty() # A bit inefficient to only delete here

        # 3. Minibatch Gradient Descent Loop (Iterate along random timesteps)
        short_watch.start()
        policy.train(); value.train()

        for epoch in range(EPOCHS):
            for j, batch in enumerate(train_replay_buffer):
                # 2.1 Optimization Step
                batch = batch.to(device)
                with torch.autocast(device_type=device_type, dtype=amp_dtype, enabled=(amp_dtype==torch.float16)):
                    loss_data = loss_module(batch)
                    loss = loss_data["loss_objective"] + loss_data["loss_critic"] + loss_data["loss_entropy"]
                optimizer.zero_grad(set_to_none=True)
                scaler.scale(loss).backward()

                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=MAX_GRAD_NORM)

                scaler.step(optimizer)
                scaler.update()

                # 2.2 Accumulate Metric
                weight = float(batch.batch_size[0])
                logger.accumulate(loss_dict(loss_data, weight))

                # Check for end
                if logger.last()["kl_target"] > KL_TARGET: break
            if logger.last()["kl_target"] > KL_TARGET: break

        train_replay_buffer.empty()
        policy.eval(); value.eval()
        train_time = short_watch.end()
        logger.sum({"train_time": train_time})

        # 4. Log results
        logger.sum({"generation": 1})
        logger.sum({"timestep": GENERATION_SIZE})
        if (i % LOG_INTERVAL) == 0:
            logger.sum({"time": long_watch.end()})
            long_watch.start()
            logger.next(print_row=True)
        
        # 5. Checkpoint model
        if (i % CHECKPOINT_INTERVAL) == 0:
            gen = i + 1
            metric = metrics[BEST_METRIC_KEY]
            checkpointer.save_progress(metric_key=BEST_METRIC_KEY,
            state_obj={
                "generation": gen,
                "policy_state_dict": policy.state_dict(),
                "value_state_dict": value.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scaler_state_dict": scaler.state_dict(),
                BEST_METRIC_KEY: metric,
            })

    checkpointer.copy_model('latest', MODEL_PATH, ('policy_state_dict', 'value_state_dict'))
    return logger.dataframe()