## Goal

This is a proof of concept on how AlphaZero can be implemented on top of TorchRL. 

We will apply this technique on CliffWalking-v0 environment. 

In [1]:
import torch

from tensordict.nn import TensorDictModule

from torchrl.modules import QValueActor

from torchrl.envs import GymEnv, TransformedEnv, Compose, DTypeCastTransform, StepCounter

from torchrl.objectives import DQNLoss


In [2]:
# turn on autoreload
%load_ext autoreload
%autoreload 2

# QValue Network

Lets first create a QValue network. QValue networks provide an initial value for each action when we explore a node for the first time. 

In [3]:
def make_q_value(num_observation, num_action, action_space):
    net = torch.nn.Linear(num_observation, num_action)
    qvalue_module = QValueActor(net, in_keys=["observation"], action_space=action_space)
    return qvalue_module


env = TransformedEnv(
    GymEnv("CliffWalking-v0"),
    Compose(
        DTypeCastTransform(dtype_in=torch.long, dtype_out=torch.float32, in_keys=["observation"]), 
        StepCounter(),
    )
)
qvalue_module = make_q_value(env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1], env.action_spec)
qvalue_module(env.reset())

  logger.warn(
  logger.warn(


TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([48]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [4]:
loss_module = DQNLoss(qvalue_module, action_space=env.action_spec)



In [5]:
from mcts.tensordict_map import TensorDictMap
from mcts.mcts_policy import SimulatedSearchPolicy, MctsPolicy, UpdateTreeStrategy, AlphaZeroExpansionStrategy, PucbSelectionPolicy, SimulatedAlphaZeroSearchPolicy

tree = TensorDictMap(["observation", "step_count"])


policy = SimulatedAlphaZeroSearchPolicy(
    policy=MctsPolicy(
        expansion_strategy=AlphaZeroExpansionStrategy(value_module=qvalue_module, tree=tree),
        selection_strategy=PucbSelectionPolicy(),
    ),
    tree_updater=UpdateTreeStrategy(tree),
    env=env,
    num_simulation=10,
    max_steps=1000,
)


In [6]:
from torchrl.collectors import SyncDataCollector

In [7]:
data_collecter = SyncDataCollector(lambda: env, policy, total_frames = 1_000_000, frames_per_batch = 10_000)
print('done')
from torchrl.trainers import Trainer
Trainer(collector=data_collecter, loss_module=loss_module,  optimizer=torch.optim.Adam(qvalue_module.parameters(), lr=1e-3)).train()

10


: 