In [1]:
%pip install tensordict
%pip install torchrl
%pip install gym

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
torch.manual_seed(0)
import time
from torchrl.envs import GymEnv, StepCounter, TransformedEnv, GymWrapper

In [3]:
from gymnasium.envs.registration import register
import gymnasium as gym
from env_swingup_cartpole import SUCartPoleEnv

register(
    id="CustomSUCartPole",
    entry_point="env_swingup_cartpole:SUCartPoleEnv",
    max_episode_steps=4000
)

In [4]:
base_env = gym.make("CustomSUCartPole")
print(base_env.reset())

(array([ 0.23151027, -0.32406142,  3.0433    , -0.0703684 ], dtype=float32), {})


In [5]:
env = GymWrapper(base_env)
transformed_env = TransformedEnv(env, StepCounter())
print(transformed_env.reset())

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)




In [6]:
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq 
from torchrl.modules import EGreedyModule, MLP, QValueModule

value_mlp = MLP(out_features=transformed_env.action_spec.shape[-1], num_cells=[64,64])
value_net = Mod(value_mlp, in_keys="observation", out_keys="action_value")
policy = Seq(value_net, QValueModule(spec=transformed_env.action_spec))

exploration_module = EGreedyModule(
    transformed_env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
policy_exploration = Seq(policy, exploration_module)



In [7]:
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

init_rand_steps = 5000 #Warm up
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
    transformed_env,
    policy,
    frames_per_batch=frames_per_batch,
    total_frames=-1,
    init_random_frames=init_rand_steps
)

rb = ReplayBuffer(storage=LazyTensorStorage(150_000), sampler=SamplerWithoutReplacement(),)

In [8]:
from torch.optim import Adam
from torchrl.objectives import DQNLoss, SoftUpdate

loss = DQNLoss(value_network=policy, action_space=transformed_env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)

In [9]:
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = "./training_loop"
logger = CSVLogger(exp_name="dpn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
base_env = gym.make("CustomSUCartPole", render_mode="rgb_array")
env = GymWrapper(base_env, from_pixels=True, pixels_only=False)
record_env = TransformedEnv(
    env, video_recorder
)



In [10]:
total_counts = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    rb.extend(data)
    max_length = rb[:]["next","step_count"].max()
    if len(rb) > init_rand_steps:
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            exploration_module.step(data.numel())
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_counts += data.numel()
            total_episodes += data["next", "done"].sum()

    if i % 20 == 0:
        
            for i in range(1):
                record_env.rollout(max_steps=8000, policy=policy)
                video_recorder.dump()

    if max_length >= 4000:
        break
t1 = time.time()
torchrl_logger.info(
    f"solved after {total_counts} steps, {total_episodes} episodes and in {t1-t0}s."
)

2024-07-22 01:24:28,706 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,718 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,728 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,739 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,749 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,761 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,773 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,786 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,800 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,814 [torchrl][INFO] Max num steps: 553, rb length 5200
2024-07-22 01:24:28,937 [torchrl][INFO] Max num steps: 553, rb length 5300
2024-07-22 01:24:28,952 [torchrl][INFO] Max num steps: 553, rb length 5300
2024-07-22 01:24:28,965 [torchrl][INFO] Max num steps: 553, rb length 5300
2024-07-22 01:24:28,979 [

In [11]:
for i in range(5):
    record_env.rollout(max_steps=8000, policy=policy)
    video_recorder.dump()

KeyboardInterrupt: 