## Goal

This is a proof of concept on how AlphaZero can be implemented on top of TorchRL. 

We will apply this technique on CliffWalking-v0 environment. 

In [1]:
import torch

from tensordict.nn import TensorDictModule

from torchrl.modules import QValueActor

from torchrl.envs import GymEnv, TransformedEnv, Compose, DTypeCastTransform

from torchrl.objectives import DQNLoss


# QValue Network

Lets first create a QValue network. QValue networks provide an initial value for each action when we explore a node for the first time. 

In [2]:
def make_q_value(num_observation, num_action, action_space):
    net = torch.nn.Linear(num_observation, num_action)
    qvalue_module = QValueActor(net, in_keys=["observation"], action_space=action_space)
    return qvalue_module


env = TransformedEnv(
    GymEnv("CliffWalking-v0"),
    Compose(
        DTypeCastTransform(dtype_in=torch.long, dtype_out=torch.float32, in_keys=["observation"])
    )
)
qvalue_module = make_q_value(env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1], env.action_spec)
qvalue_module(env.reset())

  logger.warn(
  logger.warn(


TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([48]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [3]:
loss_module = DQNLoss(qvalue_module, action_space=env.action_spec)



In [4]:
from mcts.tensordict_map import TensorDictMap
from mcts.mcts_policy import SimulatedSearchPolicy, MctsPolicy, UpdateTreeStrategy, AlphaZeroExpansionStrategy, PucbSelectionPolicy

tree = TensorDictMap("observation")

policy = SimulatedSearchPolicy(
    policy=MctsPolicy(
        expansion_strategy=AlphaZeroExpansionStrategy(value_module=qvalue_module, tree=tree),
        selection_strategy=PucbSelectionPolicy(),
    ),
    tree_updater=UpdateTreeStrategy(tree),
    env=env,
    num_simulation=5,
    max_steps=3,
)


In [5]:
res = policy(env.reset())

start_simulation
forward
simulation-0
action_value
q_sa: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
n_sa: [0. 0. 0. 0.]
p_sa: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
action_value: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
action: 2
action_value
q_sa: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
n_sa: [0. 0. 0. 0.]
p_sa: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
action_value: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
action: 2
action_value
q_sa: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
n_sa: [0. 0. 0. 0.]
p_sa: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
action_value: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
action: 2
updated... [0 0 1 0]
q_sa: [-0.04422648 -0.03890638 -3.         -0.14129654]
n_sa: [0. 0. 1. 0.]
p_sa: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
action_value: [-0.04422648 -0.03890638  0.09039502 -0.14129654]
updating done!
updated... [0 0 1 0]
q_sa: [-0.04422648 -0.03890638 -2.5        -0.141296

In [6]:
len(tree._dict)

3

In [7]:
res["q_sa"]

tensor([-102., -200.,   -2.,   -2.])

In [8]:
res["p_sa"]

tensor([-0.0442, -0.0389,  0.0904, -0.1413])

In [9]:
res["action_value"]

tensor([-102.0399, -200.0175,   -1.9767,   -2.0637])

In [10]:
res["action"]

tensor([0, 0, 1, 0])