In [1]:
import torchrl
import custom_torchrl_env
import importlib

In [2]:
importlib.reload(custom_torchrl_env)

<module 'custom_torchrl_env' from '/home/emil/Development/custom_torchrl_env/custom_torchrl_env.py'>

In [3]:
env = custom_torchrl_env.RodentRunEnv(batch_size=(2048,), device='cuda', worker_thread_count=16)

In [4]:
torchrl.envs.utils.check_env_specs(env)

2024-03-27 14:08:08,663 [torchrl][INFO] check_env_specs succeeded!


# Debug deadlocks

In [None]:
env.rollout(100)

In [None]:
for tensordict_data in tqdm.tqdm(collector):
    pass

In [None]:
import numpy as np

In [40]:
#env.reset()
a = env.rand_action()
a["action"] *= torch.nan
env.step(a)[("next", "fullphysics")]#["action"]

ValueError: Passed action contains NaNs.

# From TorchRL PPO tutorial

In [5]:
import collections
import torch
import tensordict
import tqdm

In [6]:
num_cells = 1024
lr = 1e-4
max_grad_norm = 1.0
device = 'cuda'
frames_per_batch = 8*1024
total_frames = 2048*1024#256*1024

In [7]:
sub_batch_size = 32  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 4  # optimisation steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

In [8]:
actor_net = torch.nn.Sequential(
    torch.nn.LazyLinear(num_cells, device=device),
    torch.nn.Tanh(),
    torch.nn.LazyLinear(num_cells, device=device),
    torch.nn.Tanh(),
    torch.nn.LazyLinear(num_cells, device=device),
    torch.nn.Tanh(),
    torch.nn.LazyLinear(2 * env.action_spec.shape[-1], device=device),
    tensordict.nn.distributions.NormalParamExtractor(),
)



In [9]:
policy_module = tensordict.nn.TensorDictModule(
    actor_net, in_keys=["fullphysics"], out_keys=["loc", "scale"]
)

In [10]:
policy_module = torchrl.modules.ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=torchrl.modules.TanhNormal,
    distribution_kwargs={
        "min": env.action_spec.space.low,
        "max": env.action_spec.space.high,
    },
    return_log_prob=True,
    # we'll need the log-prob for the numerator of the importance weights
)

In [11]:
value_net = torch.nn.Sequential(
    torch.nn.LazyLinear(num_cells, device=device),
    torch.nn.Tanh(),
    torch.nn.LazyLinear(num_cells, device=device),
    torch.nn.Tanh(),
    torch.nn.LazyLinear(num_cells, device=device),
    torch.nn.Tanh(),
    torch.nn.LazyLinear(1, device=device),
)

value_module = torchrl.modules.ValueOperator(
    module=value_net,
    in_keys=["fullphysics"]
)

In [12]:
print("Running policy:", policy_module(env.reset()).shape)
print("Running value:", value_module(env.reset()).shape)

Running policy: torch.Size([2048])
Running value: torch.Size([2048])


In [13]:
collector = torchrl.collectors.SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

In [14]:
replay_buffer = torchrl.data.replay_buffers.ReplayBuffer(
    storage=torchrl.data.replay_buffers.storages.LazyTensorStorage(max_size=frames_per_batch),
    sampler=torchrl.data.replay_buffers.samplers.SamplerWithoutReplacement(),
)

In [15]:
advantage_module = torchrl.objectives.value.GAE(
    gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device
)

loss_module = torchrl.objectives.ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

In [16]:
optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

In [17]:
import time

In [18]:
for i, tensordict_data in tqdm.tqdm(enumerate(collector)):
    for j in range(tensordict_data.shape[1]):
        advantage_module(tensordict_data[:,j])
        loss_vals = loss_module(tensordict_data[:,j])
        loss_value = (
            loss_vals["loss_objective"]
            + loss_vals["loss_critic"]
            + loss_vals["loss_entropy"]
        )
        loss_value.backward()
        torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
        optim.step()
        optim.zero_grad()

256it [01:22,  3.09it/s]


In [113]:
total_frames / 81#173

25890.765432098764

In [91]:
for t in tqdm.trange(total_frames // 256):
    env.rand_step()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8192/8192 [01:00<00:00, 134.96it/s]


In [93]:
for t in tqdm.trange(total_frames // 256):
    env.simulation_pool.step()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8192/8192 [00:37<00:00, 221.16it/s]


In [44]:
s_time = time.time()
logs = collections.defaultdict(list)
#pbar = tqdm.tqdm(total=total_frames)
eval_str = ""

# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    print(i, "[{:.3f}]".format(time.time()-s_time), "Training...")
    for j in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu())
        for k in range(frames_per_batch // sub_batch_size):
            #print(i, j, k, "Drawing sample")
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            # Optimization: backward, grad clipping and optimization step
            #print(i, j, k, "Backward pass")
            loss_value.backward()
            # this is not strictly mandatory but it's good practice to keep
            # your gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            #print(i, j, k, "Optimization step")
            optim.step()
            optim.zero_grad()
            #print(i, j, k, "Sub-batch done")
    print(i, "[{:.3f}]".format(time.time()-s_time), "Logging...")
    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    #pbar.update(tensordict_data.numel())
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    logs["step_count"].append(tensordict_data["step_count"].max().item())
    stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        print(i, "[{:.3f}]".format(time.time()-s_time), "Evaluating...")
        # We evaluate the policy once every 10 batches of data.
        # Evaluation is rather simple: execute the policy without exploration
        # (take the expected value of the action distribution) for a given
        # number of steps (1000, which is our ``env`` horizon).
        # The ``rollout`` method of the ``env`` can take a policy as argument:
        # it will then execute this policy at each step.
        with torchrl.envs.utils.set_exploration_type(torchrl.envs.utils.ExplorationType.MEAN), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(100, policy_module)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
    #pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))

    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()

0 [0.306] Training...
0 [6.816] Logging...
0 [6.817] Evaluating...
1 [7.934] Training...
1 [14.475] Logging...
2 [14.738] Training...
2 [21.190] Logging...
3 [21.436] Training...
3 [28.078] Logging...
4 [28.332] Training...
4 [34.816] Logging...
5 [35.068] Training...
5 [42.020] Logging...
6 [42.260] Training...
6 [49.436] Logging...
7 [49.679] Training...
7 [56.774] Logging...
8 [57.025] Training...
8 [63.933] Logging...
9 [64.185] Training...
9 [71.032] Logging...


ValueError: Passed action contains NaNs.

In [45]:
logs

defaultdict(list,
            {'reward': [-1.5837199687957764,
              -1.9704928398132324,
              -1.7500295639038086,
              -0.8590556383132935,
              -0.6554573774337769,
              -0.7430774569511414,
              -0.4696408808231354,
              -1.022096872329712,
              -1.2704322338104248,
              -1.4466164112091064],
             'step_count': [32, 132, 164, 196, 228, 260, 292, 324, 356, 388],
             'lr': [0.0001,
              9.975923633360985e-05,
              9.903926402016151e-05,
              9.784701678661043e-05,
              9.619397662556433e-05,
              9.409606321741774e-05,
              9.157348061512726e-05,
              8.865052266813684e-05,
              8.535533905932737e-05,
              8.171966420818226e-05],
             'eval reward': [-1.0644084215164185],
             'eval reward (sum)': [-27248.85546875],
             'eval step_count': [100]})