# Create an environment

In [None]:
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs
from torchrl.data import DiscreteTensorSpec, BinaryDiscreteTensorSpec, DiscreteTensorSpec, CompositeSpec, BoundedTensorSpec

from tensordict import TensorDict, TensorDictBase
from typing import Optional
import torch.nn.functional as F
import torch
from torch import nn

from dataclasses import dataclass

In [None]:
device = "cpu"
class TicTacToeEnv(EnvBase):
    def __init__(self, seed=None, device=device, *argv, **kwargs):
        super().__init__(*argv, device=device, **kwargs)
        self._make_spec()
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)
        
    
    def _step(self, tensordict: TensorDict):
        board = tensordict["board"]
        turn = tensordict["turn"]
        action = tensordict["action"]
        
        one_hot_action = F.one_hot(action, 9)
        flatten_board = board.reshape(-1, 9)
        
        # b x 1
        is_valid = (torch.sum(one_hot_action * flatten_board, dim=-1) == 0).to(torch.float).unsqueeze(dim=-1)
        new_flatten_board = flatten_board + turn * one_hot_action
        
        player_view = (new_flatten_board == turn).reshape(-1, 3, 3)
        
        # b x 3
        row_win = (torch.sum(player_view, dim=-1) == 3).long()
        col_win = (torch.sum(player_view, dim=-2) == 3).long()
        
        # b x 1
        won_reward = -(torch.sum(row_win + col_win, dim=-1, keepdim=True) > 0).long()
        
        next_turn = ((turn % 2) + 1) * is_valid + turn * (1 - is_valid)
        reward = won_reward * is_valid + (1 - is_valid) * -10
        
        final_board = new_flatten_board * is_valid + flatten_board * (1 - is_valid)
                
        done = torch.sum(final_board != 0, dim=-1) == 9
        
        next_state = TensorDict(
            {
                "next": {
                    "board": final_board.reshape(-1, 3, 3).to(torch.int64),
                    "turn": next_turn.to(torch.int64), 
                    "reward": reward.to(torch.int64),
                    "done": done,
                },
            },
            tensordict.shape,
        )
        return next_state
    
    def _reset(self, tensordict: Optional[TensorDict]):
        batch_size=self.batch_size
        
        return TensorDict(
            {
                "board": torch.zeros(batch_size + (3, 3)).long(),
                "turn": torch.ones(batch_size + (1,)).long(), 
            },
            batch_size=batch_size
        )

            
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng
    
    def _make_spec(self):
        batch_size = self.batch_size
        self.observation_spec = CompositeSpec(
            board=BoundedTensorSpec(
                minimum=0,
                maximum=2,
                shape=batch_size + (3, 3),
                dtype=torch.int64,
            ),
            turn=BoundedTensorSpec(
                minimum=0,
                maximum=2,
                shape=batch_size + (1,),
                dtype=torch.int64,
            ),
            shape=batch_size
        )
        
        self.state_spec = self.observation_spec.clone()
        self.action_spec = DiscreteTensorSpec(n=9, shape=batch_size)
        n = 1
        self.reward_spec = BoundedTensorSpec(
            minimum=-10, 
            maximum=1,
            dtype=torch.int64,
            shape=batch_size + (n,)
        )

In [None]:
env = TicTacToeEnv(batch_size=[1])
check_env_specs(env)

In [None]:
class ConsoleObserver:
    def __call__(self, env: EnvBase, state: TensorDict):
        print(state["board"])
        
observer = ConsoleObserver()

In [5]:
trajectory = env.rollout(10, callback=observer)

tensor([[[0, 0, 0],
         [0, 0, 0],
         [1, 0, 0]]])
tensor([[[0, 0, 0],
         [0, 0, 0],
         [1, 2, 0]]])
tensor([[[0, 0, 0],
         [0, 0, 0],
         [1, 2, 0]]])
tensor([[[0, 0, 0],
         [0, 0, 0],
         [1, 2, 0]]])
tensor([[[0, 1, 0],
         [0, 0, 0],
         [1, 2, 0]]])
tensor([[[0, 1, 0],
         [0, 0, 0],
         [1, 2, 0]]])
tensor([[[2, 1, 0],
         [0, 0, 0],
         [1, 2, 0]]])
tensor([[[2, 1, 0],
         [0, 0, 0],
         [1, 2, 1]]])
tensor([[[2, 1, 2],
         [0, 0, 0],
         [1, 2, 1]]])


In [6]:
trajectory["action"]

tensor([[6, 7, 6, 7, 1, 1, 0, 8, 2, 6]])

In [7]:
trajectory["next"]["reward"]

tensor([[[  0],
         [  0],
         [-10],
         [-10],
         [  0],
         [-10],
         [  0],
         [  0],
         [  0],
         [-10]]])

# Create a Policy

In [8]:
from torchrl.modules import MLP, EGreedyWrapper
from torchrl.data import OneHotDiscreteTensorSpec
import torchrl.modules.tensordict_module as td_module
from torchrl.envs import (
    TransformedEnv, 
    Compose,
    DoubleToFloat
)

In [9]:
n_act = 9
n_state = 9
n_inner = 20

qvalue_net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(n_state, n_inner, device=device),
    nn.LeakyReLU(),
    nn.Linear(n_inner, n_inner, device=device),
    nn.LeakyReLU(),
    nn.Linear(n_inner, n_act, device=device),
    nn.Softmax(),
)

In [10]:
tenv = TransformedEnv(
    env, 
    Compose(
        DoubleToFloat(in_keys=["board"])
    )
)

In [11]:
actor = td_module.QValueActor(qvalue_net, in_keys=["board"], action_space=env.action_spec)
stock_actor = EGreedyWrapper(actor, annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.5)

In [12]:
traj = tenv.rollout(10, policy=stock_actor)

  return forward_call(*args, **kwargs)


In [13]:
traj["action_value"]

tensor([[[0.0852, 0.1015, 0.1207, 0.0965, 0.1331, 0.1115, 0.1158, 0.1402,
          0.0956],
         [0.0891, 0.1048, 0.1157, 0.0951, 0.1355, 0.1105, 0.1130, 0.1385,
          0.0978],
         [0.0925, 0.0992, 0.1181, 0.0927, 0.1334, 0.1079, 0.1194, 0.1376,
          0.0991],
         [0.0925, 0.0992, 0.1181, 0.0927, 0.1334, 0.1079, 0.1194, 0.1376,
          0.0991],
         [0.0925, 0.0992, 0.1181, 0.0927, 0.1334, 0.1079, 0.1194, 0.1376,
          0.0991],
         [0.0925, 0.0992, 0.1181, 0.0927, 0.1334, 0.1079, 0.1194, 0.1376,
          0.0991],
         [0.0925, 0.0992, 0.1181, 0.0927, 0.1334, 0.1079, 0.1194, 0.1376,
          0.0991],
         [0.0946, 0.0996, 0.1167, 0.0934, 0.1310, 0.1057, 0.1205, 0.1400,
          0.0985],
         [0.0946, 0.0996, 0.1167, 0.0934, 0.1310, 0.1057, 0.1205, 0.1400,
          0.0985],
         [0.0919, 0.1059, 0.1112, 0.0916, 0.1331, 0.1124, 0.1102, 0.1423,
          0.1015]]], grad_fn=<StackBackward0>)

In [14]:
traj["action"]

tensor([[7, 4, 7, 7, 7, 7, 6, 7, 0, 7]])

In [15]:
actor(tenv.reset())

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.float32, is_shared=False),
        board: Tensor(shape=torch.Size([1, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        turn: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([1]),
    device=cpu,
    is_shared=False)

# Build a dataset from a policy

In [16]:
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from tqdm.notebook import tqdm_notebook as tqdm
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl.envs import set_exploration_type, ExplorationType

In [17]:
total_frames = 20000

collector = SyncDataCollector(
    tenv,
    stock_actor, 
    frames_per_batch=50,
    total_frames=total_frames,
)

loss_fn = DQNLoss(
    stock_actor,
    action_space=tenv.action_spec,
    delay_value=True,
)

updater = SoftUpdate(loss_fn, eps=0.95)
optim = torch.optim.Adam(stock_actor.parameters(), lr=3e-4)


In [18]:
pbar = tqdm(total=total_frames)

rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(10_000), 
    batch_size=10,
)

utd = 16
for i, data in enumerate(collector):
    pbar.update(data.numel())
    rb.extend(data.squeeze(0).to_tensordict().cpu())
    for _ in range(utd):
        s = rb.sample().to(device)
        loss_value = loss_fn(s)
        loss_value["loss"].backward()
        optim.step()
        optim.zero_grad()
    
    stock_actor.step()
    updater.step()
    
    if i % 50 == 0:
        with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
            sim = tenv.rollout(10, stock_actor)
            re = sim["next", "reward"].to(torch.float32).mean().item()
            pbar.set_description(f"Average reward = {re:.2f}")
        


  0%|          | 0/20000 [00:00<?, ?it/s]

In [19]:
with set_exploration_type(ExplorationType.MODE):
    sim = tenv.rollout(10, stock_actor, ConsoleObserver())

tensor([[[0., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.]]])
tensor([[[0., 0., 0.],
         [1., 0., 2.],
         [0., 0., 0.]]])
tensor([[[0., 0., 0.],
         [1., 0., 2.],
         [0., 0., 1.]]])
tensor([[[0., 0., 0.],
         [1., 2., 2.],
         [0., 0., 1.]]])
tensor([[[0., 1., 0.],
         [1., 2., 2.],
         [0., 0., 1.]]])
tensor([[[0., 1., 2.],
         [1., 2., 2.],
         [0., 0., 1.]]])
tensor([[[0., 1., 2.],
         [1., 2., 2.],
         [0., 1., 1.]]])
tensor([[[2., 1., 2.],
         [1., 2., 2.],
         [0., 1., 1.]]])
tensor([[[2., 1., 2.],
         [1., 2., 2.],
         [0., 1., 1.]]])


In [20]:
print(sim["action"])
print(sim["next", "reward"])
print(sim["next", "done"])

tensor([[3, 5, 8, 4, 1, 2, 7, 0, 1, 1]])
tensor([[[  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [-10],
         [-10]]])
tensor([[[False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False]]])


In [21]:
sim[0:1, 9]["action"]=torch.tensor([4])
state = sim[0:1, 9]
state["action"] = torch.tensor([4])

In [22]:
res = tenv.step(state)

In [23]:
res["board"]

tensor([[[2., 1., 2.],
         [1., 2., 2.],
         [0., 1., 1.]]])

In [24]:
res["next", "done"]

tensor([[False]])

In [25]:
res["next", "board"]

tensor([[[2., 1., 2.],
         [1., 2., 2.],
         [0., 1., 1.]]])

In [34]:
def show(sample):
    for idx in range(sample.shape[0]):
        print("board:")
        print(sample[idx]["board"].numpy())
        print("action:")
        print(sample[idx]["action"].numpy())
        print("reward")
        print(sample[idx]["next", "reward"].numpy())


In [35]:
for i, data in enumerate(collector):
    print(f"----- {i} -----")
    if i == 10:
        break
    show(data.squeeze())


----- 0 -----
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
0
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
8
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [1. 2. 2.]
 [0. 2. 1.]]
acti

6
reward
[-10]
board:
[[0. 0. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
0
reward
[0]
board:
[[1. 0. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
2
reward
[-10]
board:
[[1. 0. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
2
reward
[-10]
board:
[[1. 0. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
1
reward
[0]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
3
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
8
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
3
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
4
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
0
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
8
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
4
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
1
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:
4
reward
[-10]
board:
[[1. 2. 1.]
 [2. 1. 0.]
 [2. 2. 1.]]
action:

[-1]
----- 9 -----
board:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
action:
3
reward
[0]
board:
[[0. 0. 0.]
 [1. 0. 0.]
 [0. 0. 0.]]
action:
5
reward
[0]
board:
[[0. 0. 0.]
 [1. 0. 2.]
 [0. 0. 0.]]
action:
0
reward
[0]
board:
[[1. 0. 0.]
 [1. 0. 2.]
 [0. 0. 0.]]
action:
2
reward
[0]
board:
[[1. 0. 2.]
 [1. 0. 2.]
 [0. 0. 0.]]
action:
8
reward
[0]
board:
[[1. 0. 2.]
 [1. 0. 2.]
 [0. 0. 1.]]
action:
1
reward
[0]
board:
[[1. 2. 2.]
 [1. 0. 2.]
 [0. 0. 1.]]
action:
0
reward
[-10]
board:
[[1. 2. 2.]
 [1. 0. 2.]
 [0. 0. 1.]]
action:
7
reward
[0]
board:
[[1. 2. 2.]
 [1. 0. 2.]
 [0. 1. 1.]]
action:
6
reward
[0]
board:
[[1. 2. 2.]
 [1. 0. 2.]
 [2. 1. 1.]]
action:
4
reward
[0]
board:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
action:
3
reward
[0]
board:
[[0. 0. 0.]
 [1. 0. 0.]
 [0. 0. 0.]]
action:
4
reward
[0]
board:
[[0. 0. 0.]
 [1. 2. 0.]
 [0. 0. 0.]]
action:
5
reward
[0]
board:
[[0. 0. 0.]
 [1. 2. 1.]
 [0. 0. 0.]]
action:
1
reward
[0]
board:
[[0. 2. 0.]
 [1. 2. 1.]
 [0. 0. 0.]]
action:
6
reward
[0]
boar