## Goal

This is a proof of concept on how AlphaZero can be implemented on top of TorchRL. 

We will apply this technique on CliffWalking-v0 environment. 

In [1]:
import torch

from tensordict.nn import TensorDictModule

from torchrl.modules import QValueActor

from torchrl.envs import GymEnv, TransformedEnv, Compose, DTypeCastTransform, StepCounter

from torchrl.objectives import DQNLoss


# QValue Network

Lets first create a QValue network. QValue networks provide an initial value for each action when we explore a node for the first time. 

In [2]:
def make_q_value(num_observation, num_action, action_space):
    net = torch.nn.Linear(num_observation, num_action)
    qvalue_module = QValueActor(net, in_keys=["observation"], action_space=action_space)
    return qvalue_module


env = TransformedEnv(
    GymEnv("CliffWalking-v0"),
    Compose(
        DTypeCastTransform(dtype_in=torch.long, dtype_out=torch.float32, in_keys=["observation"]), 
        StepCounter(),
    )
)
qvalue_module = make_q_value(env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1], env.action_spec)
qvalue_module(env.reset())

  logger.warn(
  logger.warn(


TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([48]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [3]:
loss_module = DQNLoss(qvalue_module, action_space=env.action_spec)



In [4]:
from mcts.tensordict_map import TensorDictMap
from mcts.mcts_policy import SimulatedSearchPolicy, MctsPolicy, UpdateTreeStrategy, AlphaZeroExpansionStrategy, PucbSelectionPolicy

tree = TensorDictMap(["observation", "step_count"])

policy = SimulatedSearchPolicy(
    policy=MctsPolicy(
        expansion_strategy=AlphaZeroExpansionStrategy(value_module=qvalue_module, tree=tree),
        selection_strategy=PucbSelectionPolicy(),
    ),
    tree_updater=UpdateTreeStrategy(tree),
    env=env,
    num_simulation=100,
    max_steps=1000,
)


In [5]:
res = policy(env.reset())

simulation: 0

q_sa -> [-0.10321956 -0.07089151 -0.08972184 -5.        ]
p_sa -> [-0.10321956 -0.07089151 -0.08972184 -0.07059182]
n_sa -> [0. 0. 0. 1.]
action -> [0 0 0 1]
step_count -> [0]
q_sa -> [-0.10321956 -0.07089151 -0.08972184 -4.        ]
p_sa -> [-0.10321956 -0.07089151 -0.08972184 -0.07059182]
n_sa -> [0. 0. 0. 1.]
action -> [0 0 0 1]
step_count -> [1]
q_sa -> [-0.10321956 -0.07089151 -0.08972184 -3.        ]
p_sa -> [-0.10321956 -0.07089151 -0.08972184 -0.07059182]
n_sa -> [0. 0. 0. 1.]
action -> [0 0 0 1]
step_count -> [2]
q_sa -> [-0.10321956 -0.07089151 -0.08972184 -2.        ]
p_sa -> [-0.10321956 -0.07089151 -0.08972184 -0.07059182]
n_sa -> [0. 0. 0. 1.]
action -> [0 0 0 1]
step_count -> [3]
q_sa -> [-0.10321956 -0.07089151 -0.08972184 -1.        ]
p_sa -> [-0.10321956 -0.07089151 -0.08972184 -0.07059182]
n_sa -> [0. 0. 0. 1.]
action -> [0 0 0 1]
step_count -> [4]
simulation: 1

q_sa -> [-1.0321956e-01 -5.0000000e+02 -8.9721836e-02 -5.0000000e+00]
p_sa -> [-0.10321956

In [6]:
len(tree._dict)

32

In [7]:
res["q_sa"]

tensor([  -5., -500.,   -5.,   -5.])

In [8]:
res["n_sa"]

tensor([1., 1., 1., 7.])

In [9]:
res["action_value"]

tensor([  -5.0816, -500.0560,   -5.0709,   -5.0140])

In [10]:
res["action"]

tensor([0, 0, 0, 1])