In [19]:
from torchrl.envs import EnvBase, ParallelEnv
from torchrl.envs.utils import check_env_specs
from torchrl.data import DiscreteTensorSpec, BinaryDiscreteTensorSpec, DiscreteTensorSpec, CompositeSpec, BoundedTensorSpec
from torchrl.modules import MultiAgentMLP
from tensordict import TensorDict
import torch.nn.functional as F
import torch
from torchrl.envs import set_exploration_type, ExplorationType
from typing import Optional

In [26]:
# Tic Tac Toe 
# Environment asks each agent in each turn for a move (even the agent that is not its turn).
# Environment accepts the move from active player and ignore the turn. 
# Regardless of which player is active, environment provided a reward for both agents

DEFAULT_DEVICE = 'cpu'
batch_size = [1]

class TicTacToe(EnvBase):    
    def __init__(self, seed=None, device=DEFAULT_DEVICE, *argv, **kwargs):
        super().__init__(*argv, device=device, **kwargs)
        self._make_spec()
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

        
    def _step(self, state_action: TensorDict):
        batch_size = state_action.shape
        # b x 1
        turns = state_action["turn"]
        new_turns = torch.clone(turns)
        
        # b x 1
        actions = state_action[self.action_key][:, 0] * (1 - turns) + state_action[self.action_key][:, 1] * turns

        # b x 2 x 9
        boards = torch.clone(state_action["board"])
        # b x 2
        rewards = torch.zeros(self.batch_size + (2, 1))

        # Check if the action points to a cell with zero value, otherwise it is an invalid move.
        is_valids = torch.sum(F.one_hot(actions, 9) * boards, dim=-1) == 0
        dones = torch.zeros(self.batch_size + (2, 1)).to(torch.bool)
        
        for idx in range(batch_size[0]):
            turn = turns[idx].item()
            action = actions[idx].item()
            is_valid = is_valids[idx, turn].item()
            if is_valid:
                
                boards[idx, turn, action] = 1
                boards[idx, 1-turn, action] = 2

                player_view = (boards[idx, turn] == 1).reshape(3, 3).to(torch.int)

                row_win = torch.sum((torch.sum(player_view, dim=-1) == 3).long()) > 0
                col_win = torch.sum((torch.sum(player_view, dim=-2) == 3).long()) > 0
                main_diag_win = (torch.trace(player_view) == 3).long()
                anti_diag_win = (torch.trace(torch.fliplr(player_view)) == 3).long()

                won = torch.sum(
                    (row_win + col_win + main_diag_win + anti_diag_win) > 0
                ).to(torch.float)

                rewards[idx, turn, 0] = won * 1
                rewards[idx, 1-turn, 0] = 0
                new_turns[idx] = 1 - turn
            else:
                rewards[idx, turn, 0] = 0

            dones[idx, :, 0] = (torch.sum(boards[idx, :, :] != 0, dim=-1) == 9)
        
        next_state = TensorDict({
                "board": boards,                        
                "reward": rewards,
                "turn": new_turns,
                "done": dones,
            },
            state_action.shape,
        )
        return next_state

    def _reset(self, tensordict: Optional[TensorDict]):   
        batch_size = self.batch_size
        return TensorDict(
            {
                "board": torch.zeros(batch_size + (2, 9)).long(),
                "turn": torch.zeros(batch_size).long(), 
                "done": torch.zeros(batch_size + (2, 1)).bool(),
            },
            batch_size=batch_size
        )
    
    def _set_seed(self, seed: Optional[int]):
        rng = torch.manual_seed(seed)
        self.rng = rng

    def _make_spec(self):
        batch_size = self.batch_size
        self.observation_spec = CompositeSpec(
            {
                "board": BoundedTensorSpec(
                    minimum=0,
                    maximum=2,
                    shape=batch_size + (2, 9),
                    dtype=torch.int64,
                ),
                "turn": DiscreteTensorSpec(n=2, shape=batch_size)
            },
            shape=batch_size
        )
        
        self.state_spec = self.observation_spec.clone()
        self.action_spec = DiscreteTensorSpec(n=9, shape=batch_size + (2, ))
        self.reward_spec = BoundedTensorSpec(
            minimum=-10, 
            maximum=1,
            dtype=torch.float,
            shape=batch_size + (2, 1)
        )
        self.done_spec = DiscreteTensorSpec(n=2, shape=batch_size + (2, 1), dtype=torch.bool)
        

In [27]:
env = TicTacToe(batch_size=batch_size)
check_env_specs(env)

check_env_specs succeeded!


In [28]:
env.rollout(3)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 3, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        board: Tensor(shape=torch.Size([1, 3, 2, 9]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1, 3, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                board: Tensor(shape=torch.Size([1, 3, 2, 9]), device=cpu, dtype=torch.int64, is_shared=False),
                done: Tensor(shape=torch.Size([1, 3, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                reward: Tensor(shape=torch.Size([1, 3, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                turn: Tensor(shape=torch.Size([1, 3]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([1, 3]),
            device=None,
            is_shared=False),
        turn: Tensor(shape=torch.Size([1, 3]), device=cpu, dtype=torch.int64, is_share

In [32]:
out = env.step(TensorDict(
    {
        "board": torch.Tensor([[
            [1, 1, 2, 
             2, 2, 1, 
             0, 2, 1],
            [2, 2, 1, 
             1, 1, 2, 
             0, 1, 2]
        ]]).long(),
        "turn": torch.ones(1).long(), 
        "action": torch.Tensor([[6, 6]]).long(),
    },
    batch_size=[1]
))



In [33]:
print(out["next", "done"].numpy())
print(out["next", "reward"].numpy())

[[[ True]
  [ True]]]
[[[0.]
  [1.]]]


# Create a policy

In [34]:
from torchrl.modules import MLP, EGreedyWrapper
from torchrl.data import OneHotDiscreteTensorSpec
import torchrl.modules.tensordict_module as td_module
from torchrl.envs import (
    TransformedEnv, 
    Compose,
    DoubleToFloat
)
import torch.nn.functional as F
import torch
from torch import nn


In [35]:
n_act = 9
n_state = 9
n_inner = 20

device = DEFAULT_DEVICE
qvalue_net = nn.Sequential(
    MultiAgentMLP(
        n_agent_inputs=env.observation_spec["board"].shape[-1], 
        n_agent_outputs=9,
        n_agents=2,
        centralised=False, 
        share_params=False,
        device=device,
        depth=4,
        num_cells=32,
        activation_class=torch.nn.LeakyReLU
    ),
)

In [36]:
tenv = TransformedEnv(
    env, 
    Compose(
        DoubleToFloat(in_keys=["board"])
    )
)

actor = td_module.QValueActor(
    qvalue_net, 
    in_keys=["board"], 
    action_space=env.action_spec,
)
stock_actor = EGreedyWrapper(
    actor, 
    annealing_num_steps=1_000_000, 
    spec=env.action_spec, 
    eps_end=0.2
)

In [38]:
actor(tenv.reset())

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([1, 2, 9]), device=cpu, dtype=torch.float32, is_shared=False),
        board: Tensor(shape=torch.Size([1, 2, 9]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        turn: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)

In [10]:
traj = tenv.rollout(10, policy=stock_actor)
print(traj)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1, 10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([1, 10, 2, 9]), device=cpu, dtype=torch.float32, is_shared=False),
        board: Tensor(shape=torch.Size([1, 10, 2, 9]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1, 10, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1, 10, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                board: Tensor(shape=torch.Size([1, 10, 2, 9]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([1, 10, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                reward: Tensor(shape=torch.Size([1, 10, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                turn: Tensor(shape=torch.Size(

In [22]:
def visualize_traj(tenv, stock_actor, step_cnt=10, exploration_type=ExplorationType.MODE):
    with torch.no_grad(), set_exploration_type(exploration_type):
        traj = tenv.rollout(step_cnt, policy=stock_actor)
    for idx in range(tenv.base_env.batch_size[0]):
        for step in range(min(step_cnt, traj.shape[1])):
            state = traj[idx, step]
            turn = state["turn"].item()
            print("Next board:")
            print(state["next", "board"][0].reshape(3, 3).cpu().numpy())
            print("Trun:")
            print(turn)
            print("Action:")
            print(state["action"][turn].item())
            print("Reward:")
            print(state["next", "reward"].cpu().numpy())
            print("\n\n")

visualize_traj(tenv, stock_actor)

Next board:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]]
Trun:
0
Action:
8
Reward:
[[ 0.]
 [-0.]]



Next board:
[[0. 0. 0.]
 [0. 2. 0.]
 [0. 0. 1.]]
Trun:
1
Action:
4
Reward:
[[-0.]
 [ 0.]]



Next board:
[[0. 0. 0.]
 [0. 2. 0.]
 [1. 0. 1.]]
Trun:
0
Action:
6
Reward:
[[ 0.]
 [-0.]]



Next board:
[[0. 0. 2.]
 [0. 2. 0.]
 [1. 0. 1.]]
Trun:
1
Action:
2
Reward:
[[-0.]
 [ 0.]]



Next board:
[[0. 0. 2.]
 [0. 2. 1.]
 [1. 0. 1.]]
Trun:
0
Action:
5
Reward:
[[ 0.]
 [-0.]]



Next board:
[[2. 0. 2.]
 [0. 2. 1.]
 [1. 0. 1.]]
Trun:
1
Action:
0
Reward:
[[-0.]
 [ 0.]]



Next board:
[[2. 1. 2.]
 [0. 2. 1.]
 [1. 0. 1.]]
Trun:
0
Action:
1
Reward:
[[ 0.]
 [-0.]]



Next board:
[[2. 1. 2.]
 [2. 2. 1.]
 [1. 0. 1.]]
Trun:
1
Action:
3
Reward:
[[-0.]
 [ 0.]]



Next board:
[[2. 1. 2.]
 [2. 2. 1.]
 [1. 1. 1.]]
Trun:
0
Action:
7
Reward:
[[ 0.5]
 [-0.5]]





# Build data set and train a policy

In [12]:
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.objectives import DQNLoss, SoftUpdate
from tqdm.notebook import tqdm_notebook as tqdm

In [13]:
total_frames = 1000000

collector = SyncDataCollector(
    tenv, 
    stock_actor, 
    frames_per_batch=20,
    total_frames=total_frames,
    reset_at_each_iter=True,
)

loss_fn = DQNLoss(
    stock_actor, 
    action_space=tenv.action_spec,
    delay_value=True,
)

updater = SoftUpdate(
    loss_fn, eps=0.95
)

optim = torch.optim.Adam(stock_actor.parameters(), lr=1e-4)

rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(400), 
    batch_size=10,
)

In [14]:
# check overfitting

data = next(iter(collector))

num_batches = 100
utd = 256

pbar = tqdm(total=num_batches)

for _ in range(num_batches):
    pbar.update(1)
    losses = []    
    for _ in range(utd):
        loss_value = loss_fn(data)
        loss_value["loss"].backward()
        losses.append(loss_value["loss"].item())
        optim.step()
        optim.zero_grad()
    avg_loss = sum(losses) / len(losses)
    pbar.set_description(f"Avg loss = {avg_loss:.6f}")

  0%|          | 0/100 [00:00<?, ?it/s]

In [15]:
def show(sample):
    for idx in range(sample.shape[0]):
        print("board:")
        print(sample[idx]["board"][0, :].reshape(3, 3).numpy())
        print("action:")
        print(sample[idx]["action"].numpy())
        print("reward")
        print(sample[idx]["next", "reward"].numpy())
        print("action_value")
        print(sample[idx]["action_value"].detach().numpy())
        print()
        print("----\n\n")
        
res = actor(data)
show(res.squeeze())


board:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
action:
[5 1]
reward
[[ 0.]
 [-0.]]
action_value
[[-0.17484301 -0.02507906  0.03223633 -0.20505914  0.08394265  0.21963191
   0.0335404   0.13915312 -0.08770215]
 [ 0.13596968  0.18114944  0.14345627  0.03492893  0.15403381  0.06375592
   0.1271474  -0.00805739 -0.03087601]]

----


board:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 1. 0.]]
action:
[5 1]
reward
[[-0.]
 [ 0.]]
action_value
[[-0.3368715  -0.05212075  0.09145431 -0.30340815  0.13105103  0.15367018
   0.14463265 -0.03134204 -0.11350302]
 [ 0.16112146  0.19017364  0.14953838  0.07242975  0.16655558  0.15610191
   0.15169232 -0.018997    0.07356079]]

----


board:
[[2. 0. 0.]
 [0. 0. 0.]
 [0. 1. 0.]]
action:
[2 1]
reward
[[-1.]
 [ 0.]]
action_value
[[-0.8553548  -0.0128549   0.14454132 -0.6102624   0.10912043 -0.2465922
   0.03302849 -0.8553393  -0.445301  ]
 [ 0.16942525  0.19958091  0.14828162  0.05101341  0.16112162  0.1602923
   0.14103222 -0.01092182  0.16111578]]

----


board:
[[2. 0. 0.]

In [16]:
print(loss_fn(data)["loss"])

tensor(0.3855, grad_fn=<MeanBackward0>)


In [20]:
pbar = tqdm(total=total_frames)

utd = 16

for i, data in enumerate(collector):
    pbar.update(data.numel())
    rb.extend(data.squeeze().to_tensordict().cpu())
    losses = []
    for _ in range(utd):
        s = rb.sample().to(device)
        loss_value = loss_fn(s)
        loss_value["loss"].backward()
        losses.append(loss_value["loss"].item())
        optim.step()
        optim.zero_grad()
    
    avg_loss = sum(losses) / len(losses)
    
    stock_actor.step()
    updater.step()
    
    if i % 50 == 0:
        with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
            sim = tenv.rollout(10, stock_actor)
            re = sim["next", "reward"].to(torch.float32).sum(dim=1).cpu().squeeze().numpy()
            pbar.set_description(f"Average reward = {re[0].item():.2f}, Avg loss = {avg_loss:.6f}")

  0%|          | 0/1000000 [00:00<?, ?it/s]

In [23]:
visualize_traj(tenv, stock_actor, step_cnt=20)

Next board:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]]
Trun:
0
Action:
8
Reward:
[[ 0.]
 [-0.]]



Next board:
[[0. 0. 0.]
 [0. 2. 0.]
 [0. 0. 1.]]
Trun:
1
Action:
4
Reward:
[[-0.]
 [ 0.]]



Next board:
[[0. 0. 0.]
 [0. 2. 0.]
 [1. 0. 1.]]
Trun:
0
Action:
6
Reward:
[[ 0.]
 [-0.]]



Next board:
[[0. 0. 2.]
 [0. 2. 0.]
 [1. 0. 1.]]
Trun:
1
Action:
2
Reward:
[[-0.]
 [ 0.]]



Next board:
[[0. 0. 2.]
 [0. 2. 1.]
 [1. 0. 1.]]
Trun:
0
Action:
5
Reward:
[[ 0.]
 [-0.]]



Next board:
[[2. 0. 2.]
 [0. 2. 1.]
 [1. 0. 1.]]
Trun:
1
Action:
0
Reward:
[[-0.]
 [ 0.]]



Next board:
[[2. 1. 2.]
 [0. 2. 1.]
 [1. 0. 1.]]
Trun:
0
Action:
1
Reward:
[[ 0.]
 [-0.]]



Next board:
[[2. 1. 2.]
 [2. 2. 1.]
 [1. 0. 1.]]
Trun:
1
Action:
3
Reward:
[[-0.]
 [ 0.]]



Next board:
[[2. 1. 2.]
 [2. 2. 1.]
 [1. 1. 1.]]
Trun:
0
Action:
7
Reward:
[[ 0.5]
 [-0.5]]





In [None]:
act = actor(tenv.reset())

In [None]:
act

In [None]:
sample_data = next(iter(collector))

rb.extend(sample_data.squeeze().to_tensordict().cpu())

In [None]:
for i in range(10):
    show(rb.sample().to(device))

In [None]:
for i, data in enumerate(collector):
    if i == 10:
        break
    show(data.squeeze())
    print(f"----- {i} -----")

In [None]:
show(data.squeeze())

In [None]:
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
#     print(stock_actor(data)["action"].numpy())
    print(qvalue_net(data["board"]))